代码分析

2020-05-26  本文已影响0人  16在这儿
type Pool interface{
    Get(context.Context,network string,address string) (net.Conn,error)
}
package connpool

type Options struct{
    initialCap int
    maxCap int
    idleTimeout time.Duration
    maxIdle time.Duration
    dialTimeout time.Duration
}

type Option func(*Options)

Whit...

type pool{
    opts *Options
    conns *sync.Map
}

func NewConnPool(opt ...Option) *pool{
    //default
    opts := &Options {
        maxCap: 1000,
        idleTimeout: 1 * time.Minute,
        dialTimeout: 200 * time.Millisecond,
    }
    
    m := &sync.Map{}
    
    p := &pool{
        opts:opts,
        conns:m,
    }
       
    // 选项模式
    for _,o := opt{
        o(p.opts)
    }
    
    return p
}

func (p *pool) Get(ctx context.Context,network string,address string)(net.Conn,error){
    // 获取子连接池
    if value,ok := p.comms.Load(address);ok{
        // 获取连接
        if cp,ok := value.(*channelPool);ok{
            conn,err := cp.Get(ctx)
            return conn,err
        }
    }
    
    // 获取不到、则需要初始化子连接池
    cp,err := p.NewChannelPool(ctx,network,address)
    if err != nil{
        return nil,err
    }
    
    p.conns.Store(address,cp)
    return cp.Get(ctx)
}

// 初始化连接池
func (p *pool) NewChannelPool(ctx context.Context,network string,address string) (*channelPool,error){
    c := &channelPool{
        initialCap: p.opts.initialCap,  // 初始化连接数量
        maxCap: p.opts.maxCap,  // 连接最大值
        // todo 不知道这里什么意思
        Dial: func(ctx context.Context)(net.Conn,error){
            select{
                case <-ctx.Done():
                    return nil,ctx.Err()
                default:
            }
            
            timeout := p.opts.dialTimeout
            if t,ok := ctx.Deadline();ok{
                timeout = t.Sub(time.Now())
            }
            
            return net.DialTimeout(network,address,timeout)
        },
        conns: make(chan *PoolConn,p.opts.maxCap),
        idleTimeout: p.opts.idleTimeout,    // 不知道干嘛的
        dialTimeout: p.opts.dialTimeout,    // 不知道干嘛的
    }
    
    if p.opts.initialCap == 0{
        p.opts.initialCap = 1
    }
    
    // 获取多个连接
    for i := 0; i < p.opts.initialCap; i++ {
        conn , err := c.Dial(ctx);
        if err != nil {
            return nil, err
        }
        c.Put(c.wrapConn(conn))
    }
    
    c.RegisterChecker(3 * time.Second, c.Checker)
    return c, nil
}

var poolMap = make(map[string]Pool)
var oneByte = make([]byte,1)
func registerPool(name string,pool Pool){
    poolMap[name] = pool
}

func GetPool(name string) Pool{
    if v,ok := poolMap[name];ok{
        return v
    }
    
    return DefaultPool
}
type channelPool struct {
    net.Conn
    initialCap int  // initial capacity
    maxCap int      // max capacity
    maxIdle int     // max idle conn number
    idleTimeout time.Duration  // idle timeout
    dialTimeout time.Duration  // dial timeout
    Dial func(context.Context) (net.Conn, error)
    conns chan *PoolConn
    mu sync.RWMutex
}

func (c *channelPool) Get(ctx context.Context) (net.Conn, error) {
    if c.conns == nil {
        return nil, ErrConnClosed
    }
    select {
        case pc := <-c.conns :
            if pc == nil {
                return nil, ErrConnClosed
            }

            if pc.unusable {
                return nil, ErrConnClosed
            }

            return pc, nil
        default:
            conn, err := c.Dial(ctx)
            if err != nil {
                return nil, err
            }
            return c.wrapConn(conn), nil
    }
}

func (c *channelPool) Close() {
    c.mu.Lock()
    conns := c.conns
    c.conns = nil
    c.Dial = nil
    c.mu.Unlock()

    if conns == nil {
        return
    }
    close(conns)
    for conn := range conns {
        conn.MarkUnusable()
        conn.Close()
    }
}

func (c *channelPool) Put(conn *PoolConn) error {
    if conn == nil {
        return errors.New("connection closed")
    }
    c.mu.RLock()
    defer c.mu.RUnlock()
    if c.conns == nil {
        conn.MarkUnusable()
        conn.Close()
    }

    select {
    case c.conns <- conn :
        return nil
    default:
        // 连接池满
        return conn.Close()
    }
}

func (c *channelPool) RegisterChecker(internal time.Duration, checker func(conn *PoolConn) bool) {

    if internal <= 0 || checker == nil {
        return
    }

    go func() {

        for {

            time.Sleep(internal)

            length := len(c.conns)

            for i:=0; i < length; i++ {

                select {
                case pc := <- c.conns :

                    if !checker(pc) {
                        pc.MarkUnusable()
                        pc.Close()
                        break
                    } else {
                        c.Put(pc)
                    }
                default:
                    break
                }

            }
        }

    }()
}

func (c *channelPool) Checker (pc *PoolConn) bool {

    // check timeout
    if pc.t.Add(c.idleTimeout).Before(time.Now()) {
        return false
    }

    // check conn is alive or not
    if !isConnAlive(pc.Conn) {
        return false
    }

    return true
}

func isConnAlive(conn net.Conn) bool {
    conn.SetReadDeadline(time.Now().Add(time.Millisecond))

    if n, err := conn.Read(oneByte); n > 0 || err == io.EOF {
        return false
    }

    conn.SetReadDeadline(time.Time{})
    return true
}
package connpool

import (
    "errors"
    "net"
    "sync"
    "time"
)

var (
    ErrConnClosed = errors.New("connection closed ...")
)

type PoolConn struct {
    net.Conn
    c *channelPool
    unusable bool       // if unusable is true, the conn should be closed
    mu sync.RWMutex
    t time.Time  // connection idle time
    dialTimeout time.Duration // connection timeout duration
}

// overwrite conn Close for connection reuse
func (p *PoolConn) Close() error {
    p.mu.RLock()
    defer p.mu.RUnlock()

    if p.unusable {
        if p.Conn != nil {
            return p.Conn.Close()
        }
    }

    // reset connection deadline
    p.Conn.SetDeadline(time.Time{})

    return p.c.Put(p)
}

func (p *PoolConn) MarkUnusable() {
    p.mu.Lock()
    p.unusable = true
    p.mu.Unlock()
}

func (p *PoolConn) Read(b []byte) (int, error) {
    if p.unusable {
        return 0, ErrConnClosed
    }
    n, err := p.Conn.Read(b)
    if err != nil {
        p.MarkUnusable()
        p.Conn.Close()
    }
    return n, err
}

func (p *PoolConn) Write(b []byte) (int, error) {
    if p.unusable {
        return 0, ErrConnClosed
    }
    n, err := p.Conn.Write(b)
    if err != nil {
        p.MarkUnusable()
        p.Conn.Close()
    }
    return n, err
}

func (c *channelPool) wrapConn(conn net.Conn) *PoolConn {
    p := &PoolConn {
        c : c,
        t : time.Now(),
        dialTimeout: c.dialTimeout,
    }
    p.Conn = conn
    return p
}
上一篇下一篇

猜你喜欢

热点阅读