golang grpc mtls=>tcp自动服务降级

2022-07-19  本文已影响0人  万万没想到367

整体思路

借助protoc,根据Proto文件生成支持自动降级的client文件

具体实现

generate.go

type BaseGenerator struct {
    Pkg []*models.File
    Gen *protogen.Plugin
}

func GeneratorFactory(files []*models.File, gen *protogen.Plugin) []Generator {
    return []Generator{&MTLSGenerator{BaseGenerator: BaseGenerator{files, gen}}}
}

type Generator interface {
    Generate() error
}

type MTLSGenerator struct {
    BaseGenerator
}

func (m *MTLSGenerator) Generate() error {
    for index := range m.Pkg {
        p := m.Pkg[index]
        if len(p.Services) == 0 {
            continue
        }
        g := m.Gen.NewGeneratedFile(getFilePath(p.FilenamePrefix), "")
        g.P("// Code generated by protoc-gen-mtls. DO NOT EDIT.")
        g.P()
        g.P("package ", p.GoPackageName)
        g.P()
        g.P("import (")
        g.P("  ", "context \"context\"")
        g.P()
        g.P("  ", "grpc \"google.golang.org/grpc\"")
        g.P()
        needImport := map[string]string{}
        for i, imps := 0, p.Imports; i < imps.Len(); i++ {
            imp := p.Imports.Get(i)
            impFile, ok := m.Gen.FilesByPath[imp.Path()]
            if !ok {
                continue
            }
            if string(impFile.GoImportPath) == p.GoImportPath {
                // Don't generate imports or aliases for types in the same Go package.
                continue
            }
            if !imp.IsWeak {
                needImport[string(imp.Name())] = fmt.Sprintf(`"%s"`, string(impFile.GoImportPath))
            }
        }
        for key, value := range needImport {
            if p.NeedImport[key] {
                g.P("  ", key, " ", value)
            }
        }
        g.P(")")
        g.P()
        for _, service := range p.Services {
            g.P("type ", getFirstLowServiceName(service.Name), "MTLSClient", " struct{")
            g.P("  mtlsClient ", service.Name, "Client")
            g.P("  tcpClient ", service.Name, "Client")
            g.P("}")
            g.P()
            g.P("func", " New", service.Name, "MTLSClient", "(mtls,tcp grpc.ClientConnInterface) ", service.Name, "Client", " {")
            g.P("  ", "if mtls == nil {")
            g.P("    ", "return New", service.Name, "Client(tcp)")
            g.P("  ", "}")
            g.P("  ", "return &", getFirstLowServiceName(service.Name), "MTLSClient{", getServiceName(service.Name, "mtls"), ",",
                getServiceName(service.Name, "tcp"), "}")
            g.P("}")
            g.P()
            for _, method := range service.Methods {
                g.P("func (c *", getFirstLowServiceName(service.Name), "MTLSClient) ", method.Name, "(ctx context.Context, in *", getRequest(p.GoPackageName, method.Req), ", opts ...grpc.CallOption) (*",
                    getRequest(p.GoPackageName, method.Resp), ", error) {")
                g.P("  ", "resp, err := c.mtlsClient.", method.Name, "(ctx, in, opts...)")
                g.P("  ", "if err!=nil {")
                g.P("    ", "return c.tcpClient.", method.Name, "(ctx, in, opts...)")
                g.P("  ", "}")
                g.P("  ", "return resp,err")
                g.P("}")
            }
        }
    }
    return nil
}

func getRequest(packageName, req string) string {
    br := strings.Split(req, ".")
    if len(br) < 2 {
        return br[0]
    }
    if br[len(br)-2] == packageName {
        return br[len(br)-1]
    }

    return br[len(br)-2] + "." + br[len(br)-1]
}

func getServiceName(name, key string) string {
    return fmt.Sprintf("New%sClient(%s)", name, key)
}

func getFirstLowServiceName(name string) string {
    n := strings.ToLower(name[:1])
    return n + name[1:]
}

func getFilePath(prefix string) string {
    return fmt.Sprintf("%s_mtls.pb.go", prefix)
}

parse

func getMessageMap(gen *protogen.Plugin) map[string]*protogen.Message {
    messageMap := make(map[string]*protogen.Message)
    for _, f := range gen.Files {
        for _, message := range f.Messages {
            name := string(message.Desc.FullName())
            messageMap[name] = message
        }
    }
    return messageMap
}

func parseFile(gen *protogen.Plugin) ([]*models.File, error) {
    var files []*models.File
    messageMap := getMessageMap(gen)
    codeMap := make(map[string]string)
    for _, f := range gen.Files {
        var services []*models.Service
        prefix := f.GeneratedFilenamePrefix
        importPath := string(f.GoImportPath)
        goPackageName := string(f.GoPackageName)
        protoPackage := *f.Proto.Package
        file := &models.File{
            FilenamePrefix:   prefix,
            GoImportPath:     importPath,
            GoPackageName:    goPackageName,
            ProtoPackageName: protoPackage,
            Dependency:       f.Proto.Dependency,
            Imports:          f.Desc.Imports(),
        }
        mp := map[string]bool{}
        for _, service := range f.Services {
            name := service.GoName
            svc := &models.Service{Name: name}
            method, m, err := parseMethod(string(f.GoPackageName), service, codeMap, messageMap)
            if err != nil {
                return nil, err
            }
            mp = mergeMap(mp, m)
            svc.Methods = method
            services = append(services, svc)
        }
        file.NeedImport = mp
        file.Services = services
        files = append(files, file)
    }
    return files, nil
}

func mergeMap(n, o map[string]bool) map[string]bool {
    for key, value := range o {
        n[key] = value
    }
    return n
}

func parseMethod(name string, service *protogen.Service, codeMap map[string]string, messageMap map[string]*protogen.Message) ([]*models.Method, map[string]bool, error) {
    var methods []*models.Method
    var needImport = map[string]bool{}
    for _, method := range service.Methods {
        m := &models.Method{Name: method.GoName}
        m.Req = string(method.Input.Desc.FullName())
        m.Resp = string(method.Output.Desc.FullName())
        if !strings.Contains(string(method.Input.Desc.FullName()), name) {
            needImport[getImportName(string(method.Input.Desc.FullName()))] = true
        }
        if !strings.Contains(string(method.Output.Desc.FullName()), name) {
            needImport[getImportName(string(method.Output.Desc.FullName()))] = true
        }
        methods = append(methods, m)
    }
    return methods, needImport, nil
}

func getImportName(i string) string {
    r := strings.Split(i, ".")
    return r[len(r)-2]
}

上一篇 下一篇

猜你喜欢

热点阅读