go语言实现内网穿透(udp版)

2024-10-23  本文已影响0人  今天i你好吗

相关内容

node.js实现内网穿透: https://www.jianshu.com/p/d2d4f8bff599
kotlin实现内网穿透: https://www.jianshu.com/p/c8dc095c758e

最大设计连接数: 65535

前面写了个udp转tcp再转udp的工具, 打算用它和tcp内网穿透结合使用 来实现udp内网穿透, 但是在实际使用中发现存在网速较慢的问题, 初步判断为运营商网络问题(使用http下载也一样, 使用单线程只能达到1MB/S内, 3条就可以达到10MB/S. 上传没有问题). 这个问题没法解决就只好再写个udp版. 本来想用udp打洞写的, 但是有一个网络不支持... 只好用服务器转发写, 但是仍然存在一个小问题, 暂时不打算解决. 结尾会讲.

实现代码

服务端:

package main

import (
    "crypto/rsa"
    "crypto/sha1"
    "crypto/x509"
    "encoding/pem"
    "errors"
    "fmt"
    "log"
    "net"
    "os"
    "strings"
    "sync"
    "time"
)

var (
    privateKeyStr = "-----BEGIN PRIVATE KEY-----\n" +
        "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAJYBf37uy0sVXyxb\n" +
        "bMDjvQzxv/ke3UWCSkYhUd8e+MjGHeT8A9V9aVemg3qUogND/Pgtlz6bTd9p5H+q\n" +
        "OCXnrZbSwdAN7O/9x3zhRzaEOH+9OJ8vpF80DuatbdqqJXFPiDO+nfbufNyT8n+b\n" +
        "N9UXISAVv7Nay+vTEySD401czDDzAgMBAAECgYAE9u26TLrrtDxfInN5+s+R8xpQ\n" +
        "a2YVW9eLdKTaBpNjSbNJldGmqiznWrp1PyARjZl8uT2NM+Si5UVLuF19W6qSC1Tw\n" +
        "bn2brBSsFsufKQff0XXRsD8OqKT5h5/PlRATYgisZonD/v0SPfDHROcNCFiOelEv\n" +
        "kKZAJgJS3vghZx5jCQJBAMTb4esPOApK+6wXlfVShvRiac9UWW9KBYLwpqya5Cjd\n" +
        "X/QUGVwywMFkNgRnBz7yWdJd1zuXfuI77N87BXYOE68CQQDDEjfaKCyZJ73RfNJo\n" +
        "ED5DRfZXm86RPBDlZznJ4shjNVbk1sGClYAk0WuAHeIZJVpm1HpME6NSCfumqeOq\n" +
        "z1P9AkAQu+xJcgK+hT89ksexkfFc5ty9vhrYJf+v8MsKUyRgAOl+MxMwzjOqfN1G\n" +
        "pIduJ2XRRx7btvYXPybUlwzQy0OLAkBIexlznt/LXH/kOcv4TKjF2FYLAWKEhlwE\n" +
        "0REg2Xn5mtUZnE40lhYSGBoodXIQQ9fOQ37Zi6ZwkjMGHzPvwK+FAkAe9gHRMI2u\n" +
        "lzwVp/AQntBMXmw92IULIlfRmfV1jDBYuT0JHUUGZqfCrz+iDW8Ot24QBzLbwKxJ\n" +
        "JRrmXbUxObmC\n" +
        "-----END PRIVATE KEY-----"
    natDispatcherAddr = ":8989"
    natPasswords      = []string{"yzh"}
    timeOut           = time.Second * 60 * 2
    serverMap         = map[string]*Server{}
    mapMutex          sync.Mutex
    clientIndex       = 0
    maxLength         = 0
)

func main() {
    log.Println("入参: " + strings.Join(os.Args[1:], " "))
    if len(os.Args) == 2 {
        natDispatcherAddr = os.Args[1]
    }
    log.SetFlags(log.LstdFlags | log.Lshortfile)
    startServer()
}

func Start(natDispatcherAddrStr string, natPasswordsArr []string) {
    natDispatcherAddr = natDispatcherAddrStr
    natPasswords = natPasswordsArr
    startServer()
}

func startServer() {
    log.Println("NatU分发服务地址: " + natDispatcherAddr)
    go func() {
        for {
            time.Sleep(time.Second * 60)
            println()
            log.Println("natU转发Size:", len(serverMap), ",maxLength:", maxLength)
            mapMutex.Lock()
            for _, value := range serverMap {
                log.Println("natU转发中:", value.toString())
            }
            mapMutex.Unlock()
        }
    }()
    privateKey, err := loadPrivateKey(privateKeyStr)
    if err != nil {
        log.Panicln(err)
    }

    listenerAddr, err := net.ResolveUDPAddr("udp", natDispatcherAddr)
    if err != nil {
        log.Println(err)
        return
    }
    network := "udp"
    if listenerAddr.IP.To4() != nil {
        network = "udp4"
    } else if listenerAddr.IP.To16() != nil {
        network = "udp6"
    }
    listenerConn, err := net.ListenUDP(network, listenerAddr)
    if err != nil {
        log.Println(err)
        return
    }
    defer listenerConn.Close()
    buffer := make([]byte, 1024*64)
    for {
        // log.Println("监听消息")
        n, clientAddr, err := listenerConn.ReadFromUDP(buffer)
        if err != nil {
            log.Println(err)
            if errors.Is(err, net.ErrClosed) {
                break
            }
            time.Sleep(time.Second)
            continue
        }
        if n > maxLength {
            maxLength = n
        }
        if n < 1 {
            log.Println("异常消息:", clientAddr)
            continue
        }
        data := make([]byte, n)
        copy(data, buffer)

        cmd := data[0]
        // log.Println("收到消息:", cmd, clientAddr, "=>", listenerConn.LocalAddr())

        switch {
        case cmd == 1:
            // log.Println("心跳数据")
            realData := data[1:]
            // 鉴权
            value := handleNatUAuth(realData, privateKey, clientAddr, listenerConn)
            listenerConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
            listenerConn.WriteToUDP([]byte{value}, clientAddr)
        case cmd == 0:
            // 真实数据
            cIndex := int(data[1])*256 + int(data[2])
            realData := data[3:]
            // log.Println("转发到:", cIndex)
            handleRealData(cIndex, realData)
        }
    }
}

func handleRealData(cIndex int, realData []byte) {
    servers := []*Server{}
    mapMutex.Lock()
    for _, server := range serverMap {
        servers = append(servers, server)
    }
    mapMutex.Unlock()
    for _, server := range servers {
        server.clientMutex.Lock()
        for _, value := range server.clientMap {
            if value.index == cIndex {
                // log.Println("转成功:", cIndex, value.clientAddr)
                value.openConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
                value.openConn.WriteToUDP(realData, value.clientAddr)
                value.lastLime = time.Now()
                break
            }
        }
        server.clientMutex.Unlock()
    }
}

func handleNatUAuth(cmdData []byte, privateKey *rsa.PrivateKey, clientAddr *net.UDPAddr, listenerConn *net.UDPConn) byte {
    // decryptedText, err := rsa.DecryptPKCS1v15(nil, privateKey, cmdData)
    decryptedText, err := rsa.DecryptOAEP(sha1.New(), nil, privateKey, cmdData, nil)
    if err != nil {
        log.Println("DecryptPKCS1v15 err", err)
        return 3
    }
    info := string(decryptedText)
    infos := strings.Split(info, "-")
    if len(infos) < 2 {
        log.Println("infos error", infos)
        return 4
    }
    version := 1
    pwdOk := false
    for _, v := range natPasswords {
        if v == infos[1] {
            version = 1
            pwdOk = true
            break
        }
    }
    for _, v := range natPasswords {
        if v == infos[0] {
            version = 2
            pwdOk = true
            break
        }
    }
    if !pwdOk {
        log.Println("密码错误", infos)
        return 5
    }
    var openAddrs []string
    switch version {
    case 1:
        openAddrs = infos[:1]
    case 2:
        openAddrs = infos[1:]
    default:
        openAddrs = infos[:1]
    }
    // log.Println("openAddrs:", openAddrs)

    for index, address := range openAddrs {
        server := getServer(address, index, version, clientAddr, strings.Join(openAddrs, "-"), listenerConn)
        if server == nil {
            log.Println("getServer nil")
            return 6
        }
    }
    return 1
}

func getServer(address string, index int, version int, clientAddr *net.UDPAddr, openAddrsStr string, listenerConn *net.UDPConn) *Server {
    addr := strings.Split(address, ":")
    if len(addr) < 2 {
        log.Println("address error", address)
        return nil
    }
    port := addr[len(addr)-1]
    mapMutex.Lock()
    defer mapMutex.Unlock()
    server := serverMap[port]
    if server != nil {
        if server.address != address {
            log.Println("端口重复", server.address, address)
            return nil
        }
        if server.openAddrsStr != openAddrsStr {
            log.Println("openAddrsStr不同", server.openAddrsStr, openAddrsStr)
            return nil
        }
        if server.version != version {
            log.Println("版本不同", server.version, version)
            return nil
        }
    } else {
        listenerAddr, err := net.ResolveUDPAddr("udp", address)
        if err != nil {
            log.Println(err)
            return nil
        }
        network := "udp"
        if listenerAddr.IP.To4() != nil {
            network = "udp4"
        } else if listenerAddr.IP.To16() != nil {
            network = "udp6"
        }
        openConn, err := net.ListenUDP(network, listenerAddr)
        if err != nil {
            log.Println(err)
            return nil
        }
        log.Println("开放地址:", address, "对方index:", index, "版本:", version)
        server = &Server{address: address, index: index, version: version, createTime: time.Now(),
            openConn: openConn, openAddrsStr: openAddrsStr, clientMap: map[string]*Client{}}
        serverMap[port] = server
        go server.accept(port, listenerConn)
    }
    server.clientAddr = clientAddr
    server.openConn.SetReadDeadline(time.Now().Add(timeOut))
    return server
}

type Server struct {
    openAddrsStr string
    index        int
    version      int
    createTime   time.Time
    address      string
    openConn     *net.UDPConn
    clientAddr   *net.UDPAddr
    clientMap    map[string]*Client
    clientMutex  sync.Mutex
}

type Client struct {
    index      int
    lastLime   time.Time
    clientAddr *net.UDPAddr
    openConn   *net.UDPConn
}

func (server *Server) accept(port string, listenerConn *net.UDPConn) {
    buffer := make([]byte, 1024*64)
    for {
        n, clientAddr, err := server.openConn.ReadFromUDP(buffer)
        if err != nil {
            log.Println(err)
            break
        }
        if n > maxLength {
            maxLength = n
        }
        if n < 1 {
            log.Println("空消息:", clientAddr, port)
            continue
        }
        data := make([]byte, n)
        copy(data, buffer)
        // log.Println("转发消息:", len, clientAddr, "=>", server.index)

        server.clientMutex.Lock()
        client := server.clientMap[clientAddr.String()]
        if client == nil {
        getClient:
            for {
                clientIndex++
                index := clientIndex
                for key, value := range server.clientMap {
                    if value.index == index {
                        log.Println("index无效:", index)
                        continue getClient
                    }
                    if time.Since(value.lastLime) >= timeOut {
                        delete(server.clientMap, key)
                    }
                }
                client = &Client{index: index, clientAddr: clientAddr, openConn: server.openConn}
                server.clientMap[clientAddr.String()] = client
                break
            }
        }
        client.lastLime = time.Now()
        server.clientMutex.Unlock()
        listenerConn.SetWriteDeadline(time.Now().Add(time.Second * 3))
        listenerConn.WriteToUDP(append([]byte{byte(10 + server.index), byte(client.index / 256), byte(client.index % 256)}, data...), server.clientAddr)
    }
    server.openConn.Close()
    mapMutex.Lock()
    delete(serverMap, port)
    mapMutex.Unlock()
    log.Println("释放端口:", port)
}

func (server *Server) toString() string {
    ms := time.Since(server.createTime).Milliseconds()
    s := ms / 1000
    m := s / 60
    h := m / 60
    runTime := fmt.Sprintf("%d天%d时%d分%d秒", h/24, h%24, m%60, s%60)
    return fmt.Sprintf("%s=>%s, index: %d, version: %d, %s", server.address, server.clientAddr, server.index, server.version, runTime)
}

func loadPrivateKey(privateKeyStr string) (privateKey *rsa.PrivateKey, err error) {
    block, _ := pem.Decode([]byte(privateKeyStr))
    if block == nil {
        return nil, fmt.Errorf("解码私钥失败")
    }
    key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
    if err != nil {
        return nil, err
    }
    privateKey, ok := key.(*rsa.PrivateKey)
    if !ok {
        return nil, fmt.Errorf("非法私钥文件")
    }
    return privateKey, nil
}

客户端:

package main

import (
    "crypto/rand"
    "crypto/rsa"
    "crypto/sha1"
    "crypto/x509"
    "encoding/pem"
    "errors"
    "fmt"
    "log"
    "net"
    "os"
    "strings"
    "sync"
    "time"
)

var (
    publicKeyStr = "-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCWAX9+7stLFV8sW2zA470M8b/5\nHt1FgkpGIVHfHvjIxh3k/APVfWlXpoN6lKIDQ/z4LZc+m03faeR/qjgl562W0sHQ\nDezv/cd84Uc2hDh/vTifL6RfNA7mrW3aqiVxT4gzvp327nzck/J/mzfVFyEgFb+z\nWsvr0xMkg+NNXMww8wIDAQAB\n-----END PUBLIC KEY-----\n"

    natDispatcherAddr = "127.0.0.1:8989"
    natPassword       = "yzh"
    rateInterval      = time.Second * 15
    timeOut           = time.Second * 60 * 2
    natMapArr         = []string{
        ":1701-192.168.3.25:1701",
        ":11771-127.0.0.1:1771",
        ":11772-127.0.0.1:1772",
        ":11773-127.0.0.1:1773",
    }
    errMap = map[int]string{
        2: "config error or port used",
        3: "DecryptPKCS1v15 err",
        4: "infos error",
        5: "密码错误",
        6: "getServer nil",
        7: "",
        8: "",
        9: "",
    }
    listenerConn *net.UDPConn
    serverAddr   *net.UDPAddr
    mapMutex     sync.Mutex
    forwardMap   = map[int]*net.UDPConn{}
    natSuccess   = false
    maxLength    = 0
)

func main() {
    log.Println("入参: " + strings.Join(os.Args[1:], " "))
    if len(os.Args) == 2 {
        natDispatcherAddr = os.Args[1]
    }
    log.SetFlags(log.LstdFlags | log.Lshortfile)
    startClient()
}

func Start(natDispatcherAddrStr string, natPasswordStr string, natMapStr []string) {
    natDispatcherAddr = natDispatcherAddrStr
    natPassword = natPasswordStr
    natMapArr = natMapStr
    startClient()
}

func startClient() {
    log.Println("NatU分发服务地址: " + natDispatcherAddr)
    publicKey, err := loadPublicKey(publicKeyStr)
    if err != nil {
        log.Panicln(err)
    }
    natParams := []string{natPassword}
    // 通过index找到需要转发的位置
    localServerAddrs := []string{}
    for _, v := range natMapArr {
        mapArr := strings.Split(v, "-")
        natServerOpenAddr := mapArr[0]
        localServerAddr := mapArr[1]
        log.Println("NatU服务器开放地址: "+natServerOpenAddr, "本地服务地址: "+localServerAddr)
        natParams = append(natParams, natServerOpenAddr)
        localServerAddrs = append(localServerAddrs, localServerAddr)
    }
    //密码, 开放端口
    natInfo, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, []byte(strings.Join(natParams, "-")), nil)
    if err != nil {
        log.Panicln(err)
    }
    go func() {
        for {
            log.Println("startNat")
            creatClient(localServerAddrs)
            listenerConn = nil
            serverAddr = nil
            natSuccess = false
            time.Sleep(time.Second)
        }
    }()

    go func() {
        for {
            if listenerConn == nil || serverAddr == nil {
                time.Sleep(time.Second)
                continue
            }
            // log.Println("主动发消息:", listenerConn.LocalAddr(), "=>", serverAddr)
            data := append([]byte{1}, natInfo...)
            listenerConn.SetWriteDeadline(time.Now().Add(time.Second * 5))
            listenerConn.WriteToUDP(data, serverAddr)
            time.Sleep(rateInterval)
        }
    }()
    for {
        time.Sleep(time.Second * 60)
        println()
        log.Println("natU转发Size:", len(natMapArr), ",maxLength:", maxLength, ",natSuccess:", natSuccess)
        for _, value := range natMapArr {
            log.Println("natU转发中:", strings.ReplaceAll(value, "-", "=>"))
        }
    }
}

func creatClient(localServerAddrs []string) {
    var err error
    serverAddr, err = net.ResolveUDPAddr("udp", natDispatcherAddr)
    if err != nil {
        log.Println(err)
        return
    }
    listenerAddr, err := net.ResolveUDPAddr("udp", ":0")
    if err != nil {
        log.Println(err)
        return
    }
    listenerConn, err = net.ListenUDP("udp", listenerAddr)
    if err != nil {
        log.Println(err)
        return
    }
    defer listenerConn.Close()
    buffer := make([]byte, 1024*64)
    for {
        listenerConn.SetReadDeadline(time.Now().Add(timeOut))
        n, clientAddr, err := listenerConn.ReadFromUDP(buffer)
        if err != nil {
            log.Println(err)
            break
        }
        if n > maxLength {
            maxLength = n
        }
        if n < 1 {
            log.Println("异常消息:", clientAddr)
            continue
        }
        data := make([]byte, n)
        copy(data, buffer)

        // if clientAddr.Port != serverAddr.Port || clientAddr.IP.String() != serverAddr.IP.String() {
        if clientAddr.Port != serverAddr.Port {
            log.Println("异常消息:", serverAddr, clientAddr)
            continue
        }
        cmd := data[0]
        // log.Println("收到响应:", cmd, clientAddr, "=>", listenerConn.LocalAddr())
        switch {
        case cmd == 1:
            if !natSuccess {
                natSuccess = true
                log.Println("natU建立成功")
            }
            // log.Println("心跳数据")
        case cmd > 1 && cmd < 10:
            errMsg := errMap[int(cmd)]
            if errMsg == "" {
                errMsg = "config error"
            }
            log.Panicln(cmd, errMsg, natMapArr)
        case cmd >= 10:
            forwardAddress := localServerAddrs[data[0]-10]
            clinetIndex := int(data[1])*256 + int(data[2])
            realData := data[3:]
            // log.Println("转发到:", forwardAddress)
            handleClientRequest(clientAddr, realData, forwardAddress, clinetIndex)
        }
    }
}

func handleClientRequest(clientAddr *net.UDPAddr, clientData []byte, forwardAddress string, clinetIndex int) {
    if clientAddr == nil {
        return
    }
    clientAddrString := clientAddr.String()
    mapMutex.Lock()
    defer mapMutex.Unlock()
    forwardConn := forwardMap[clinetIndex]
    if forwardConn == nil {
        forwardAddr, err := net.ResolveUDPAddr("udp", forwardAddress)
        if err != nil {
            log.Println(err)
            return
        }
        forwardConn, err = net.DialUDP("udp", nil, forwardAddr)
        if err != nil {
            log.Println(err)
            return
        }
        infoStr := clientAddrString + "=>" + forwardAddress + "=>" + forwardConn.LocalAddr().String() + "=>" + forwardConn.RemoteAddr().String()
        log.Println("添加udp转发:" + infoStr)
        forwardMap[clinetIndex] = forwardConn
        buffer := make([]byte, 1024*64)
        go func() {
            defer forwardConn.Close()
            forwardSuccess := false
            for {
                forwardConn.SetReadDeadline(time.Now().Add(timeOut))
                n, serverAddr, err := forwardConn.ReadFromUDP(buffer)
                if err != nil {
                    log.Println(err)
                    if errors.Is(err, net.ErrClosed) {
                        break
                    }
                    if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
                        break
                    }
                    time.Sleep(time.Second)
                    continue
                }
                // if serverAddr.Port != forwardAddr.Port || serverAddr.IP.String() != forwardAddr.IP.String() {
                if serverAddr.Port != forwardAddr.Port {
                    log.Println("异常消息:", serverAddr.String(), forwardAddr.String())
                    continue
                }
                if n > maxLength {
                    maxLength = n
                }
                if !forwardSuccess {
                    forwardSuccess = true
                    log.Println("udp转发成功:", serverAddr.String(), n, clientAddrString)
                }
                // log.Println("服务端消息:", serverAddr.String(), len, clientAddrString)
                data := make([]byte, n)
                copy(data, buffer)
                listenerConn.SetWriteDeadline(time.Now().Add(time.Second * 5))
                listenerConn.WriteToUDP(append([]byte{0, byte(clinetIndex / 256), byte(clinetIndex % 256)}, data...), clientAddr)
            }
            log.Println("移除udp:" + infoStr)
            mapMutex.Lock()
            delete(forwardMap, clinetIndex)
            mapMutex.Unlock()
        }()
    }
    // log.Println("客户端消息:", clientAddrString, len(clientData))
    forwardConn.SetWriteDeadline(time.Now().Add(time.Second * 5))
    forwardConn.Write(clientData)
}

func loadPublicKey(publicKeyStr string) (publicKey *rsa.PublicKey, err error) {
    block, _ := pem.Decode([]byte(publicKeyStr))
    if block == nil {
        return nil, fmt.Errorf("解码公钥失败")
    }
    key, err := x509.ParsePKIXPublicKey(block.Bytes)
    if err != nil {
        return nil, err
    }
    publicKey, ok := key.(*rsa.PublicKey)
    if !ok {
        return nil, fmt.Errorf("非法公钥文件")
    }
    return publicKey, nil
}


存在的问题:

客户端fd00::2给fd00::5发消息 服务端知道是fd00::2发来的, 但是不知道是哪个ip接收的, 也无法控制使用哪个ip回消息, 测试中发现服务端可能会用fd00::6发消息给fd00::2, 在部分网络下这个消息是发送不过去的(这也是我没用打洞法的原因), 问题点就在这里. 解决方案也很简单, 分别监听每个ip, 但是需要监听设备ip的变化, 不想这样做. 不知道有没有大佬有更好的解决方案

画个草图好理解些:

image.png
上一篇下一篇

猜你喜欢

热点阅读