go语言实现内网穿透

2024-02-13  本文已影响0人  今天i你好吗

相关内容

node.js实现内网穿透: https://www.jianshu.com/p/d2d4f8bff599
kotlin实现内网穿透: https://www.jianshu.com/p/c8dc095c758e
可以和node.js, kotlin版混用
使用方式见node.js版,大同小异,部分逻辑稍有调整

实现代码

服务端:

package main

import (
    "crypto/rsa"
    "crypto/sha1"
    "crypto/x509"
    "encoding/pem"
    "fmt"
    "io"
    "log"
    "net"
    "os"
    "strings"
    "sync"
    "time"
)

var (
    privateKeyStr = "-----BEGIN PRIVATE KEY-----\n" +
        "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAJYBf37uy0sVXyxb\n" +
        "bMDjvQzxv/ke3UWCSkYhUd8e+MjGHeT8A9V9aVemg3qUogND/Pgtlz6bTd9p5H+q\n" +
        "OCXnrZbSwdAN7O/9x3zhRzaEOH+9OJ8vpF80DuatbdqqJXFPiDO+nfbufNyT8n+b\n" +
        "N9UXISAVv7Nay+vTEySD401czDDzAgMBAAECgYAE9u26TLrrtDxfInN5+s+R8xpQ\n" +
        "a2YVW9eLdKTaBpNjSbNJldGmqiznWrp1PyARjZl8uT2NM+Si5UVLuF19W6qSC1Tw\n" +
        "bn2brBSsFsufKQff0XXRsD8OqKT5h5/PlRATYgisZonD/v0SPfDHROcNCFiOelEv\n" +
        "kKZAJgJS3vghZx5jCQJBAMTb4esPOApK+6wXlfVShvRiac9UWW9KBYLwpqya5Cjd\n" +
        "X/QUGVwywMFkNgRnBz7yWdJd1zuXfuI77N87BXYOE68CQQDDEjfaKCyZJ73RfNJo\n" +
        "ED5DRfZXm86RPBDlZznJ4shjNVbk1sGClYAk0WuAHeIZJVpm1HpME6NSCfumqeOq\n" +
        "z1P9AkAQu+xJcgK+hT89ksexkfFc5ty9vhrYJf+v8MsKUyRgAOl+MxMwzjOqfN1G\n" +
        "pIduJ2XRRx7btvYXPybUlwzQy0OLAkBIexlznt/LXH/kOcv4TKjF2FYLAWKEhlwE\n" +
        "0REg2Xn5mtUZnE40lhYSGBoodXIQQ9fOQ37Zi6ZwkjMGHzPvwK+FAkAe9gHRMI2u\n" +
        "lzwVp/AQntBMXmw92IULIlfRmfV1jDBYuT0JHUUGZqfCrz+iDW8Ot24QBzLbwKxJ\n" +
        "JRrmXbUxObmC\n" +
        "-----END PRIVATE KEY-----"
    natPasswords = []string{"yzh"}
    headSize     = 128
    serverMap    = map[string]*Server{}
    mutex        sync.Mutex
    timeOut      = time.Second * 15
)

type ConnPair struct {
    serverClient *net.Conn
    client       *net.Conn
}

type Server struct {
    emptyCount   int
    address      string
    mutex        sync.Mutex
    wMutex       sync.Mutex
    natConnPairs []*ConnPair
    paddingConns []*net.Conn
}

func (server *Server) accept(listener *net.Listener, port string) {
    count := 0
    go func() {
        for {
            time.Sleep(time.Second * 60)
            if len(server.natConnPairs) == 0 {
                server.emptyCount++
                if server.emptyCount > 5 {
                    count = 8
                    (*listener).Close()
                    break
                }
            } else {
                server.emptyCount = 0
            }
        }
    }()
    for {
        client, err := (*listener).Accept()
        if err != nil {
            log.Println(err)
            count++
            if count > 5 {
                break
            }
            time.Sleep(time.Second)
            continue
        }
        log.Println(client.LocalAddr(), client.RemoteAddr())
        server.dispatcherNat(&client)
        count = 0
    }
    serverMap[port] = nil
}
func (server *Server) dispatcherNat(client *net.Conn) {
    server.mutex.Lock()
    defer server.mutex.Unlock()
    if len(server.natConnPairs) == 0 {
        server.paddingConns = append(server.paddingConns, client)
    } else {
        connPair := server.natConnPairs[0]
        server.natConnPairs = server.natConnPairs[1:]
        go startNat(connPair, client, &server.wMutex)
    }
}
func (server *Server) addNatServerClient(connPair *ConnPair) {
    server.mutex.Lock()
    defer server.mutex.Unlock()
    server.emptyCount = 0
    if len(server.paddingConns) == 0 {
        server.natConnPairs = append(server.natConnPairs, connPair)
    } else {
        client := server.paddingConns[0]
        server.paddingConns = server.paddingConns[1:]
        go startNat(connPair, client, &server.wMutex)
    }
}
func (server *Server) removeNatServerClient(connPair *ConnPair) {
    server.mutex.Lock()
    defer server.mutex.Unlock()
    for i, c := range server.natConnPairs {
        if c == connPair {
            server.natConnPairs = append(server.natConnPairs[:i], server.natConnPairs[i+1:]...)
            return
        }
    }
}
func startNat(connPair *ConnPair, client *net.Conn, wMutex *sync.Mutex) {
    wMutex.Lock()
    connPair.client = client
    wMutex.Unlock()
    serverClient := *connPair.serverClient
    defer serverClient.Close()
    defer (*connPair.client).Close()
    serverClient.SetWriteDeadline(time.Now().Add(timeOut))
    buf := make([]byte, 1)
    buf[0] = 1
    _, err := serverClient.Write(buf)
    if err != nil {
        log.Println("Write err", err)
        return
    }
    switchData(serverClient, *connPair.client)
}

func main() {
    log.Println("入参: " + strings.Join(os.Args[1:], " "))
    dispatcherAddress := ":8989"
    if len(os.Args) == 2 {
        dispatcherAddress = os.Args[1]
    }
    log.SetFlags(log.LstdFlags | log.Lshortfile)
    log.Println("Nat分发服务地址: " + dispatcherAddress)
    privateKey, err := loadPrivateKey(privateKeyStr)
    if err != nil {
        log.Panicln(err)
    }

    listener, err := net.Listen("tcp", dispatcherAddress)
    if err != nil {
        log.Panicln(err)
    }
    for {
        client, err := listener.Accept()
        if err != nil {
            log.Println(err)
            time.Sleep(time.Second)
            continue
        }
        log.Println(client.LocalAddr(), client.RemoteAddr())
        go handleClientRequest(&client, privateKey)
    }
}

func handleClientRequest(client *net.Conn, privateKey *rsa.PrivateKey) {
    serverClient := *client
    defer serverClient.Close()
    serverClient.SetReadDeadline(time.Now().Add(timeOut))
    cmdData := make([]byte, headSize)
    n, err := io.ReadFull(serverClient, cmdData)
    buf := make([]byte, 1)
    buf[0] = 2
    serverClient.SetWriteDeadline(time.Now().Add(timeOut))
    if err != nil {
        log.Println("ReadFull err", err)
        serverClient.Write(buf)
        return
    }
    if n != len(cmdData) {
        log.Println("读取长度错误:", n)
        serverClient.Write(buf)
        return
    }

    // decryptedText, err := rsa.DecryptPKCS1v15(nil, privateKey, cmdData)
    decryptedText, err := rsa.DecryptOAEP(sha1.New(), nil, privateKey, cmdData, nil)
    if err != nil {
        log.Println("DecryptPKCS1v15 err", err)
        serverClient.Write(buf)
        return
    }
    info := string(decryptedText)
    infos := strings.Split(info, "-")
    if len(infos) != 2 {
        log.Println("infos error", infos)
        serverClient.Write(buf)
        return
    }
    pwdOk := false
    for _, v := range natPasswords {
        if v == infos[1] {
            pwdOk = true
            break
        }
    }
    if !pwdOk {
        log.Println("密码错误", infos)
        serverClient.Write(buf)
        return
    }
    server := getServer(infos[0])
    if server == nil {
        log.Println("getServer nil")
        serverClient.Write(buf)
        return
    }
    connPair := &ConnPair{serverClient: &serverClient}
    server.addNatServerClient(connPair)

    n = 1
    for {
        if n == 1 {
            serverClient.SetReadDeadline(time.Now().Add(timeOut))
        }
        n, err = serverClient.Read(buf)
        if err != nil {
            server.removeNatServerClient(connPair)
            return
        }
        if n != 1 {
            continue
        }
        if buf[0] == 0 {
            break
        }
        func() {
            server.wMutex.Lock()
            defer server.wMutex.Unlock()
            if connPair.client != nil {
                log.Println("nating")
                return
            }
            serverClient.SetWriteDeadline(time.Now().Add(timeOut))
            buf[0] = 0
            _, err = serverClient.Write(buf)
            if err != nil {
                server.removeNatServerClient(connPair)
                return
            }
        }()
    }
    if connPair.client == nil {
        log.Println("nat client nil")
        return
    }
    defer (*connPair.client).Close()
    switchData(*connPair.client, *connPair.serverClient)
}

func getServer(address string) *Server {
    addr := strings.Split(address, ":")
    if len(addr) != 2 {
        log.Println("address error", address)
        return nil
    }
    mutex.Lock()
    defer mutex.Unlock()
    server := serverMap[addr[1]]
    if server != nil {
        if server.address != address {
            log.Println("端口重复", server.address, address)
            return nil
        }
    } else {
        listener, err := net.Listen("tcp", address)
        if err != nil {
            log.Println("listen error", err)
            return nil
        }
        server = &Server{address: address}
        serverMap[addr[1]] = server
        go server.accept(&listener, addr[1])
    }
    return server
}

func loadPrivateKey(privateKeyStr string) (privateKey *rsa.PrivateKey, err error) {
    block, _ := pem.Decode([]byte(privateKeyStr))
    if block == nil {
        return nil, fmt.Errorf("解码私钥失败")
    }
    key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
    if err != nil {
        return nil, err
    }
    privateKey, ok := key.(*rsa.PrivateKey)
    if !ok {
        return nil, fmt.Errorf("非法私钥文件")
    }
    return privateKey, nil
}

func switchData(dst net.Conn, src net.Conn) (written int64, err error) {
    buf := make([]byte, 10240)
    resetTimeOut := true
    for {
        if resetTimeOut {
            src.SetReadDeadline(time.Now().Add(timeOut))
            dst.SetReadDeadline(time.Now().Add(timeOut))
            resetTimeOut = false
        }
        nr, er := src.Read(buf)
        if nr > 0 {
            resetTimeOut = true
            dst.SetWriteDeadline(time.Now().Add(timeOut))
            nw, ew := dst.Write(buf[0:nr])
            if nw < 0 || nr < nw {
                nw = 0
                if ew == nil {
                    ew = fmt.Errorf("invalid write result")
                }
            }
            written += int64(nw)
            if ew != nil {
                err = ew
                break
            }
            if nr != nw {
                err = io.ErrShortWrite
                break
            }
        }
        if er != nil {
            if er != io.EOF {
                err = er
            }
            break
        }
    }
    if err != nil {
        log.Println(err)
    }
    return written, err
}

客户端:

package main

import (
    "crypto/rand"
    "crypto/rsa"
    "crypto/sha1"
    "crypto/x509"
    "encoding/pem"
    "fmt"
    "io"
    "log"
    "net"
    "os"
    "strings"
    "sync"
    "time"
)

var (
    publicKeyStr = "-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCWAX9+7stLFV8sW2zA470M8b/5\nHt1FgkpGIVHfHvjIxh3k/APVfWlXpoN6lKIDQ/z4LZc+m03faeR/qjgl562W0sHQ\nDezv/cd84Uc2hDh/vTifL6RfNA7mrW3aqiVxT4gzvp327nzck/J/mzfVFyEgFb+z\nWsvr0xMkg+NNXMww8wIDAQAB\n-----END PUBLIC KEY-----\n"

    natDispatcherAddr = "192.168.10.18:8989"
    natServerOpenAddr = ":9001"
    localServerAddr   = "127.0.0.1:8080"
    natPassword       = "yzh"
    maxFreeNat        = 3
    rateInterval      = time.Second * 5
    timeOut           = time.Second * 15
)

func main() {
    log.Println("入参: " + strings.Join(os.Args[1:], " "))
    log.SetFlags(log.LstdFlags | log.Lshortfile)
    log.Println("Nat分发服务地址: " + natDispatcherAddr)
    log.Println("Nat服务器开放地址: " + natServerOpenAddr)
    log.Println("本地服务地址: " + localServerAddr)
    publicKey, err := loadPublicKey(publicKeyStr)
    if err != nil {
        log.Panicln(err)
    }
    natInfo, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, []byte(natServerOpenAddr+"-"+natPassword), nil)
    if err != nil {
        log.Panicln(err)
    }
    for i := 1; i < maxFreeNat; i++ {
        go start(natInfo, i)
    }
    start(natInfo, 0)
}
func start(natInfo []byte, index int) {
    for {
        log.Println("startNatId", index)
        startNat(natInfo)
    }
}
func startNat(natInfo []byte) {
    server, err := net.DialTimeout("tcp", natDispatcherAddr, timeOut)
    if err != nil {
        log.Println(err)
        time.Sleep(time.Second)
        return
    }
    server.SetWriteDeadline(time.Now().Add(timeOut))
    n, err := server.Write(natInfo)
    if err != nil {
        log.Println(err)
        server.Close()
        return
    }
    if n != len(natInfo) {
        log.Println("写入长度错误", n, len(natInfo))
        server.Close()
        return
    }

    mutex := &sync.Mutex{}
    nating := false
    go func() {
        buf := make([]byte, 1)
        buf[0] = 1
        n = 1
        for {
            time.Sleep(rateInterval)
            mutex.Lock()
            if nating {
                break
            }
            if n == 1 {
                server.SetWriteDeadline(time.Now().Add(timeOut))
            }
            n, err = server.Write(buf)
            if err != nil {
                log.Println(err)
                break
            }
            mutex.Unlock()
        }
        mutex.Unlock()
    }()

    buf := make([]byte, 1)
    r := 1
    for {
        if r == 1 {
            if buf[0] == 1 {
                break
            }
            if buf[0] == 2 {
                log.Panicln("config error or port used", natServerOpenAddr)
            }
            server.SetReadDeadline(time.Now().Add(timeOut))
        }
        r, err = server.Read(buf)
        if err != nil {
            log.Println(err)
            server.Close()
            return
        }
    }
    mutex.Lock()
    nating = true
    mutex.Unlock()
    go startServer(&server)
}

func startServer(client *net.Conn) {
    buf := make([]byte, 1)
    buf[0] = 0
    defer (*client).Close()
    (*client).SetWriteDeadline(time.Now().Add(timeOut))
    n, err := (*client).Write(buf)
    if err != nil {
        log.Println(err)
        return
    }
    if n != 1 {
        log.Println("写入开始命令错误")
        return
    }
    server, err := net.DialTimeout("tcp", localServerAddr, timeOut)
    if err != nil {
        log.Println(err)
        return
    }
    defer server.Close()
    go switchData(*client, server)
    switchData(server, *client)
}

func loadPublicKey(publicKeyStr string) (publicKey *rsa.PublicKey, err error) {
    block, _ := pem.Decode([]byte(publicKeyStr))
    if block == nil {
        return nil, fmt.Errorf("解码公钥失败")
    }
    key, err := x509.ParsePKIXPublicKey(block.Bytes)
    if err != nil {
        return nil, err
    }
    publicKey, ok := key.(*rsa.PublicKey)
    if !ok {
        return nil, fmt.Errorf("非法公钥文件")
    }
    return publicKey, nil
}

func switchData(dst net.Conn, src net.Conn) (written int64, err error) {
    buf := make([]byte, 10240)
    resetTimeOut := true
    for {
        if resetTimeOut {
            src.SetReadDeadline(time.Now().Add(timeOut))
            dst.SetReadDeadline(time.Now().Add(timeOut))
            resetTimeOut = false
        }
        nr, er := src.Read(buf)
        if nr > 0 {
            resetTimeOut = true
            dst.SetWriteDeadline(time.Now().Add(timeOut))
            nw, ew := dst.Write(buf[0:nr])
            if nw < 0 || nr < nw {
                nw = 0
                if ew == nil {
                    ew = fmt.Errorf("invalid write result")
                }
            }
            written += int64(nw)
            if ew != nil {
                err = ew
                break
            }
            if nr != nw {
                err = io.ErrShortWrite
                break
            }
        }
        if er != nil {
            if er != io.EOF {
                err = er
            }
            break
        }
    }
    if err != nil {
        log.Println(err)
    }
    return written, err
}

由于未找到较好的读取加密的私钥的方案,服务端的私钥改为了未加密的私钥

上一篇下一篇

猜你喜欢

热点阅读