diff --git a/src/client/tunnel/client.go b/src/client/tunnel/client.go index f749548..5f2c657 100644 --- a/src/client/tunnel/client.go +++ b/src/client/tunnel/client.go @@ -278,21 +278,30 @@ func (c *Client) handleTunnelMessage(msg *TunnelMessage) { // handleConnectRequest 处理连接请求 func (c *Client) handleConnectRequest(msg *TunnelMessage) { - if len(msg.Data) < 6 { + // 解析: connID(4) + targetPort(2) + targetIPLen(1) + targetIP(变长) + if len(msg.Data) < 7 { log.Printf("连接请求数据太短") return } connID := binary.BigEndian.Uint32(msg.Data[0:4]) targetPort := binary.BigEndian.Uint16(msg.Data[4:6]) - targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort) + targetIPLen := int(msg.Data[6]) + + if len(msg.Data) < 7+targetIPLen { + log.Printf("连接请求数据不完整") + return + } + + targetIP := string(msg.Data[7 : 7+targetIPLen]) + targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort)) - log.Printf("收到连接请求: ID=%d, 端口=%d", connID, targetPort) + log.Printf("收到连接请求: ID=%d, 地址=%s", connID, targetAddr) - // 尝试连接到本地服务 + // 尝试连接到目标服务 localConn, err := net.DialTimeout("tcp", targetAddr, ConnectTimeout) if err != nil { - log.Printf("连接本地服务失败 (ID=%d -> %s): %v", connID, targetAddr, err) + log.Printf("连接目标服务失败 (ID=%d -> %s): %v", connID, targetAddr, err) c.sendConnectResponse(connID, ConnStatusFailed) return } @@ -309,7 +318,7 @@ func (c *Client) handleConnectRequest(msg *TunnelMessage) { c.connections[connID] = connection c.connMu.Unlock() - log.Printf("建立本地连接: ID=%d -> %s", connID, targetAddr) + log.Printf("建立目标连接: ID=%d -> %s", connID, targetAddr) // 发送连接成功响应 c.sendConnectResponse(connID, ConnStatusSuccess) @@ -337,9 +346,8 @@ func (c *Client) handleDataMessage(msg *TunnelMessage) { return } - // 写入到本地连接 - if _, err := connection.Conn.Write(data); err != nil { - log.Printf("写入本地连接失败 (ID=%d): %v", connID, err) + if _, err := connection.Conn.Write(data); err != nil { + log.Printf("写入目标连接失败 (ID=%d): %v", connID, err) c.closeConnection(connID) } } @@ -410,7 +418,7 @@ func (c *Client) forwardData(connection *LocalConnection) { n, err := connection.Conn.Read(buffer) if err != nil { if err != io.EOF && !isTimeout(err) { - log.Printf("读取本地连接失败 (ID=%d): %v", connection.ID, err) + log.Printf("读取目标连接失败 (ID=%d): %v", connection.ID, err) } return } diff --git a/src/server/api/api.go b/src/server/api/api.go index f682ed0..948e77d 100644 --- a/src/server/api/api.go +++ b/src/server/api/api.go @@ -97,8 +97,15 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) { h.writeError(w, http.StatusServiceUnavailable, "隧道未连接") return } - // 隧道模式使用本地地址 - req.TargetIP = "127.0.0.1" + // 隧道模式也需要目标IP(客户端会连接到该IP) + if req.TargetIP == "" { + h.writeError(w, http.StatusBadRequest, "target_ip 不能为空") + return + } + if net.ParseIP(req.TargetIP) == nil { + h.writeError(w, http.StatusBadRequest, "target_ip 格式无效") + return + } } else { // 直接模式需要验证 IP if req.TargetIP == "" { @@ -121,7 +128,7 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) { var err error if req.UseTunnel { // 隧道模式:使用隧道转发 - err = h.forwarderMgr.AddTunnel(req.SourcePort, req.SourcePort, h.tunnelServer) + err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetIP, req.TargetPort, h.tunnelServer) } else { // 直接模式:直接TCP转发 err = h.forwarderMgr.Add(req.SourcePort, req.TargetIP, req.TargetPort) diff --git a/src/server/forwarder/forwarder.go b/src/server/forwarder/forwarder.go index 146410d..320e956 100644 --- a/src/server/forwarder/forwarder.go +++ b/src/server/forwarder/forwarder.go @@ -12,7 +12,7 @@ import ( // TunnelServer 隧道服务器接口 type TunnelServer interface { - ForwardConnection(clientConn net.Conn, targetPort int) error + ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error IsConnected() bool } @@ -20,6 +20,7 @@ type TunnelServer interface { type Forwarder struct { sourcePort int targetPort int + targetIP string targetAddr string listener net.Listener cancel context.CancelFunc @@ -35,6 +36,7 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder { return &Forwarder{ sourcePort: sourcePort, targetPort: targetPort, + targetIP: targetIP, targetAddr: fmt.Sprintf("%s:%d", targetIP, targetPort), cancel: cancel, ctx: ctx, @@ -43,11 +45,12 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder { } // NewTunnelForwarder 创建使用隧道的端口转发器 -func NewTunnelForwarder(sourcePort int, targetPort int, tunnelServer TunnelServer) *Forwarder { +func NewTunnelForwarder(sourcePort int, targetIP string, targetPort int, tunnelServer TunnelServer) *Forwarder { ctx, cancel := context.WithCancel(context.Background()) return &Forwarder{ sourcePort: sourcePort, targetPort: targetPort, + targetIP: targetIP, tunnelServer: tunnelServer, useTunnel: true, cancel: cancel, @@ -117,8 +120,8 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) { } // 将连接转发到隧道,ForwardConnection 会处理连接关闭 - if err := f.tunnelServer.ForwardConnection(clientConn, f.targetPort); err != nil { - log.Printf("隧道转发失败 (端口 %d -> %d): %v", f.sourcePort, f.targetPort, err) + if err := f.tunnelServer.ForwardConnection(clientConn, f.targetIP, f.targetPort); err != nil { + log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetIP, f.targetPort, err) } return } @@ -221,7 +224,7 @@ func (m *Manager) Add(sourcePort int, targetIP string, targetPort int) error { } // AddTunnel 添加使用隧道的转发器 -func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelServer TunnelServer) error { +func (m *Manager) AddTunnel(sourcePort int, targetIP string, targetPort int, tunnelServer TunnelServer) error { m.mu.Lock() defer m.mu.Unlock() @@ -229,7 +232,7 @@ func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelServer TunnelS return fmt.Errorf("端口 %d 已被占用", sourcePort) } - forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelServer) + forwarder := NewTunnelForwarder(sourcePort, targetIP, targetPort, tunnelServer) if err := forwarder.Start(); err != nil { return err } diff --git a/src/server/forwarder/forwarder_test.go b/src/server/forwarder/forwarder_test.go index aed72da..812cde7 100644 --- a/src/server/forwarder/forwarder_test.go +++ b/src/server/forwarder/forwarder_test.go @@ -34,7 +34,7 @@ type mockTunnelServer struct { connected bool } -func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetPort int) error { +func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) error { // 简单的模拟实现 defer clientConn.Close() return nil @@ -49,7 +49,7 @@ func TestNewTunnelForwarder(t *testing.T) { // 创建模拟隧道服务器 mockServer := &mockTunnelServer{connected: true} - fwd := NewTunnelForwarder(8080, 80, mockServer) + fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer) if fwd == nil { t.Fatal("创建隧道转发器失败") diff --git a/src/server/main.go b/src/server/main.go index b598350..2c4f99a 100644 --- a/src/server/main.go +++ b/src/server/main.go @@ -73,7 +73,7 @@ func main() { log.Printf("警告: 端口 %d 需要隧道模式但隧道服务未启用,跳过", mapping.SourcePort) continue } - err = fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetPort, tunnelServer) + err = fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort, tunnelServer) } else { // 直接模式 err = fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort) diff --git a/src/server/tunnel/tunnel.go b/src/server/tunnel/tunnel.go index 7dbd9b4..491c8a5 100644 --- a/src/server/tunnel/tunnel.go +++ b/src/server/tunnel/tunnel.go @@ -53,6 +53,7 @@ type TunnelMessage struct { type ConnectRequestData struct { ConnID uint32 // 连接ID TargetPort uint16 // 目标端口 + TargetIP string // 目标IP地址 } // ConnectResponseData 连接响应数据 @@ -77,6 +78,7 @@ type PendingConnection struct { ID uint32 ClientConn net.Conn TargetPort int + TargetIP string Created time.Time ResponseChan chan bool // 用于接收连接响应 } @@ -86,6 +88,7 @@ type ActiveConnection struct { ID uint32 ClientConn net.Conn TargetPort int + TargetIP string Created time.Time } @@ -347,6 +350,7 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) { ID: connID, ClientConn: pending.ClientConn, TargetPort: pending.TargetPort, + TargetIP: pending.TargetIP, Created: time.Now(), } @@ -354,7 +358,7 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) { s.activeConns[connID] = active s.connMu.Unlock() - log.Printf("连接已建立: ID=%d, 端口=%d", connID, pending.TargetPort) + log.Printf("连接已建立: ID=%d, 地址=%s:%d", connID, pending.TargetIP, pending.TargetPort) // 启动数据转发 s.wg.Add(1) @@ -512,7 +516,7 @@ func (s *Server) closeConnection(connID uint32) { } // ForwardConnection 转发连接到隧道(新的透明代理实现) -func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error { +func (s *Server) ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error { s.mu.RLock() tunnelConnected := s.tunnelConn != nil s.mu.RUnlock() @@ -530,6 +534,7 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error { ID: connID, ClientConn: clientConn, TargetPort: targetPort, + TargetIP: targetIP, Created: time.Now(), ResponseChan: make(chan bool, 1), } @@ -537,14 +542,18 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error { s.connMu.Unlock() // 发送连接请求 - reqData := make([]byte, 6) + // 格式: connID(4) + targetPort(2) + targetIPLen(1) + targetIP(变长) + targetIPBytes := []byte(targetIP) + reqData := make([]byte, 7+len(targetIPBytes)) binary.BigEndian.PutUint32(reqData[0:4], connID) binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort)) + reqData[6] = byte(len(targetIPBytes)) + copy(reqData[7:], targetIPBytes) msg := &TunnelMessage{ Version: ProtocolVersion, Type: MsgTypeConnectRequest, - Length: 6, + Length: uint32(len(reqData)), Data: reqData, } @@ -557,7 +566,7 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error { return fmt.Errorf("发送连接请求超时") } - log.Printf("发送连接请求: ID=%d, 端口=%d", connID, targetPort) + log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetIP, targetPort) // 等待连接响应 select { diff --git a/src/server/tunnel/tunnel_test.go b/src/server/tunnel/tunnel_test.go index a3f4761..1f03f1a 100644 --- a/src/server/tunnel/tunnel_test.go +++ b/src/server/tunnel/tunnel_test.go @@ -456,7 +456,7 @@ func TestForwardConnection(t *testing.T) { // 启动连接转发 go func() { - err := server.ForwardConnection(serverSideConn, targetPort) + err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort) if err != nil { t.Errorf("转发连接失败: %v", err) }