541 lines
11 KiB
Go
541 lines
11 KiB
Go
package tunnel
|
||
|
||
import (
|
||
"context"
|
||
"encoding/binary"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
// 协议版本
|
||
ProtocolVersion = 0x01
|
||
|
||
// 消息头大小
|
||
HeaderSize = 6 // 版本(1) + 类型(1) + 长度(4)
|
||
|
||
// 最大包大小
|
||
MaxPacketSize = 1024 * 1024
|
||
|
||
// 重连延迟
|
||
ReconnectDelay = 5 * time.Second
|
||
|
||
// 消息类型
|
||
MsgTypeConnectRequest = 0x01 // 连接请求
|
||
MsgTypeConnectResponse = 0x02 // 连接响应
|
||
MsgTypeData = 0x03 // 数据传输
|
||
MsgTypeClose = 0x04 // 关闭连接
|
||
MsgTypeKeepAlive = 0x05 // 心跳
|
||
|
||
// 连接响应状态
|
||
ConnStatusSuccess = 0x00 // 连接成功
|
||
ConnStatusFailed = 0x01 // 连接失败
|
||
|
||
// 超时设置
|
||
ConnectTimeout = 10 * time.Second // 连接超时
|
||
ReadTimeout = 30 * time.Second // 读取超时
|
||
KeepAliveInterval = 15 * time.Second // 心跳间隔
|
||
)
|
||
|
||
// TunnelMessage 隧道消息
|
||
type TunnelMessage struct {
|
||
Version byte
|
||
Type byte
|
||
Length uint32
|
||
Data []byte
|
||
}
|
||
|
||
// LocalConnection 本地连接
|
||
type LocalConnection struct {
|
||
ID uint32
|
||
TargetAddr string
|
||
Conn net.Conn
|
||
closeChan chan struct{}
|
||
closeOnce sync.Once
|
||
}
|
||
|
||
// Client 内网穿透客户端
|
||
type Client struct {
|
||
serverAddr string
|
||
serverConn net.Conn
|
||
cancel context.CancelFunc
|
||
ctx context.Context
|
||
wg sync.WaitGroup
|
||
mu sync.RWMutex
|
||
|
||
// 连接管理
|
||
connections map[uint32]*LocalConnection
|
||
connMu sync.RWMutex
|
||
|
||
// 消息队列
|
||
sendChan chan *TunnelMessage
|
||
}
|
||
|
||
// NewClient 创建新的隧道客户端
|
||
func NewClient(serverAddr string) *Client {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
return &Client{
|
||
serverAddr: serverAddr,
|
||
cancel: cancel,
|
||
ctx: ctx,
|
||
connections: make(map[uint32]*LocalConnection),
|
||
sendChan: make(chan *TunnelMessage, 1000),
|
||
}
|
||
}
|
||
|
||
// Start 启动隧道客户端
|
||
func (c *Client) Start() error {
|
||
log.Printf("正在连接到隧道服务器: %s", c.serverAddr)
|
||
|
||
c.wg.Add(1)
|
||
go c.connectLoop()
|
||
|
||
return nil
|
||
}
|
||
|
||
// connectLoop 连接循环(支持自动重连)
|
||
func (c *Client) connectLoop() {
|
||
defer c.wg.Done()
|
||
|
||
for {
|
||
select {
|
||
case <-c.ctx.Done():
|
||
return
|
||
default:
|
||
}
|
||
|
||
conn, err := net.DialTimeout("tcp", c.serverAddr, 10*time.Second)
|
||
if err != nil {
|
||
log.Printf("连接隧道服务器失败: %v,%v 后重试", err, ReconnectDelay)
|
||
time.Sleep(ReconnectDelay)
|
||
continue
|
||
}
|
||
|
||
log.Printf("已连接到隧道服务器: %s", c.serverAddr)
|
||
|
||
c.mu.Lock()
|
||
c.serverConn = conn
|
||
c.mu.Unlock()
|
||
|
||
// 处理连接
|
||
var connWg sync.WaitGroup
|
||
connWg.Add(3)
|
||
go func() {
|
||
defer connWg.Done()
|
||
c.handleServerRead(conn)
|
||
}()
|
||
go func() {
|
||
defer connWg.Done()
|
||
c.handleServerWrite(conn)
|
||
}()
|
||
go func() {
|
||
defer connWg.Done()
|
||
c.keepAliveLoop(conn)
|
||
}()
|
||
|
||
// 等待连接断开
|
||
connWg.Wait()
|
||
|
||
c.mu.Lock()
|
||
c.serverConn = nil
|
||
c.mu.Unlock()
|
||
|
||
// 关闭所有本地连接
|
||
c.connMu.Lock()
|
||
for _, conn := range c.connections {
|
||
conn.closeOnce.Do(func() {
|
||
close(conn.closeChan)
|
||
})
|
||
if conn.Conn != nil {
|
||
conn.Conn.Close()
|
||
}
|
||
}
|
||
c.connections = make(map[uint32]*LocalConnection)
|
||
c.connMu.Unlock()
|
||
|
||
log.Printf("与服务器断开连接,%v 后重连", ReconnectDelay)
|
||
time.Sleep(ReconnectDelay)
|
||
}
|
||
}
|
||
|
||
// handleServerRead 处理服务器读取
|
||
func (c *Client) handleServerRead(conn net.Conn) {
|
||
defer conn.Close()
|
||
|
||
for {
|
||
select {
|
||
case <-c.ctx.Done():
|
||
return
|
||
default:
|
||
}
|
||
|
||
msg, err := c.readTunnelMessage(conn)
|
||
if err != nil {
|
||
if err != io.EOF {
|
||
log.Printf("读取隧道消息失败: %v", err)
|
||
}
|
||
return
|
||
}
|
||
|
||
c.handleTunnelMessage(msg)
|
||
}
|
||
}
|
||
|
||
// handleServerWrite 处理服务器写入
|
||
func (c *Client) handleServerWrite(conn net.Conn) {
|
||
for {
|
||
select {
|
||
case <-c.ctx.Done():
|
||
return
|
||
case msg := <-c.sendChan:
|
||
if err := c.writeTunnelMessage(conn, msg); err != nil {
|
||
log.Printf("写入隧道消息失败: %v", err)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// readTunnelMessage 读取隧道消息
|
||
func (c *Client) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
|
||
// 读取消息头
|
||
header := make([]byte, HeaderSize)
|
||
if _, err := io.ReadFull(conn, header); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
version := header[0]
|
||
msgType := header[1]
|
||
dataLen := binary.BigEndian.Uint32(header[2:6])
|
||
|
||
if version != ProtocolVersion {
|
||
return nil, fmt.Errorf("不支持的协议版本: %d", version)
|
||
}
|
||
|
||
if dataLen > MaxPacketSize {
|
||
return nil, fmt.Errorf("数据包过大: %d bytes", dataLen)
|
||
}
|
||
|
||
// 读取数据
|
||
var data []byte
|
||
if dataLen > 0 {
|
||
data = make([]byte, dataLen)
|
||
if _, err := io.ReadFull(conn, data); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
return &TunnelMessage{
|
||
Version: version,
|
||
Type: msgType,
|
||
Length: dataLen,
|
||
Data: data,
|
||
}, nil
|
||
}
|
||
|
||
// writeTunnelMessage 写入隧道消息
|
||
func (c *Client) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
|
||
// 构建消息头
|
||
header := make([]byte, HeaderSize)
|
||
header[0] = msg.Version
|
||
header[1] = msg.Type
|
||
binary.BigEndian.PutUint32(header[2:6], msg.Length)
|
||
|
||
// 写入消息头
|
||
if _, err := conn.Write(header); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 写入数据
|
||
if msg.Length > 0 && msg.Data != nil {
|
||
if _, err := conn.Write(msg.Data); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// handleTunnelMessage 处理隧道消息
|
||
func (c *Client) handleTunnelMessage(msg *TunnelMessage) {
|
||
switch msg.Type {
|
||
case MsgTypeConnectRequest:
|
||
c.handleConnectRequest(msg)
|
||
case MsgTypeData:
|
||
c.handleDataMessage(msg)
|
||
case MsgTypeClose:
|
||
c.handleCloseMessage(msg)
|
||
case MsgTypeKeepAlive:
|
||
c.handleKeepAlive(msg)
|
||
default:
|
||
log.Printf("未知消息类型: %d", msg.Type)
|
||
}
|
||
}
|
||
|
||
// handleConnectRequest 处理连接请求
|
||
func (c *Client) handleConnectRequest(msg *TunnelMessage) {
|
||
if len(msg.Data) < 6 {
|
||
log.Printf("连接请求数据太短")
|
||
return
|
||
}
|
||
|
||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||
targetPort := binary.BigEndian.Uint16(msg.Data[4:6])
|
||
targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)
|
||
|
||
log.Printf("收到连接请求: ID=%d, 端口=%d", connID, targetPort)
|
||
|
||
// 尝试连接到本地服务
|
||
localConn, err := net.DialTimeout("tcp", targetAddr, ConnectTimeout)
|
||
if err != nil {
|
||
log.Printf("连接本地服务失败 (ID=%d -> %s): %v", connID, targetAddr, err)
|
||
c.sendConnectResponse(connID, ConnStatusFailed)
|
||
return
|
||
}
|
||
|
||
// 创建本地连接对象
|
||
connection := &LocalConnection{
|
||
ID: connID,
|
||
TargetAddr: targetAddr,
|
||
Conn: localConn,
|
||
closeChan: make(chan struct{}),
|
||
}
|
||
|
||
c.connMu.Lock()
|
||
c.connections[connID] = connection
|
||
c.connMu.Unlock()
|
||
|
||
log.Printf("建立本地连接: ID=%d -> %s", connID, targetAddr)
|
||
|
||
// 发送连接成功响应
|
||
c.sendConnectResponse(connID, ConnStatusSuccess)
|
||
|
||
// 启动数据转发
|
||
go c.forwardData(connection)
|
||
}
|
||
|
||
// handleDataMessage 处理数据消息
|
||
func (c *Client) handleDataMessage(msg *TunnelMessage) {
|
||
if len(msg.Data) < 4 {
|
||
log.Printf("数据消息太短")
|
||
return
|
||
}
|
||
|
||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||
data := msg.Data[4:]
|
||
|
||
c.connMu.RLock()
|
||
connection, exists := c.connections[connID]
|
||
c.connMu.RUnlock()
|
||
|
||
if !exists {
|
||
log.Printf("收到未知连接的数据: %d", connID)
|
||
return
|
||
}
|
||
|
||
// 写入到本地连接
|
||
if _, err := connection.Conn.Write(data); err != nil {
|
||
log.Printf("写入本地连接失败 (ID=%d): %v", connID, err)
|
||
c.closeConnection(connID)
|
||
}
|
||
}
|
||
|
||
// handleCloseMessage 处理关闭消息
|
||
func (c *Client) handleCloseMessage(msg *TunnelMessage) {
|
||
if len(msg.Data) < 4 {
|
||
log.Printf("关闭消息数据太短")
|
||
return
|
||
}
|
||
|
||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||
c.closeConnection(connID)
|
||
}
|
||
|
||
// handleKeepAlive 处理心跳消息
|
||
func (c *Client) handleKeepAlive(msg *TunnelMessage) {
|
||
// 回应心跳
|
||
response := &TunnelMessage{
|
||
Version: ProtocolVersion,
|
||
Type: MsgTypeKeepAlive,
|
||
Length: 0,
|
||
Data: nil,
|
||
}
|
||
|
||
select {
|
||
case c.sendChan <- response:
|
||
default:
|
||
log.Printf("发送心跳响应失败: 发送队列已满")
|
||
}
|
||
}
|
||
|
||
// sendConnectResponse 发送连接响应
|
||
func (c *Client) sendConnectResponse(connID uint32, status byte) {
|
||
responseData := make([]byte, 5)
|
||
binary.BigEndian.PutUint32(responseData[0:4], connID)
|
||
responseData[4] = status
|
||
|
||
msg := &TunnelMessage{
|
||
Version: ProtocolVersion,
|
||
Type: MsgTypeConnectResponse,
|
||
Length: 5,
|
||
Data: responseData,
|
||
}
|
||
|
||
select {
|
||
case c.sendChan <- msg:
|
||
default:
|
||
log.Printf("发送连接响应失败: 发送队列已满")
|
||
}
|
||
}
|
||
|
||
// forwardData 转发数据
|
||
func (c *Client) forwardData(connection *LocalConnection) {
|
||
defer c.closeConnection(connection.ID)
|
||
|
||
buffer := make([]byte, 32*1024)
|
||
for {
|
||
select {
|
||
case <-connection.closeChan:
|
||
return
|
||
case <-c.ctx.Done():
|
||
return
|
||
default:
|
||
}
|
||
|
||
connection.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
|
||
n, err := connection.Conn.Read(buffer)
|
||
if err != nil {
|
||
if err != io.EOF && !isTimeout(err) {
|
||
log.Printf("读取本地连接失败 (ID=%d): %v", connection.ID, err)
|
||
}
|
||
return
|
||
}
|
||
|
||
// 发送数据到隧道
|
||
dataMsg := make([]byte, 4+n)
|
||
binary.BigEndian.PutUint32(dataMsg[0:4], connection.ID)
|
||
copy(dataMsg[4:], buffer[:n])
|
||
|
||
msg := &TunnelMessage{
|
||
Version: ProtocolVersion,
|
||
Type: MsgTypeData,
|
||
Length: uint32(len(dataMsg)),
|
||
Data: dataMsg,
|
||
}
|
||
|
||
select {
|
||
case c.sendChan <- msg:
|
||
case <-time.After(5 * time.Second):
|
||
log.Printf("发送数据超时 (ID=%d)", connection.ID)
|
||
return
|
||
case <-c.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// closeConnection 关闭连接
|
||
func (c *Client) closeConnection(connID uint32) {
|
||
c.connMu.Lock()
|
||
connection, exists := c.connections[connID]
|
||
if exists {
|
||
delete(c.connections, connID)
|
||
connection.closeOnce.Do(func() {
|
||
close(connection.closeChan)
|
||
})
|
||
connection.Conn.Close()
|
||
}
|
||
c.connMu.Unlock()
|
||
|
||
// 发送关闭消息
|
||
closeData := make([]byte, 4)
|
||
binary.BigEndian.PutUint32(closeData, connID)
|
||
|
||
msg := &TunnelMessage{
|
||
Version: ProtocolVersion,
|
||
Type: MsgTypeClose,
|
||
Length: 4,
|
||
Data: closeData,
|
||
}
|
||
|
||
select {
|
||
case c.sendChan <- msg:
|
||
default:
|
||
// 发送队列满,忽略
|
||
}
|
||
|
||
if exists {
|
||
log.Printf("连接已关闭: ID=%d", connID)
|
||
}
|
||
}
|
||
|
||
// Stop 停止隧道客户端
|
||
func (c *Client) Stop() error {
|
||
log.Println("正在停止隧道客户端...")
|
||
c.cancel()
|
||
|
||
c.mu.Lock()
|
||
if c.serverConn != nil {
|
||
c.serverConn.Close()
|
||
}
|
||
c.mu.Unlock()
|
||
|
||
// 等待所有协程结束
|
||
done := make(chan struct{})
|
||
go func() {
|
||
c.wg.Wait()
|
||
close(done)
|
||
}()
|
||
|
||
select {
|
||
case <-done:
|
||
log.Println("隧道客户端已停止")
|
||
case <-time.After(5 * time.Second):
|
||
log.Println("隧道客户端停止超时")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// isTimeout 检查是否为超时错误
|
||
func isTimeout(err error) bool {
|
||
if netErr, ok := err.(net.Error); ok {
|
||
return netErr.Timeout()
|
||
}
|
||
return false
|
||
}
|
||
|
||
// keepAliveLoop 心跳循环
|
||
func (c *Client) keepAliveLoop(conn net.Conn) {
|
||
ticker := time.NewTicker(KeepAliveInterval)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-c.ctx.Done():
|
||
return
|
||
case <-ticker.C:
|
||
// 发送心跳消息
|
||
keepAliveMsg := &TunnelMessage{
|
||
Version: ProtocolVersion,
|
||
Type: MsgTypeKeepAlive,
|
||
Length: 0,
|
||
Data: nil,
|
||
}
|
||
|
||
select {
|
||
case c.sendChan <- keepAliveMsg:
|
||
log.Printf("发送心跳消息")
|
||
case <-time.After(5 * time.Second):
|
||
log.Printf("发送心跳消息超时")
|
||
return
|
||
case <-c.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
} |