feat: 支持从client访问target_ip

This commit is contained in:
Pan Qiancheng 2025-10-14 16:47:24 +08:00
parent 7b719a33e7
commit c374741c5b
7 changed files with 55 additions and 28 deletions

View File

@ -278,21 +278,30 @@ func (c *Client) handleTunnelMessage(msg *TunnelMessage) {
// handleConnectRequest 处理连接请求
func (c *Client) handleConnectRequest(msg *TunnelMessage) {
if len(msg.Data) < 6 {
// 解析: connID(4) + targetPort(2) + targetIPLen(1) + targetIP(变长)
if len(msg.Data) < 7 {
log.Printf("连接请求数据太短")
return
}
connID := binary.BigEndian.Uint32(msg.Data[0:4])
targetPort := binary.BigEndian.Uint16(msg.Data[4:6])
targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)
targetIPLen := int(msg.Data[6])
log.Printf("收到连接请求: ID=%d, 端口=%d", connID, targetPort)
if len(msg.Data) < 7+targetIPLen {
log.Printf("连接请求数据不完整")
return
}
// 尝试连接到本地服务
targetIP := string(msg.Data[7 : 7+targetIPLen])
targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
log.Printf("收到连接请求: ID=%d, 地址=%s", connID, targetAddr)
// 尝试连接到目标服务
localConn, err := net.DialTimeout("tcp", targetAddr, ConnectTimeout)
if err != nil {
log.Printf("连接本地服务失败 (ID=%d -> %s): %v", connID, targetAddr, err)
log.Printf("连接目标服务失败 (ID=%d -> %s): %v", connID, targetAddr, err)
c.sendConnectResponse(connID, ConnStatusFailed)
return
}
@ -309,7 +318,7 @@ func (c *Client) handleConnectRequest(msg *TunnelMessage) {
c.connections[connID] = connection
c.connMu.Unlock()
log.Printf("建立本地连接: ID=%d -> %s", connID, targetAddr)
log.Printf("建立目标连接: ID=%d -> %s", connID, targetAddr)
// 发送连接成功响应
c.sendConnectResponse(connID, ConnStatusSuccess)
@ -337,9 +346,8 @@ func (c *Client) handleDataMessage(msg *TunnelMessage) {
return
}
// 写入到本地连接
if _, err := connection.Conn.Write(data); err != nil {
log.Printf("写入本地连接失败 (ID=%d): %v", connID, err)
log.Printf("写入目标连接失败 (ID=%d): %v", connID, err)
c.closeConnection(connID)
}
}
@ -410,7 +418,7 @@ func (c *Client) forwardData(connection *LocalConnection) {
n, err := connection.Conn.Read(buffer)
if err != nil {
if err != io.EOF && !isTimeout(err) {
log.Printf("读取本地连接失败 (ID=%d): %v", connection.ID, err)
log.Printf("读取目标连接失败 (ID=%d): %v", connection.ID, err)
}
return
}

View File

@ -97,8 +97,15 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
h.writeError(w, http.StatusServiceUnavailable, "隧道未连接")
return
}
// 隧道模式使用本地地址
req.TargetIP = "127.0.0.1"
// 隧道模式也需要目标IP客户端会连接到该IP
if req.TargetIP == "" {
h.writeError(w, http.StatusBadRequest, "target_ip 不能为空")
return
}
if net.ParseIP(req.TargetIP) == nil {
h.writeError(w, http.StatusBadRequest, "target_ip 格式无效")
return
}
} else {
// 直接模式需要验证 IP
if req.TargetIP == "" {
@ -121,7 +128,7 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
var err error
if req.UseTunnel {
// 隧道模式:使用隧道转发
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.SourcePort, h.tunnelServer)
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetIP, req.TargetPort, h.tunnelServer)
} else {
// 直接模式直接TCP转发
err = h.forwarderMgr.Add(req.SourcePort, req.TargetIP, req.TargetPort)

View File

@ -12,7 +12,7 @@ import (
// TunnelServer 隧道服务器接口
type TunnelServer interface {
ForwardConnection(clientConn net.Conn, targetPort int) error
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error
IsConnected() bool
}
@ -20,6 +20,7 @@ type TunnelServer interface {
type Forwarder struct {
sourcePort int
targetPort int
targetIP string
targetAddr string
listener net.Listener
cancel context.CancelFunc
@ -35,6 +36,7 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
return &Forwarder{
sourcePort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
targetAddr: fmt.Sprintf("%s:%d", targetIP, targetPort),
cancel: cancel,
ctx: ctx,
@ -43,11 +45,12 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
}
// NewTunnelForwarder 创建使用隧道的端口转发器
func NewTunnelForwarder(sourcePort int, targetPort int, tunnelServer TunnelServer) *Forwarder {
func NewTunnelForwarder(sourcePort int, targetIP string, targetPort int, tunnelServer TunnelServer) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
return &Forwarder{
sourcePort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
tunnelServer: tunnelServer,
useTunnel: true,
cancel: cancel,
@ -117,8 +120,8 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
}
// 将连接转发到隧道ForwardConnection 会处理连接关闭
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetPort); err != nil {
log.Printf("隧道转发失败 (端口 %d -> %d): %v", f.sourcePort, f.targetPort, err)
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetIP, f.targetPort); err != nil {
log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetIP, f.targetPort, err)
}
return
}
@ -221,7 +224,7 @@ func (m *Manager) Add(sourcePort int, targetIP string, targetPort int) error {
}
// AddTunnel 添加使用隧道的转发器
func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelServer TunnelServer) error {
func (m *Manager) AddTunnel(sourcePort int, targetIP string, targetPort int, tunnelServer TunnelServer) error {
m.mu.Lock()
defer m.mu.Unlock()
@ -229,7 +232,7 @@ func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelServer TunnelS
return fmt.Errorf("端口 %d 已被占用", sourcePort)
}
forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelServer)
forwarder := NewTunnelForwarder(sourcePort, targetIP, targetPort, tunnelServer)
if err := forwarder.Start(); err != nil {
return err
}

View File

@ -34,7 +34,7 @@ type mockTunnelServer struct {
connected bool
}
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetPort int) error {
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) error {
// 简单的模拟实现
defer clientConn.Close()
return nil
@ -49,7 +49,7 @@ func TestNewTunnelForwarder(t *testing.T) {
// 创建模拟隧道服务器
mockServer := &mockTunnelServer{connected: true}
fwd := NewTunnelForwarder(8080, 80, mockServer)
fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer)
if fwd == nil {
t.Fatal("创建隧道转发器失败")

View File

@ -73,7 +73,7 @@ func main() {
log.Printf("警告: 端口 %d 需要隧道模式但隧道服务未启用,跳过", mapping.SourcePort)
continue
}
err = fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetPort, tunnelServer)
err = fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort, tunnelServer)
} else {
// 直接模式
err = fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort)

View File

@ -53,6 +53,7 @@ type TunnelMessage struct {
type ConnectRequestData struct {
ConnID uint32 // 连接ID
TargetPort uint16 // 目标端口
TargetIP string // 目标IP地址
}
// ConnectResponseData 连接响应数据
@ -77,6 +78,7 @@ type PendingConnection struct {
ID uint32
ClientConn net.Conn
TargetPort int
TargetIP string
Created time.Time
ResponseChan chan bool // 用于接收连接响应
}
@ -86,6 +88,7 @@ type ActiveConnection struct {
ID uint32
ClientConn net.Conn
TargetPort int
TargetIP string
Created time.Time
}
@ -347,6 +350,7 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
ID: connID,
ClientConn: pending.ClientConn,
TargetPort: pending.TargetPort,
TargetIP: pending.TargetIP,
Created: time.Now(),
}
@ -354,7 +358,7 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
s.activeConns[connID] = active
s.connMu.Unlock()
log.Printf("连接已建立: ID=%d, 端口=%d", connID, pending.TargetPort)
log.Printf("连接已建立: ID=%d, 地址=%s:%d", connID, pending.TargetIP, pending.TargetPort)
// 启动数据转发
s.wg.Add(1)
@ -512,7 +516,7 @@ func (s *Server) closeConnection(connID uint32) {
}
// ForwardConnection 转发连接到隧道(新的透明代理实现)
func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
func (s *Server) ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error {
s.mu.RLock()
tunnelConnected := s.tunnelConn != nil
s.mu.RUnlock()
@ -530,6 +534,7 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
ID: connID,
ClientConn: clientConn,
TargetPort: targetPort,
TargetIP: targetIP,
Created: time.Now(),
ResponseChan: make(chan bool, 1),
}
@ -537,14 +542,18 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
s.connMu.Unlock()
// 发送连接请求
reqData := make([]byte, 6)
// 格式: connID(4) + targetPort(2) + targetIPLen(1) + targetIP(变长)
targetIPBytes := []byte(targetIP)
reqData := make([]byte, 7+len(targetIPBytes))
binary.BigEndian.PutUint32(reqData[0:4], connID)
binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort))
reqData[6] = byte(len(targetIPBytes))
copy(reqData[7:], targetIPBytes)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeConnectRequest,
Length: 6,
Length: uint32(len(reqData)),
Data: reqData,
}
@ -557,7 +566,7 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
return fmt.Errorf("发送连接请求超时")
}
log.Printf("发送连接请求: ID=%d, 端口=%d", connID, targetPort)
log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetIP, targetPort)
// 等待连接响应
select {

View File

@ -456,7 +456,7 @@ func TestForwardConnection(t *testing.T) {
// 启动连接转发
go func() {
err := server.ForwardConnection(serverSideConn, targetPort)
err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
if err != nil {
t.Errorf("转发连接失败: %v", err)
}