feat: 支持从client访问target_ip
This commit is contained in:
parent
7b719a33e7
commit
c374741c5b
|
|
@ -278,21 +278,30 @@ func (c *Client) handleTunnelMessage(msg *TunnelMessage) {
|
|||
|
||||
// handleConnectRequest 处理连接请求
|
||||
func (c *Client) handleConnectRequest(msg *TunnelMessage) {
|
||||
if len(msg.Data) < 6 {
|
||||
// 解析: connID(4) + targetPort(2) + targetIPLen(1) + targetIP(变长)
|
||||
if len(msg.Data) < 7 {
|
||||
log.Printf("连接请求数据太短")
|
||||
return
|
||||
}
|
||||
|
||||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||||
targetPort := binary.BigEndian.Uint16(msg.Data[4:6])
|
||||
targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)
|
||||
targetIPLen := int(msg.Data[6])
|
||||
|
||||
if len(msg.Data) < 7+targetIPLen {
|
||||
log.Printf("连接请求数据不完整")
|
||||
return
|
||||
}
|
||||
|
||||
targetIP := string(msg.Data[7 : 7+targetIPLen])
|
||||
targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
|
||||
|
||||
log.Printf("收到连接请求: ID=%d, 端口=%d", connID, targetPort)
|
||||
log.Printf("收到连接请求: ID=%d, 地址=%s", connID, targetAddr)
|
||||
|
||||
// 尝试连接到本地服务
|
||||
// 尝试连接到目标服务
|
||||
localConn, err := net.DialTimeout("tcp", targetAddr, ConnectTimeout)
|
||||
if err != nil {
|
||||
log.Printf("连接本地服务失败 (ID=%d -> %s): %v", connID, targetAddr, err)
|
||||
log.Printf("连接目标服务失败 (ID=%d -> %s): %v", connID, targetAddr, err)
|
||||
c.sendConnectResponse(connID, ConnStatusFailed)
|
||||
return
|
||||
}
|
||||
|
|
@ -309,7 +318,7 @@ func (c *Client) handleConnectRequest(msg *TunnelMessage) {
|
|||
c.connections[connID] = connection
|
||||
c.connMu.Unlock()
|
||||
|
||||
log.Printf("建立本地连接: ID=%d -> %s", connID, targetAddr)
|
||||
log.Printf("建立目标连接: ID=%d -> %s", connID, targetAddr)
|
||||
|
||||
// 发送连接成功响应
|
||||
c.sendConnectResponse(connID, ConnStatusSuccess)
|
||||
|
|
@ -337,9 +346,8 @@ func (c *Client) handleDataMessage(msg *TunnelMessage) {
|
|||
return
|
||||
}
|
||||
|
||||
// 写入到本地连接
|
||||
if _, err := connection.Conn.Write(data); err != nil {
|
||||
log.Printf("写入本地连接失败 (ID=%d): %v", connID, err)
|
||||
if _, err := connection.Conn.Write(data); err != nil {
|
||||
log.Printf("写入目标连接失败 (ID=%d): %v", connID, err)
|
||||
c.closeConnection(connID)
|
||||
}
|
||||
}
|
||||
|
|
@ -410,7 +418,7 @@ func (c *Client) forwardData(connection *LocalConnection) {
|
|||
n, err := connection.Conn.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF && !isTimeout(err) {
|
||||
log.Printf("读取本地连接失败 (ID=%d): %v", connection.ID, err)
|
||||
log.Printf("读取目标连接失败 (ID=%d): %v", connection.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,8 +97,15 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
|
|||
h.writeError(w, http.StatusServiceUnavailable, "隧道未连接")
|
||||
return
|
||||
}
|
||||
// 隧道模式使用本地地址
|
||||
req.TargetIP = "127.0.0.1"
|
||||
// 隧道模式也需要目标IP(客户端会连接到该IP)
|
||||
if req.TargetIP == "" {
|
||||
h.writeError(w, http.StatusBadRequest, "target_ip 不能为空")
|
||||
return
|
||||
}
|
||||
if net.ParseIP(req.TargetIP) == nil {
|
||||
h.writeError(w, http.StatusBadRequest, "target_ip 格式无效")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 直接模式需要验证 IP
|
||||
if req.TargetIP == "" {
|
||||
|
|
@ -121,7 +128,7 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
|
|||
var err error
|
||||
if req.UseTunnel {
|
||||
// 隧道模式:使用隧道转发
|
||||
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.SourcePort, h.tunnelServer)
|
||||
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetIP, req.TargetPort, h.tunnelServer)
|
||||
} else {
|
||||
// 直接模式:直接TCP转发
|
||||
err = h.forwarderMgr.Add(req.SourcePort, req.TargetIP, req.TargetPort)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
// TunnelServer 隧道服务器接口
|
||||
type TunnelServer interface {
|
||||
ForwardConnection(clientConn net.Conn, targetPort int) error
|
||||
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error
|
||||
IsConnected() bool
|
||||
}
|
||||
|
||||
|
|
@ -20,6 +20,7 @@ type TunnelServer interface {
|
|||
type Forwarder struct {
|
||||
sourcePort int
|
||||
targetPort int
|
||||
targetIP string
|
||||
targetAddr string
|
||||
listener net.Listener
|
||||
cancel context.CancelFunc
|
||||
|
|
@ -35,6 +36,7 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
|
|||
return &Forwarder{
|
||||
sourcePort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetIP: targetIP,
|
||||
targetAddr: fmt.Sprintf("%s:%d", targetIP, targetPort),
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
|
|
@ -43,11 +45,12 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
|
|||
}
|
||||
|
||||
// NewTunnelForwarder 创建使用隧道的端口转发器
|
||||
func NewTunnelForwarder(sourcePort int, targetPort int, tunnelServer TunnelServer) *Forwarder {
|
||||
func NewTunnelForwarder(sourcePort int, targetIP string, targetPort int, tunnelServer TunnelServer) *Forwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Forwarder{
|
||||
sourcePort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetIP: targetIP,
|
||||
tunnelServer: tunnelServer,
|
||||
useTunnel: true,
|
||||
cancel: cancel,
|
||||
|
|
@ -117,8 +120,8 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
|||
}
|
||||
|
||||
// 将连接转发到隧道,ForwardConnection 会处理连接关闭
|
||||
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetPort); err != nil {
|
||||
log.Printf("隧道转发失败 (端口 %d -> %d): %v", f.sourcePort, f.targetPort, err)
|
||||
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetIP, f.targetPort); err != nil {
|
||||
log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetIP, f.targetPort, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -221,7 +224,7 @@ func (m *Manager) Add(sourcePort int, targetIP string, targetPort int) error {
|
|||
}
|
||||
|
||||
// AddTunnel 添加使用隧道的转发器
|
||||
func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelServer TunnelServer) error {
|
||||
func (m *Manager) AddTunnel(sourcePort int, targetIP string, targetPort int, tunnelServer TunnelServer) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
|
|
@ -229,7 +232,7 @@ func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelServer TunnelS
|
|||
return fmt.Errorf("端口 %d 已被占用", sourcePort)
|
||||
}
|
||||
|
||||
forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelServer)
|
||||
forwarder := NewTunnelForwarder(sourcePort, targetIP, targetPort, tunnelServer)
|
||||
if err := forwarder.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ type mockTunnelServer struct {
|
|||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetPort int) error {
|
||||
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) error {
|
||||
// 简单的模拟实现
|
||||
defer clientConn.Close()
|
||||
return nil
|
||||
|
|
@ -49,7 +49,7 @@ func TestNewTunnelForwarder(t *testing.T) {
|
|||
// 创建模拟隧道服务器
|
||||
mockServer := &mockTunnelServer{connected: true}
|
||||
|
||||
fwd := NewTunnelForwarder(8080, 80, mockServer)
|
||||
fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer)
|
||||
|
||||
if fwd == nil {
|
||||
t.Fatal("创建隧道转发器失败")
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ func main() {
|
|||
log.Printf("警告: 端口 %d 需要隧道模式但隧道服务未启用,跳过", mapping.SourcePort)
|
||||
continue
|
||||
}
|
||||
err = fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetPort, tunnelServer)
|
||||
err = fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort, tunnelServer)
|
||||
} else {
|
||||
// 直接模式
|
||||
err = fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort)
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ type TunnelMessage struct {
|
|||
type ConnectRequestData struct {
|
||||
ConnID uint32 // 连接ID
|
||||
TargetPort uint16 // 目标端口
|
||||
TargetIP string // 目标IP地址
|
||||
}
|
||||
|
||||
// ConnectResponseData 连接响应数据
|
||||
|
|
@ -77,6 +78,7 @@ type PendingConnection struct {
|
|||
ID uint32
|
||||
ClientConn net.Conn
|
||||
TargetPort int
|
||||
TargetIP string
|
||||
Created time.Time
|
||||
ResponseChan chan bool // 用于接收连接响应
|
||||
}
|
||||
|
|
@ -86,6 +88,7 @@ type ActiveConnection struct {
|
|||
ID uint32
|
||||
ClientConn net.Conn
|
||||
TargetPort int
|
||||
TargetIP string
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
|
|
@ -347,6 +350,7 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
|
|||
ID: connID,
|
||||
ClientConn: pending.ClientConn,
|
||||
TargetPort: pending.TargetPort,
|
||||
TargetIP: pending.TargetIP,
|
||||
Created: time.Now(),
|
||||
}
|
||||
|
||||
|
|
@ -354,7 +358,7 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
|
|||
s.activeConns[connID] = active
|
||||
s.connMu.Unlock()
|
||||
|
||||
log.Printf("连接已建立: ID=%d, 端口=%d", connID, pending.TargetPort)
|
||||
log.Printf("连接已建立: ID=%d, 地址=%s:%d", connID, pending.TargetIP, pending.TargetPort)
|
||||
|
||||
// 启动数据转发
|
||||
s.wg.Add(1)
|
||||
|
|
@ -512,7 +516,7 @@ func (s *Server) closeConnection(connID uint32) {
|
|||
}
|
||||
|
||||
// ForwardConnection 转发连接到隧道(新的透明代理实现)
|
||||
func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
|
||||
func (s *Server) ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error {
|
||||
s.mu.RLock()
|
||||
tunnelConnected := s.tunnelConn != nil
|
||||
s.mu.RUnlock()
|
||||
|
|
@ -530,6 +534,7 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
|
|||
ID: connID,
|
||||
ClientConn: clientConn,
|
||||
TargetPort: targetPort,
|
||||
TargetIP: targetIP,
|
||||
Created: time.Now(),
|
||||
ResponseChan: make(chan bool, 1),
|
||||
}
|
||||
|
|
@ -537,14 +542,18 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
|
|||
s.connMu.Unlock()
|
||||
|
||||
// 发送连接请求
|
||||
reqData := make([]byte, 6)
|
||||
// 格式: connID(4) + targetPort(2) + targetIPLen(1) + targetIP(变长)
|
||||
targetIPBytes := []byte(targetIP)
|
||||
reqData := make([]byte, 7+len(targetIPBytes))
|
||||
binary.BigEndian.PutUint32(reqData[0:4], connID)
|
||||
binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort))
|
||||
reqData[6] = byte(len(targetIPBytes))
|
||||
copy(reqData[7:], targetIPBytes)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeConnectRequest,
|
||||
Length: 6,
|
||||
Length: uint32(len(reqData)),
|
||||
Data: reqData,
|
||||
}
|
||||
|
||||
|
|
@ -557,7 +566,7 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
|
|||
return fmt.Errorf("发送连接请求超时")
|
||||
}
|
||||
|
||||
log.Printf("发送连接请求: ID=%d, 端口=%d", connID, targetPort)
|
||||
log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetIP, targetPort)
|
||||
|
||||
// 等待连接响应
|
||||
select {
|
||||
|
|
|
|||
|
|
@ -456,7 +456,7 @@ func TestForwardConnection(t *testing.T) {
|
|||
|
||||
// 启动连接转发
|
||||
go func() {
|
||||
err := server.ForwardConnection(serverSideConn, targetPort)
|
||||
err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
|
||||
if err != nil {
|
||||
t.Errorf("转发连接失败: %v", err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue