udp转tcp再转udp工具

2024-08-09  本文已影响0人  今天i你好吗

相关文章: https://www.jianshu.com/p/44eded155258
前面分享了一个udp端口转发工具, 但是在实际使用中经常会出现发送出了数据但是收不到数据的情况(应该是因为运营商的ipv6网络还不够完善),为了解决这个问题于是决定在用ipv6传数据部分采用tcp协议解决.

这两个工具要配合使用,且只能udp转tcp再转udp,直接上代码

udp转tcp:

package main

import (
    "errors"
    "fmt"
    "io"
    "log"
    "net"
    "os"
    "strings"
    "time"
)

var (
    maxDTime = 0.0
    timeOut  = time.Second * 60 * 2
)

func main() {
    println("入参监听地址和转发地址: " + strings.Join(os.Args[1:], " "))
    listenerAddress := ":1701"
    forwardAddress := "127.0.0.1:1701"
    if len(os.Args) == 3 {
        listenerAddress = os.Args[1]
        forwardAddress = os.Args[2]
    }
    log.SetFlags(log.LstdFlags | log.Lshortfile)

    udpForwardBean := Start(listenerAddress, forwardAddress)
    if udpForwardBean != nil {
        for {
            println(udpForwardBean.String())
            println()
            time.Sleep(time.Second * 60)
        }
    }
}

func Start(listenerAddress string, forwardAddress string) *UdpForwardBean {
    message := listenerAddress + "=>" + forwardAddress
    log.Println(message)

    listenerAddr, err := net.ResolveUDPAddr("udp", listenerAddress)
    if err != nil {
        log.Println(err)
        return nil
    }
    network := "udp"
    if listenerAddr.IP.To4() != nil {
        network = "udp4"
    } else if listenerAddr.IP.To16() != nil {
        network = "udp6"
    }
    listenerConn, err := net.ListenUDP(network, listenerAddr)
    if err != nil {
        log.Println(err)
        return nil
    }
    forwardMap := map[string]*ForwardBean{}
    udpForwardBean := &UdpForwardBean{isClosed: false, listenerAddress: listenerAddress,
        forwardAddress: forwardAddress, listenerConn: listenerConn, forwardMap: forwardMap}
    buffer := make([]byte, 1024*64)
    go func() {
        defer listenerConn.Close()
        for {
            n, clientAddr, err := listenerConn.ReadFromUDP(buffer)
            if err != nil {
                if errors.Is(err, net.ErrClosed) {
                    break
                }
                log.Println(err)
                time.Sleep(time.Second)
                continue
            }
            forwardBean := handleClientRequest(clientAddr, listenerConn, forwardAddress, forwardMap)
            if forwardBean != nil {
                data := append([]byte{byte(n / 256), byte(n % 256)}, buffer[:n]...)
                // log.Println("客户端消息:", n, clientAddr.String())
                forwardBean.bufferedCh <- data
            }
        }
        udpForwardBean.isClosed = true
    }()
    return udpForwardBean
}

func handleClientRequest(clientAddr *net.UDPAddr, listenerConn *net.UDPConn, forwardAddress string, forwardMap map[string]*ForwardBean) *ForwardBean {
    if clientAddr == nil {
        return nil
    }
    clientAddrString := clientAddr.String()
    forwardBean := forwardMap[clientAddrString]
    if forwardBean != nil {
        return forwardBean
    }
    forwardAddr, err := net.ResolveTCPAddr("tcp", forwardAddress)
    if err != nil {
        log.Println(err)
        return nil
    }
    forwardConn, err := net.DialTCP("tcp", nil, forwardAddr)
    if err != nil {
        log.Println(err)
        return nil
    }
    bufferedCh := make(chan []byte, 50)
    infoStr := clientAddrString + "=>" + forwardAddress + "=>" + forwardConn.LocalAddr().String() + "=>" + forwardConn.RemoteAddr().String()
    log.Println("添加U2T转发:", infoStr)
    lengthBuffer := make([]byte, 2)
    go func() {
        defer forwardConn.Close()
        startTime := time.Now()
        forwardSuccess := false
        for {
            forwardConn.SetReadDeadline(time.Now().Add(timeOut))
            n, err := io.ReadFull(forwardConn, lengthBuffer)
            if err != nil || n != 2 {
                log.Println(n, err)
                break
            }
            length := int(lengthBuffer[0])*256 + int(lengthBuffer[1])
            data := make([]byte, length)
            n, err = io.ReadFull(forwardConn, data)
            if err != nil || n != length {
                log.Println(n, length, err)
                break
            }
            dTime := time.Since(startTime).Seconds()
            startTime = time.Now()
            if dTime > maxDTime {
                maxDTime = dTime
            }
            if !forwardSuccess {
                forwardSuccess = true
                log.Println("U2T转发成功:", n, forwardConn.RemoteAddr().String()+"=>"+clientAddrString)
            }
            // log.Println("服务端消息:", n, forwardConn.RemoteAddr().String()+"=>"+clientAddrString)
            listenerConn.WriteToUDP(data, clientAddr)
        }
        log.Println("移除U2T:" + infoStr)
        delete(forwardMap, clientAddrString)
        close(bufferedCh)
    }()
    go func() {
        for data := range bufferedCh {
            forwardConn.SetWriteDeadline(time.Now().Add(timeOut))
            forwardConn.Write(data)
        }
    }()
    forwardBean = &ForwardBean{bufferedCh, forwardConn}
    forwardMap[clientAddrString] = forwardBean
    return forwardBean
}

type ForwardBean struct {
    bufferedCh  chan []byte
    forwardConn *net.TCPConn
}

type UdpForwardBean struct {
    isClosed        bool
    listenerAddress string
    forwardAddress  string
    listenerConn    *net.UDPConn
    forwardMap      map[string]*ForwardBean
}

func (bean *UdpForwardBean) String() string {
    var keys []string
    for key, value := range bean.forwardMap {
        keys = append(keys, ",\n"+key+"=>"+value.forwardConn.LocalAddr().String()+"=>"+value.forwardConn.RemoteAddr().String())
    }
    return fmt.Sprintf("U2T转发中: %s %s %t %.2f %s", bean.listenerAddress, bean.forwardAddress, bean.isClosed, maxDTime, strings.Join(keys, ""))
}

func (bean *UdpForwardBean) Close() {
    bean.listenerConn.Close()
    for _, v := range bean.forwardMap {
        close(v.bufferedCh)
        v.forwardConn.Close()
    }
}

tcp转udp:

package main

import (
    "errors"
    "fmt"
    "io"
    "log"
    "net"
    "os"
    "strings"
    "time"
)

var (
    maxDTime = 0.0
    timeOut  = time.Second * 60 * 2
)

func main() {
    println("入参监听地址和转发地址: " + strings.Join(os.Args[1:], " "))
    listenerAddress := ":1701"
    forwardAddress := "192.168.1.8:1701"
    if len(os.Args) == 3 {
        listenerAddress = os.Args[1]
        forwardAddress = os.Args[2]
    }
    log.SetFlags(log.LstdFlags | log.Lshortfile)

    tcpForwardBean := Start(listenerAddress, forwardAddress)
    if tcpForwardBean != nil {
        for {
            println(tcpForwardBean.String())
            println()
            time.Sleep(time.Second * 60)
        }
    }
}

func Start(listenerAddress string, forwardAddress string) *TcpForwardBean {
    message := listenerAddress + "=>" + forwardAddress
    log.Println(message)

    listenerAddr, err := net.ResolveTCPAddr("tcp", listenerAddress)
    if err != nil {
        log.Println(err)
        return nil
    }
    network := "tcp"
    if listenerAddr.IP.To4() != nil {
        network = "tcp4"
    } else if listenerAddr.IP.To16() != nil {
        network = "tcp6"
    }
    tcpListener, err := net.ListenTCP(network, listenerAddr)
    if err != nil {
        log.Println(err)
        return nil
    }
    forwardMap := map[*net.TCPConn]*net.UDPConn{}
    tcpForwardBean := &TcpForwardBean{isClosed: false, listenerAddress: listenerAddress,
        forwardAddress: forwardAddress, tcpListener: tcpListener, forwardMap: forwardMap}
    go func() {
        defer tcpListener.Close()
        for {
            client, err := tcpListener.AcceptTCP()
            if err != nil {
                if errors.Is(err, net.ErrClosed) {
                    break
                }
                log.Println(err)
                time.Sleep(time.Second)
                continue
            }
            go forward(client, forwardAddress, forwardMap)
        }
        tcpForwardBean.isClosed = true
    }()
    return tcpForwardBean
}

func forward(client *net.TCPConn, forwardAddress string, forwardMap map[*net.TCPConn]*net.UDPConn) {
    if client == nil {
        return
    }
    // log.Println("客户端地址:", client.RemoteAddr(), "服务端地址:", client.LocalAddr())
    defer client.Close()
    forwardAddr, err := net.ResolveUDPAddr("udp", forwardAddress)
    if err != nil {
        log.Println(err)
        return
    }
    forwardConn, err := net.DialUDP("udp", nil, forwardAddr)
    if err != nil {
        log.Println(err)
        return
    }
    defer forwardConn.Close()
    forwardMap[client] = forwardConn
    defer delete(forwardMap, client)

    infoStr := client.RemoteAddr().String() + "=>" + client.LocalAddr().String() + "=>" + forwardConn.LocalAddr().String() + "=>" + forwardConn.RemoteAddr().String()
    log.Println("添加T2U转发:", infoStr)

    buffer := make([]byte, 1024*64)
    go func() {
        defer forwardConn.Close()
        defer client.Close()
        startTime := time.Now()
        forwardSuccess := false
        bufferedCh := make(chan []byte, 50)
        go func() {
            for data := range bufferedCh {
                client.SetWriteDeadline(time.Now().Add(timeOut))
                client.Write(data)
            }
        }()
        for {
            forwardConn.SetReadDeadline(time.Now().Add(timeOut))
            n, serverAddr, err := forwardConn.ReadFromUDP(buffer)
            if err != nil {
                if errors.Is(err, net.ErrClosed) {
                    break
                }
                if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
                    break
                }
                log.Println(err)
                time.Sleep(time.Second)
                continue
            }
            if serverAddr.Port != forwardAddr.Port || serverAddr.IP.String() != forwardAddr.IP.String() {
                log.Println("异常消息:", serverAddr.String(), forwardAddr.String())
                continue
            }
            dTime := time.Since(startTime).Seconds()
            startTime = time.Now()
            if dTime > maxDTime {
                maxDTime = dTime
            }
            if !forwardSuccess {
                forwardSuccess = true
                log.Println("T2U转发成功:", n, serverAddr.String()+"=>"+client.RemoteAddr().String())
            }
            // log.Println("服务端消息:", n, serverAddr.String()+"=>"+client.RemoteAddr().String())
            bufferedCh <- append([]byte{byte(n / 256), byte(n % 256)}, buffer[:n]...)
        }
        close(bufferedCh)
        log.Println("移除T2U:", infoStr)
    }()

    lengthBuffer := make([]byte, 2)
    for {
        client.SetReadDeadline(time.Now().Add(timeOut))
        n, err := io.ReadFull(client, lengthBuffer)
        if err != nil || n != 2 {
            log.Println(n, err)
            break
        }
        length := int(lengthBuffer[0])*256 + int(lengthBuffer[1])
        data := make([]byte, length)
        n, err = io.ReadFull(client, data)
        if err != nil || n != length {
            log.Println(n, length, err)
            break
        }
        // log.Println("服务端消息:", n, client.RemoteAddr().String())
        forwardConn.Write(data)
    }
}

type TcpForwardBean struct {
    isClosed        bool
    listenerAddress string
    forwardAddress  string
    tcpListener     *net.TCPListener
    forwardMap      map[*net.TCPConn]*net.UDPConn
}

func (bean *TcpForwardBean) String() string {
    var keys []string
    for key, value := range bean.forwardMap {
        keys = append(keys, ",\n"+key.RemoteAddr().String()+"=>"+key.LocalAddr().String()+"=>"+value.LocalAddr().String()+"=>"+value.RemoteAddr().String())
    }
    return fmt.Sprintf("T2U转发中: %s %s %t %.2f %s", bean.listenerAddress, bean.forwardAddress, bean.isClosed, maxDTime, strings.Join(keys, ""))
}

func (bean *TcpForwardBean) Close() {
    bean.tcpListener.Close()
    for k, v := range bean.forwardMap {
        k.Close()
        v.Close()
    }
}

上一篇下一篇

猜你喜欢

热点阅读