Golang

Rpcx源码之Server

2019-02-18  本文已影响25人  神奇的考拉

一、概述

在Rpcx框架源码中,存在Server的角色,用来完成承担Server stub;相对来说Server,每个需要对外提供功能的Service(在rpcx中抽象出来的服务提供者)都需要进行注册。一旦服务定义玩后,可将其暴露出去来使用。通过启动一个TCP或UDP服务器来监听请求。对应的Server提供了如下的功能:

    func NewServer(options ...OptionFn) *Server   // 新建Server
    func (s *Server) Close() error                            // 关闭Server
    func (s *Server) RegisterOnShutdown(f func())    // 注册Shutdown为Server 用于优雅关闭connection
    func (s *Server) Serve(network, address string) (err error)  // 启动Server  以TCP或UDP协议与客户端通信
    func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) // 启动Server  服务以HTTP方式提供

接下来详细剖析对应的源码

二、源码

本文源码主要位于:rpcx\server包中server.go
Server结构

// rpc的server端  支持TCP、UDP
type Server struct {
    ln                net.Listener                   // 监听器 用于
    readTimeout       time.Duration        // 读取client端请求数据包的有效时间
    writeTimeout      time.Duration        // 写入client端响应数据包的有效时间
    gatewayHTTPServer *http.Server         // http网关

    serviceMapMu sync.RWMutex             // 保护service提供service记录表的安全(读多写少使用读写锁)
    serviceMap   map[string]*service         // server端提供service记录表

    mu         sync.RWMutex                      //
    activeConn map[net.Conn]struct{}      // server提供的活跃connection记录表
    doneChan   chan struct{}                    // service完成通知channel
    seq        uint64                                     // server端ID

    inShutdown int32                                 //  中断
    onShutdown []func(s *Server)              // 中断处理函数

    // TLSConfig for creating tls tcp connection.
    tlsConfig *tls.Config  // tls tcp连接时的配置项
    // BlockCrypt for kcp.BlockCrypt
    options map[string]interface{}  // 主要使用kcp协议时的提供的一些限制
    
        Plugins PluginContainer   // 通过plugin的方式 增加server的一些特性

    AuthFunc func(ctx context.Context, req *protocol.Message, token string) error  // 认证

    handlerMsgNum int32  // 处理消息量
}

从上面的Server结构体的源码可以看到提供了Server启动的Listener、读写操作的有效期、以Plugin方式提供Server的额外特性、Server认证、以及优雅关闭Connecton等操作;具体的提供的函数如下:

1.1 新建Server

// 新建Server
// 目前支持OptionFn: WithTLSConfig()、WithReadTimeout()、WithWriteTimeout()
// NewServer returns a server.
func NewServer(options ...OptionFn) *Server {
    s := &Server{
        Plugins: &pluginContainer{},
        options: make(map[string]interface{}),
    }

    for _, op := range options {
        op(s)
    }

    return s
}

1.2 启动Server

// 启动server并监听client的请求
// 该操作属于阻塞的  直到接收到client的连接connection
// Serve支持TCP/UDP以及Http
func (s *Server) Serve(network, address string) (err error) {
    s.startShutdownListener()       // 开启server中断监听 主要用于优雅关闭对应的Connecton
    var ln net.Listener
    ln, err = s.makeListener(network, address)  // 获取对应的listener
    if err != nil {
        return
    }

    if network == "http" {  // 通过http协议
        s.serveByHTTP(ln, "")  // 劫持Http连接
        return nil
    }

    // try to start gateway
    ln = s.startGateway(network, ln)

    return s.serveListener(ln)
}

当非Http协议时

// serveListener 接收ln上的connection
// connection开启一个goroutine来处理对应的请求.
// 对应Service的goroutine读取request并调用真正的Service来执行再返回对应的结果.
func (s *Server) serveListener(ln net.Listener) error {
    if s.Plugins == nil {
        s.Plugins = &pluginContainer{}
    }

    var tempDelay time.Duration

    s.mu.Lock()
    s.ln = ln
    if s.activeConn == nil {  // 记录当前Server活跃的connection
        s.activeConn = make(map[net.Conn]struct{})
    }
    s.mu.Unlock()

    for {
        conn, e := ln.Accept()    // 接收客户端请求
        if e != nil {
            select {
            case <-s.getDoneChan():   // 通过Done 接收一些中断异常的信号 直接停止Server
                return ErrServerClosed
            default:
            }

            if ne, ok := e.(net.Error); ok && ne.Temporary() { // 当出现非中断错误时 进行延迟重试
                if tempDelay == 0 {
                    tempDelay = 5 * time.Millisecond  // 默认5ms
                } else {                                                   // 当设置延迟重试时间 以2^n进行倍数递增(<=1s)
                    tempDelay *= 2
                }

                if max := 1 * time.Second; tempDelay > max { // 防止延迟重试时间过久
                    tempDelay = max
                }

                log.Errorf("rpcx: Accept error: %v; retrying in %v", e, tempDelay)
                time.Sleep(tempDelay) // 休眠当前执行线程
                continue
            }
            return e
        }
        tempDelay = 0  // 延迟重试时间需要置0,防止影响下一次出现net.Error时,延迟重试时间错误

        if tc, ok := conn.(*net.TCPConn); ok {  // tcp连接
            tc.SetKeepAlive(true)     // 保持长连接
            tc.SetKeepAlivePeriod(3 * time.Minute) // 检查周期3min
            tc.SetLinger(10)  // 指定connection的关闭行为
        }

        s.mu.Lock()
        s.activeConn[conn] = struct{}{} // 使用空struct来标记connection已用
        s.mu.Unlock()

        conn, ok := s.Plugins.DoPostConnAccept(conn) // 处理接收到conn:执行connection定义的plugin
        if !ok {
            continue
        }

        go s.serveConn(conn) // 单独开启goroutine来执行connection
    }
}

开启connection,接收request

func (s *Server) serveConn(conn net.Conn) {
        // 
    ...
        // 当=tsl connection需要进行额外设置
    if tlsConn, ok := conn.(*tls.Conn); ok {
        if d := s.readTimeout; d != 0 {
            conn.SetReadDeadline(time.Now().Add(d))
        }
        if d := s.writeTimeout; d != 0 {
            conn.SetWriteDeadline(time.Now().Add(d))
        }
        if err := tlsConn.Handshake(); err != nil {
            log.Errorf("rpcx: TLS handshake error from %s: %v", conn.RemoteAddr(), err)
            return
        }
    }
        // 设置conn读取buffer 默认1kb
    r := bufio.NewReaderSize(conn, ReaderBuffsize)

    for { // 循环执行
        if isShutdown(s) {  //connection是否中断
            closeChannel(s, conn)
            return
        }

        t0 := time.Now()
        if s.readTimeout != 0 {
            conn.SetReadDeadline(t0.Add(s.readTimeout))
        }

        ctx := context.WithValue(context.Background(), RemoteConnContextKey, conn)
        req, err := s.readRequest(ctx, r)  // 获取client的request
        ... 读取request时出现error处理

        if s.writeTimeout != 0 {
            conn.SetWriteDeadline(t0.Add(s.writeTimeout))
        }
              
        ctx = context.WithValue(ctx, StartRequestContextKey, time.Now().UnixNano())
        if !req.IsHeartbeat() { // 认证 非心跳请求
            err = s.auth(ctx, req)
        }

        if err != nil {
            if !req.IsOneway() {  // 是否需要响应
                res := req.Clone() // 
                res.SetMessageType(protocol.Response) // 执行消息类型
                if len(res.Payload) > 1024 && req.CompressType() != protocol.None { //数据压缩传输
                    res.SetCompressType(req.CompressType())
                }
                handleError(res, err) // 
                data := res.Encode() // 编码
  
                                // 处理response
                s.Plugins.DoPreWriteResponse(ctx, req, res)
                conn.Write(data)
                s.Plugins.DoPostWriteResponse(ctx, req, res, err)
                protocol.FreeMsg(res)
            } else { // 不需要response
                s.Plugins.DoPreWriteResponse(ctx, req, nil) 
            }
            protocol.FreeMsg(req) //为减少Message实例化的资源占用,采用缓存Message的方式重用
            continue
        }
        go func() { // 开启一个goroutine处理request
            atomic.AddInt32(&s.handlerMsgNum, 1) // 记录处理msg的数量:待执行时+1
            defer func() { // 执行完成-1
                atomic.AddInt32(&s.handlerMsgNum, -1)
            }()
            if req.IsHeartbeat() { // 心跳包:直接设置Message类型,并对request重新编码返回
                req.SetMessageType(protocol.Response)
                data := req.Encode()
                conn.Write(data)
                return
            }

                        // Metadata获取
            resMetadata := make(map[string]string)
            newCtx := context.WithValue(context.WithValue(ctx, share.ReqMetaDataKey, req.Metadata),
                share.ResMetaDataKey, resMetadata)

            res, err := s.handleRequest(newCtx, req) // 处理request

            if err != nil {
                log.Warnf("rpcx: failed to handle request: %v", err)
            }

            s.Plugins.DoPreWriteResponse(newCtx, req, res) // 预响应前的plugin
            if !req.IsOneway() { // 
                if len(resMetadata) > 0 { //复制request中的metadata
                    meta := res.Metadata
                    if meta == nil {
                        res.Metadata = resMetadata
                    } else {
                        for k, v := range resMetadata {
                            meta[k] = v
                        }
                    }
                }

                if len(res.Payload) > 1024 && req.CompressType() != protocol.None {// 数据压缩
                    res.SetCompressType(req.CompressType())
                }
                data := res.Encode() // 数据编码
                conn.Write(data) // 写入到connection
                //res.WriteTo(conn)
            }
            s.Plugins.DoPostWriteResponse(newCtx, req, res, err) // 执行response后的plugin

            protocol.FreeMsg(req) // Message返回pool中,便于重用
            protocol.FreeMsg(res)
        }()
    }
}

相对来说Http处理比较特殊 单独提出来

func (s *Server) serveByHTTP(ln net.Listener, rpcPath string) {
    s.ln = ln

    if s.Plugins == nil {
        s.Plugins = &pluginContainer{}
    }

    if rpcPath == "" {  // rpcPath为空字符串时,则使用"/_rpcx_"代替
        rpcPath = share.DefaultRPCPath
    }
    http.Handle(rpcPath, s)         // 执行http的Handle  s本身就是一个Handler
    srv := &http.Server{Handler: nil}  // 构建http Server

    s.mu.Lock()
    if s.activeConn == nil {
        s.activeConn = make(map[net.Conn]struct{})
    }
    s.mu.Unlock()

    srv.Serve(ln)  // 启动http Server
}

真正处理request的请求

func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res *protocol.Message, err error) {
    serviceName := req.ServicePath   // Service定义的Path
    methodName := req.ServiceMethod // Service中的方法

    res = req.Clone() // 复制request

    res.SetMessageType(protocol.Response) 
    s.serviceMapMu.RLock()
    service := s.serviceMap[serviceName] // 获取注册的service
    s.serviceMapMu.RUnlock()
    ...省略代码

    mtype := service.method[methodName]
    if mtype == nil {
        if service.function[methodName] != nil { //check raw functions
            return s.handleRequestForFunction(ctx, req)
        }
        err = errors.New("rpcx: can't find method " + methodName)
        return handleError(res, err)
    }

    var argv = argsReplyPools.Get(mtype.ArgType)

    codec := share.Codecs[req.SerializeType()] //获取编码格式
    ...省略代码

    err = codec.Decode(req.Payload, argv) // 解码请求内容
    if err != nil {
        return handleError(res, err)
    }

    ...省略代码

    if mtype.ArgType.Kind() != reflect.Ptr { // service调用
        err = service.call(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv))
    } else {
        err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv))
    }

    ...省略代码
    if !req.IsOneway() {
        data, err := codec.Encode(replyv) //编码response
        argsReplyPools.Put(mtype.ReplyType, replyv)
        if err != nil {
            return handleError(res, err)

        }
        res.Payload = data
    }

    return res, nil  //返回response结果
}

目前rpcx 支持如下的网络类型:

附上优雅关闭连接Connection

// 开启shutdown监听器
func (s *Server) startShutdownListener() {
    go func(s *Server) {
        log.Info("server pid:", os.Getpid())
        c := make(chan os.Signal, 1)   // 使用系统信号
        signal.Notify(c, syscall.SIGTERM)  // 开启服务端shutdown信号通知
        si := <-c                          // 获取channel 系统信号
        if si.String() == "terminated" {   // 中断信号
            if nil != s.onShutdown && len(s.onShutdown) > 0 {  // 执行中断后续操作
                for _, sd := range s.onShutdown {
                    sd(s)
                }
            }
            os.Exit(0)  // 系统退出
        }
    }(s)
}

三、使用

实例server端代码

import (
    "context"
    "flag"
    "github.com/smallnest/rpcx/server"
    "log"
    "net"
    "rpcx/examples/models"
    )

var (
    addr = flag.String("addr","localhost:8972","server address")
)

func main() {
    flag.Parse()

    s := server.NewServer()
    //s.Plugins.Add(&ConnectionListerPlugin{})

    s.Register(new(models.Arith),"")  // 只提供rcvr不指定servicePath和method以及对应的name
    s.RegisterName("PBArith", new(models.PBArith),"") // 提供rcvr及其name
    s.RegisterFunction("PB-Mul",models.Mul,"") // 注入函数(对应的函数不需要提供调用方) 提供servicePath和method
    s.RegisterFunctionName("PMul","mul",models.Mul,"") // 提供servicePath和method及其name

    s.Serve("tcp", *addr)
    //log.Println(" Server Address= " + s.Address().String())
}

type ConnectionListerPlugin struct {
}

func (clis *ConnectionListerPlugin) HandleLister(conn net.Conn) (net.Conn, bool){
    log.Printf("Server Listener Address %v \n", conn.LocalAddr().String())
    return conn,true
}

func mul(ctx context.Context, args *models.Args, reply *models.Reply) error {
    reply.C = args.A * args.B
    return nil
}

实例client端代码

import (
    "context"
    "encoding/json"
    "flag"
    "fmt"
    "rpcx/examples/models"
    "log"
    "github.com/smallnest/rpcx/client"
    "github.com/smallnest/rpcx/protocol"
)

var (
    addr = flag.String("addr", "localhost:8972","server address")
)

func main() {
    flag.Parse()

    c := client.NewClient(client.DefaultOption)
    c.Connect("tcp", *addr)
    defer c.Close()

    args := &models.Args{
        A: 10,
        B: 20,
    }

    reply := &models.Reply{}

    payload,_ := json.Marshal(args)
    // 构建message
    req := protocol.NewMessage()
    req.SetVersion(1)
    req.SetMessageType(protocol.Request)
    req.SetHeartbeat(false)
    req.SetOneway(false)
    req.SetCompressType(protocol.None)
    req.SetMessageStatusType(protocol.Normal)
    req.SetSerializeType(protocol.JSON)
    req.SetSeq(1234567890)

    m := make(map[string]string)
    req.ServicePath = "Arith"
    req.ServiceMethod = "Mul"
    m["__ID"] = "6ba7b810-9dad-11d1-80b4-00c04fd430c9"
    req.Metadata = m
    req.Payload = payload

    _, bytes, err := c.SendRaw(context.Background(),req)
    if err != nil{
        log.Fatalf("failed to call: %v", err)
    }

    json.Unmarshal(bytes,reply)
    //
    fmt.Println(reply.C)
}

辅助代码

package models

import (
    "context"
    "fmt"
    "rpcx/_testutils"
)

// 参数
type Args struct {
    A int
    B int
}

// 回复结果
type Reply struct {
    C int
}

// 服务
type Arith int

func (t *Arith) Mul(ctx context.Context, args *Args, reply *Reply) error {
    reply.C = args.A * args.B
    fmt.Printf("call: %d * %d = %d\n", args.A, args.B, reply.C)
    return nil
}

func (t *Arith) Add(ctx context.Context, args *Args, reply *Reply) error {
    reply.C = args.A + args.B
    fmt.Printf("call: %d + %d = %d\n", args.A, args.B, reply.C)
    return nil
}

func (t *Arith) Say(ctx context.Context, args *string, reply *string) error {
    *reply = "hello " + *args
    return nil
}

type PBArith int

func (t *PBArith) Mul(ctx context.Context, args *testutils.ProtoArgs, reply *testutils.ProtoReply) error {
    reply.C = args.A * args.B
    return nil
}

func (t *Arith) ThriftMul(ctx context.Context, args *testutils.ThriftArgs_, reply *testutils.ThriftReply) error {
    reply.C = args.A * args.B
    return nil
}

func Mul(ctx context.Context, args *testutils.ProtoArgs, reply *testutils.ProtoReply) error {
    reply.C = args.A * args.B
    return nil
}

四、其他

Server源码
Server执行链路如下:


Server分析源码
上一篇下一篇

猜你喜欢

热点阅读