golang200行实现连接池

2019-08-09  本文已影响0人  成功的失败者
package main

import (
    "errors"
    "fmt"
    "net"
    "os"
    "sync"
    "time"
)

type (
    Config struct {
        //初始化连接数
        InitCap int
        //最大连接数
        MaxCap   int
        InitConn func() (interface{}, error)
        Close    func(interface{}) error
        Validate func(interface{}) error
        Timeout  time.Duration
    }
    Pool struct {
        mu    sync.Mutex
        conns chan *conn
        //防止一个连接被重复加入
        poolMap  map[interface{}]int
        initConn func() (interface{}, error)
        close    func(interface{}) error
        validate func(interface{}) error
        timeout  time.Duration
    }
    conn struct {
        connect    interface{}
        createTime time.Time
    }
)

var (
    ErrClosed       = errors.New("pool is closed")
    ErrInvalidCap   = errors.New("invalid cap settings")
    ErrInvalidInit  = errors.New("invalid init func settings")
    ErrInvalidClose = errors.New("invalid close func settings")
    ErrInvalidConn  = errors.New("invalid conn")
)

const addr string = "127.0.0.1:8230"

func NewPool(config *Config) (*Pool, error) {
    if config.InitCap < 0 || config.MaxCap <= 0 || config.InitCap > config.MaxCap {
        return nil, ErrInvalidCap
    }
    if config.InitConn == nil {
        return nil, ErrInvalidInit
    }
    if config.Close == nil {
        return nil, ErrInvalidClose
    }
    pool := &Pool{
        conns:    make(chan *conn, config.MaxCap),
        poolMap:  make(map[interface{}]int),
        initConn: config.InitConn,
        close:    config.Close,
        timeout:  config.Timeout,
    }
    if config.Validate != nil {
        pool.validate = config.Validate
    }
    for i := 0; i < config.InitCap; i++ {
        connect, err := pool.initConn()
        pool.poolMap[connect] = 1
        if err != nil {
            pool.Release()
            return nil, ErrInvalidInit
        }
        pool.conns <- &conn{connect: connect, createTime: time.Now()}
    }
    return pool, nil
}

func (pool *Pool) Get() (interface{}, error) {
    if pool.conns == nil {
        return nil, ErrClosed
    }
    for {
        select {
        case wrapConn := <-pool.conns:
            if wrapConn == nil {
                return nil, ErrClosed
            }
            //判断是否超时,超时则丢弃
            pool.mu.Lock()
            if timeout := pool.timeout; timeout > 0 {
                if wrapConn.createTime.Add(timeout).Before(time.Now()) {
                    delete(pool.poolMap, wrapConn.connect)
                    pool.close(wrapConn.connect)
                    pool.mu.Unlock()
                    continue
                }
            }
            if pool.validate != nil {
                if err := pool.validate(wrapConn.connect); err != nil {
                    delete(pool.poolMap, wrapConn.connect)
                    pool.close(wrapConn.connect)
                    pool.mu.Unlock()
                    continue
                }
            }
            delete(pool.poolMap, wrapConn.connect)
            pool.mu.Unlock()
            return wrapConn.connect, nil
        default:
            pool.mu.Lock()
            if pool.initConn == nil {
                pool.mu.Unlock()
                return nil, ErrInvalidInit
            }
            connect, err := pool.initConn()
            pool.mu.Unlock()
            if err != nil {
                return nil, err
            }
            return connect, nil
        }
    }
}

func (pool *Pool) Put(connect interface{}) error {
    if connect == nil {
        return ErrInvalidConn
    }
    if pool.validate != nil {
        if err := pool.validate(connect); err != nil {
            pool.close(connect)
            return ErrInvalidConn
        }
    }
    pool.mu.Lock()
    if pool.conns == nil {
        pool.mu.Unlock()
        err := pool.close(connect)
        return err
    }
    if _, ok := pool.poolMap[connect]; ok {
        pool.mu.Unlock()
        return nil
    }
    select {
    case pool.conns <- &conn{connect: connect, createTime: time.Now()}:
        pool.poolMap[connect] = 1
        pool.mu.Unlock()
        return nil
    default:
        pool.mu.Unlock()
        //连接池已满,直接关闭该连接
        return pool.close(connect)
    }
}

// Release 释放连接池中所有连接
func (pool *Pool) Release() {
    pool.mu.Lock()
    for wrapConn := range pool.conns {
        pool.close(wrapConn.connect)
    }
    close(pool.conns)
    pool.conns = nil
    pool.initConn = nil
    pool.validate = nil
    pool.close = nil
    pool.poolMap = nil
    pool.mu.Unlock()
}

func (pool *Pool) Len() int {
    return len(pool.conns)
}

//以下是测试代码
func main() {
    go server()
    //等待tcp server启动
    time.Sleep(2 * time.Second)
    client()
    fmt.Println("服务退出")
    time.Sleep(20 * time.Second)
}

func client() {
    initConn := func() (interface{}, error) { return net.Dial("tcp", addr) }
    close := func(v interface{}) error { return v.(net.Conn).Close() }

    //创建一个连接池: 初始化5,最大连接30
    poolConfig := &Config{
        InitCap:  5,
        MaxCap:   30,
        InitConn: initConn,
        Close:    close,
        Timeout:  15 * time.Second,
    }
    p, err := NewPool(poolConfig)
    if err != nil {
        fmt.Println("err=", err)
    }
    //从连接池中取得一个连接
    v, _ := p.Get()
    _, _ = p.Get()

    //do something
    conn := v.(net.Conn)
    conn.Write([]byte("guoqiang"))
    //将连接放回连接池中
    p.Put(v)
    p.Put(v)
    p.Put(v)
    //查看当前连接中的数量
    current := p.Len()
    fmt.Println("len=", current)
}

func server() {
    l, err := net.Listen("tcp", addr)
    if err != nil {
        fmt.Println("Error listening: ", err)
        os.Exit(1)
    }
    defer l.Close()
    fmt.Println("Listening on ", addr)
    for {
        conn, err := l.Accept()
        if err != nil {
            fmt.Println("Error accepting: ", err)
        }
        buffer := make([]byte, 20480)
        conn.Read(buffer)
        fmt.Printf("Received message %s -> %s message: %s\n", conn.RemoteAddr(), conn.LocalAddr(), buffer)
    }
}

上一篇 下一篇

猜你喜欢

热点阅读