feat: 端口转发已实现

This commit is contained in:
Pan Qiancheng 2025-10-14 16:05:19 +08:00
parent 6b02def8de
commit 8c2f803b77
18 changed files with 5450 additions and 469 deletions

1505
README.md

File diff suppressed because it is too large Load Diff

View File

@ -15,24 +15,24 @@ build: build-server build-client
# 编译服务器
build-server:
@echo "编译服务器..."
go build -o bin/server ./server
go build -o ../bin/server ./server
@echo "服务器编译完成: bin/server"
# 编译客户端
build-client:
@echo "编译客户端..."
go build -o bin/client ./client
go build -o ../bin/client ./client
@echo "客户端编译完成: bin/client"
# 运行服务器
run-server: build-server
@echo "启动服务器..."
./bin/server -config config.yaml
../bin/server -config config.yaml
# 运行客户端
run-client: build-client
@echo "启动客户端..."
./bin/client -server localhost:9000
../bin/client -server localhost:9000
# 清理编译文件
clean:
@ -56,8 +56,8 @@ init:
# 交叉编译 Linux
build-linux:
@echo "交叉编译 Linux 版本..."
GOOS=linux GOARCH=amd64 go build -o bin/server-linux ./server
GOOS=linux GOARCH=amd64 go build -o bin/client-linux ./client
GOOS=linux GOARCH=amd64 go build -o ../bin/server-linux ./server
GOOS=linux GOARCH=amd64 go build -o ../bin/client-linux ./client
@echo "Linux 版本编译完成"
# 格式化代码

View File

@ -12,14 +12,51 @@ import (
)
const (
// HeaderSize 消息头大小
HeaderSize = 8
// MaxPacketSize 最大包大小
// 协议版本
ProtocolVersion = 0x01
// 消息头大小
HeaderSize = 6 // 版本(1) + 类型(1) + 长度(4)
// 最大包大小
MaxPacketSize = 1024 * 1024
// ReconnectDelay 重连延迟
// 重连延迟
ReconnectDelay = 5 * time.Second
// 消息类型
MsgTypeConnectRequest = 0x01 // 连接请求
MsgTypeConnectResponse = 0x02 // 连接响应
MsgTypeData = 0x03 // 数据传输
MsgTypeClose = 0x04 // 关闭连接
MsgTypeKeepAlive = 0x05 // 心跳
// 连接响应状态
ConnStatusSuccess = 0x00 // 连接成功
ConnStatusFailed = 0x01 // 连接失败
// 超时设置
ConnectTimeout = 10 * time.Second // 连接超时
ReadTimeout = 30 * time.Second // 读取超时
)
// TunnelMessage 隧道消息
type TunnelMessage struct {
Version byte
Type byte
Length uint32
Data []byte
}
// LocalConnection 本地连接
type LocalConnection struct {
ID uint32
TargetAddr string
Conn net.Conn
closeChan chan struct{}
closeOnce sync.Once
}
// Client 内网穿透客户端
type Client struct {
serverAddr string
@ -32,14 +69,9 @@ type Client struct {
// 连接管理
connections map[uint32]*LocalConnection
connMu sync.RWMutex
}
// LocalConnection 本地连接
type LocalConnection struct {
ID uint32
TargetAddr string
Conn net.Conn
closeChan chan struct{}
// 消息队列
sendChan chan *TunnelMessage
}
// NewClient 创建新的隧道客户端
@ -50,6 +82,7 @@ func NewClient(serverAddr string) *Client {
cancel: cancel,
ctx: ctx,
connections: make(map[uint32]*LocalConnection),
sendChan: make(chan *TunnelMessage, 1000),
}
}
@ -88,11 +121,19 @@ func (c *Client) connectLoop() {
c.mu.Unlock()
// 处理连接
if err := c.handleServerConnection(conn); err != nil {
if err != io.EOF {
log.Printf("处理服务器连接出错: %v", err)
}
}
var connWg sync.WaitGroup
connWg.Add(2)
go func() {
defer connWg.Done()
c.handleServerRead(conn)
}()
go func() {
defer connWg.Done()
c.handleServerWrite(conn)
}()
// 等待连接断开
connWg.Wait()
c.mu.Lock()
c.serverConn = nil
@ -101,7 +142,9 @@ func (c *Client) connectLoop() {
// 关闭所有本地连接
c.connMu.Lock()
for _, conn := range c.connections {
close(conn.closeChan)
conn.closeOnce.Do(func() {
close(conn.closeChan)
})
if conn.Conn != nil {
conn.Conn.Close()
}
@ -114,140 +157,317 @@ func (c *Client) connectLoop() {
}
}
// handleServerConnection 处理服务器连接
func (c *Client) handleServerConnection(conn net.Conn) error {
for {
select {
case <-c.ctx.Done():
return nil
default:
}
// 读取消息头
header := make([]byte, HeaderSize)
if _, err := io.ReadFull(conn, header); err != nil {
return err
}
dataLen := binary.BigEndian.Uint32(header[0:4])
connID := binary.BigEndian.Uint32(header[4:8])
if dataLen > MaxPacketSize {
return fmt.Errorf("数据包过大: %d bytes", dataLen)
}
// 读取数据
data := make([]byte, dataLen)
if _, err := io.ReadFull(conn, data); err != nil {
return err
}
// 处理数据
c.handleData(connID, data)
}
}
// handleData 处理接收到的数据
func (c *Client) handleData(connID uint32, data []byte) {
c.connMu.Lock()
localConn, exists := c.connections[connID]
if !exists {
// 新连接,需要建立到本地服务的连接
// 从数据中解析目标端口(这里简化处理,实际应该从协议中获取)
localConn = &LocalConnection{
ID: connID,
closeChan: make(chan struct{}),
}
c.connections[connID] = localConn
c.connMu.Unlock()
// 启动本地连接处理
c.wg.Add(1)
go c.handleLocalConnection(localConn)
// 重新获取锁并发送数据
c.connMu.Lock()
}
c.connMu.Unlock()
// 发送数据到本地连接
if localConn.Conn != nil {
localConn.Conn.Write(data)
}
}
// handleLocalConnection 处理本地连接
func (c *Client) handleLocalConnection(localConn *LocalConnection) {
defer c.wg.Done()
defer func() {
c.connMu.Lock()
delete(c.connections, localConn.ID)
c.connMu.Unlock()
if localConn.Conn != nil {
localConn.Conn.Close()
}
}()
// 连接到本地目标服务
// 这里使用固定的本地地址,实际应该根据映射配置
targetAddr := localConn.TargetAddr
if targetAddr == "" {
targetAddr = "127.0.0.1:22" // 默认 SSH
}
conn, err := net.DialTimeout("tcp", targetAddr, 5*time.Second)
if err != nil {
log.Printf("连接本地服务失败 (连接 %d -> %s): %v", localConn.ID, targetAddr, err)
return
}
// handleServerRead 处理服务器读取
func (c *Client) handleServerRead(conn net.Conn) {
defer conn.Close()
localConn.Conn = conn
log.Printf("建立本地连接: %d -> %s", localConn.ID, targetAddr)
// 从本地服务读取数据并发送到服务器
buffer := make([]byte, 32*1024)
for {
select {
case <-localConn.closeChan:
return
case <-c.ctx.Done():
return
default:
}
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := conn.Read(buffer)
msg, err := c.readTunnelMessage(conn)
if err != nil {
if err != io.EOF && !isTimeout(err) {
log.Printf("读取本地连接失败 (连接 %d): %v", localConn.ID, err)
if err != io.EOF {
log.Printf("读取隧道消息失败: %v", err)
}
return
}
// 发送到服务器
c.mu.RLock()
serverConn := c.serverConn
c.mu.RUnlock()
c.handleTunnelMessage(msg)
}
}
if serverConn == nil {
// handleServerWrite 处理服务器写入
func (c *Client) handleServerWrite(conn net.Conn) {
for {
select {
case <-c.ctx.Done():
return
case msg := <-c.sendChan:
if err := c.writeTunnelMessage(conn, msg); err != nil {
log.Printf("写入隧道消息失败: %v", err)
return
}
}
}
}
// readTunnelMessage 读取隧道消息
func (c *Client) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
// 读取消息头
header := make([]byte, HeaderSize)
if _, err := io.ReadFull(conn, header); err != nil {
return nil, err
}
version := header[0]
msgType := header[1]
dataLen := binary.BigEndian.Uint32(header[2:6])
if version != ProtocolVersion {
return nil, fmt.Errorf("不支持的协议版本: %d", version)
}
if dataLen > MaxPacketSize {
return nil, fmt.Errorf("数据包过大: %d bytes", dataLen)
}
// 读取数据
var data []byte
if dataLen > 0 {
data = make([]byte, dataLen)
if _, err := io.ReadFull(conn, data); err != nil {
return nil, err
}
}
return &TunnelMessage{
Version: version,
Type: msgType,
Length: dataLen,
Data: data,
}, nil
}
// writeTunnelMessage 写入隧道消息
func (c *Client) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
// 构建消息头
header := make([]byte, HeaderSize)
header[0] = msg.Version
header[1] = msg.Type
binary.BigEndian.PutUint32(header[2:6], msg.Length)
// 写入消息头
if _, err := conn.Write(header); err != nil {
return err
}
// 写入数据
if msg.Length > 0 && msg.Data != nil {
if _, err := conn.Write(msg.Data); err != nil {
return err
}
}
return nil
}
// handleTunnelMessage 处理隧道消息
func (c *Client) handleTunnelMessage(msg *TunnelMessage) {
switch msg.Type {
case MsgTypeConnectRequest:
c.handleConnectRequest(msg)
case MsgTypeData:
c.handleDataMessage(msg)
case MsgTypeClose:
c.handleCloseMessage(msg)
case MsgTypeKeepAlive:
c.handleKeepAlive(msg)
default:
log.Printf("未知消息类型: %d", msg.Type)
}
}
// handleConnectRequest 处理连接请求
func (c *Client) handleConnectRequest(msg *TunnelMessage) {
if len(msg.Data) < 6 {
log.Printf("连接请求数据太短")
return
}
connID := binary.BigEndian.Uint32(msg.Data[0:4])
targetPort := binary.BigEndian.Uint16(msg.Data[4:6])
targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)
log.Printf("收到连接请求: ID=%d, 端口=%d", connID, targetPort)
// 尝试连接到本地服务
localConn, err := net.DialTimeout("tcp", targetAddr, ConnectTimeout)
if err != nil {
log.Printf("连接本地服务失败 (ID=%d -> %s): %v", connID, targetAddr, err)
c.sendConnectResponse(connID, ConnStatusFailed)
return
}
// 创建本地连接对象
connection := &LocalConnection{
ID: connID,
TargetAddr: targetAddr,
Conn: localConn,
closeChan: make(chan struct{}),
}
c.connMu.Lock()
c.connections[connID] = connection
c.connMu.Unlock()
log.Printf("建立本地连接: ID=%d -> %s", connID, targetAddr)
// 发送连接成功响应
c.sendConnectResponse(connID, ConnStatusSuccess)
// 启动数据转发
go c.forwardData(connection)
}
// handleDataMessage 处理数据消息
func (c *Client) handleDataMessage(msg *TunnelMessage) {
if len(msg.Data) < 4 {
log.Printf("数据消息太短")
return
}
connID := binary.BigEndian.Uint32(msg.Data[0:4])
data := msg.Data[4:]
c.connMu.RLock()
connection, exists := c.connections[connID]
c.connMu.RUnlock()
if !exists {
log.Printf("收到未知连接的数据: %d", connID)
return
}
// 写入到本地连接
if _, err := connection.Conn.Write(data); err != nil {
log.Printf("写入本地连接失败 (ID=%d): %v", connID, err)
c.closeConnection(connID)
}
}
// handleCloseMessage 处理关闭消息
func (c *Client) handleCloseMessage(msg *TunnelMessage) {
if len(msg.Data) < 4 {
log.Printf("关闭消息数据太短")
return
}
connID := binary.BigEndian.Uint32(msg.Data[0:4])
c.closeConnection(connID)
}
// handleKeepAlive 处理心跳消息
func (c *Client) handleKeepAlive(msg *TunnelMessage) {
// 回应心跳
response := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeKeepAlive,
Length: 0,
Data: nil,
}
select {
case c.sendChan <- response:
default:
log.Printf("发送心跳响应失败: 发送队列已满")
}
}
// sendConnectResponse 发送连接响应
func (c *Client) sendConnectResponse(connID uint32, status byte) {
responseData := make([]byte, 5)
binary.BigEndian.PutUint32(responseData[0:4], connID)
responseData[4] = status
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeConnectResponse,
Length: 5,
Data: responseData,
}
select {
case c.sendChan <- msg:
default:
log.Printf("发送连接响应失败: 发送队列已满")
}
}
// forwardData 转发数据
func (c *Client) forwardData(connection *LocalConnection) {
defer c.closeConnection(connection.ID)
buffer := make([]byte, 32*1024)
for {
select {
case <-connection.closeChan:
return
case <-c.ctx.Done():
return
default:
}
connection.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
n, err := connection.Conn.Read(buffer)
if err != nil {
if err != io.EOF && !isTimeout(err) {
log.Printf("读取本地连接失败 (ID=%d): %v", connection.ID, err)
}
return
}
data := make([]byte, HeaderSize+n)
binary.BigEndian.PutUint32(data[0:4], uint32(n))
binary.BigEndian.PutUint32(data[4:8], localConn.ID)
copy(data[HeaderSize:], buffer[:n])
// 发送数据到隧道
dataMsg := make([]byte, 4+n)
binary.BigEndian.PutUint32(dataMsg[0:4], connection.ID)
copy(dataMsg[4:], buffer[:n])
if _, err := serverConn.Write(data); err != nil {
log.Printf("发送数据到服务器失败 (连接 %d): %v", localConn.ID, err)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(dataMsg)),
Data: dataMsg,
}
select {
case c.sendChan <- msg:
case <-time.After(5 * time.Second):
log.Printf("发送数据超时 (ID=%d)", connection.ID)
return
case <-c.ctx.Done():
return
}
}
}
// closeConnection 关闭连接
func (c *Client) closeConnection(connID uint32) {
c.connMu.Lock()
connection, exists := c.connections[connID]
if exists {
delete(c.connections, connID)
connection.closeOnce.Do(func() {
close(connection.closeChan)
})
connection.Conn.Close()
}
c.connMu.Unlock()
// 发送关闭消息
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeClose,
Length: 4,
Data: closeData,
}
select {
case c.sendChan <- msg:
default:
// 发送队列满,忽略
}
if exists {
log.Printf("连接已关闭: ID=%d", connID)
}
}
// Stop 停止隧道客户端
func (c *Client) Stop() error {
log.Println("正在停止隧道客户端...")

View File

@ -0,0 +1,482 @@
package tunnel
import (
"encoding/binary"
"io"
"net"
"sync"
"testing"
"time"
)
// TestNewClient 测试创建隧道客户端
func TestNewClient(t *testing.T) {
client := NewClient("127.0.0.1:9000")
if client == nil {
t.Fatal("创建隧道客户端失败")
}
if client.serverAddr != "127.0.0.1:9000" {
t.Errorf("服务器地址不正确,期望 127.0.0.1:9000得到 %s", client.serverAddr)
}
if client.connections == nil {
t.Error("连接映射未初始化")
}
if client.sendChan == nil {
t.Error("发送通道未初始化")
}
if client.ctx == nil {
t.Error("上下文未初始化")
}
}
// TestClientStartStop 测试客户端启动和停止
func TestClientStartStop(t *testing.T) {
// 创建模拟服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建模拟服务器失败: %v", err)
}
defer listener.Close()
serverAddr := listener.Addr().String()
// 启动模拟服务器
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
// 保持连接短暂时间然后关闭
time.Sleep(100 * time.Millisecond)
conn.Close()
}
}()
client := NewClient(serverAddr)
err = client.Start()
if err != nil {
t.Fatalf("启动客户端失败: %v", err)
}
// 等待客户端尝试连接
time.Sleep(300 * time.Millisecond)
// 停止客户端
err = client.Stop()
if err != nil {
t.Errorf("停止客户端失败: %v", err)
}
}
// TestClientTunnelMessage 测试客户端隧道消息处理
func TestClientTunnelMessage(t *testing.T) {
client := NewClient("127.0.0.1:9000")
// 创建测试消息
data := []byte("test data")
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(data)),
Data: data,
}
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
// 测试写入消息
go func() {
err := client.writeTunnelMessage(serverConn, msg)
if err != nil {
t.Errorf("写入隧道消息失败: %v", err)
}
}()
// 测试读取消息
receivedMsg, err := client.readTunnelMessage(clientConn)
if err != nil {
t.Fatalf("读取隧道消息失败: %v", err)
}
// 验证消息内容
if receivedMsg.Version != msg.Version {
t.Errorf("版本不匹配,期望 %d得到 %d", msg.Version, receivedMsg.Version)
}
if receivedMsg.Type != msg.Type {
t.Errorf("类型不匹配,期望 %d得到 %d", msg.Type, receivedMsg.Type)
}
if receivedMsg.Length != msg.Length {
t.Errorf("长度不匹配,期望 %d得到 %d", msg.Length, receivedMsg.Length)
}
if string(receivedMsg.Data) != string(msg.Data) {
t.Errorf("数据不匹配,期望 %s得到 %s", string(msg.Data), string(receivedMsg.Data))
}
}
// TestClientHandleConnectRequest 测试客户端处理连接请求
func TestClientHandleConnectRequest(t *testing.T) {
// 启动本地测试服务
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("启动本地测试服务失败: %v", err)
}
defer listener.Close()
localPort := listener.Addr().(*net.TCPAddr).Port
// 启动echo服务
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
io.Copy(c, c) // echo服务
}(conn)
}
}()
client := NewClient("127.0.0.1:9000")
// 创建连接请求消息
connID := uint32(12345)
reqData := make([]byte, 6)
binary.BigEndian.PutUint32(reqData[0:4], connID)
binary.BigEndian.PutUint16(reqData[4:6], uint16(localPort))
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeConnectRequest,
Length: 6,
Data: reqData,
}
// 创建模拟服务器连接
serverConn, clientServerConn := net.Pipe()
defer serverConn.Close()
defer clientServerConn.Close()
// 启动发送处理
go func() {
for {
select {
case msg := <-client.sendChan:
client.writeTunnelMessage(clientServerConn, msg)
case <-time.After(time.Second):
return
}
}
}()
// 处理连接请求
client.handleTunnelMessage(msg)
// 等待连接建立
time.Sleep(100 * time.Millisecond)
// 验证连接是否建立
client.connMu.RLock()
_, exists := client.connections[connID]
client.connMu.RUnlock()
if !exists {
t.Error("连接未建立")
}
// 读取连接响应
response, err := client.readTunnelMessage(serverConn)
if err != nil {
t.Fatalf("读取连接响应失败: %v", err)
}
if response.Type != MsgTypeConnectResponse {
t.Errorf("响应类型不正确,期望 %d得到 %d", MsgTypeConnectResponse, response.Type)
}
if len(response.Data) < 5 {
t.Fatal("响应数据太短")
}
responseConnID := binary.BigEndian.Uint32(response.Data[0:4])
status := response.Data[4]
if responseConnID != connID {
t.Errorf("响应连接ID不匹配期望 %d得到 %d", connID, responseConnID)
}
if status != ConnStatusSuccess {
t.Errorf("连接状态不正确,期望 %d得到 %d", ConnStatusSuccess, status)
}
// 清理连接
client.closeConnection(connID)
}
// TestClientHandleDataMessage 测试客户端处理数据消息
func TestClientHandleDataMessage(t *testing.T) {
client := NewClient("127.0.0.1:9000")
// 创建模拟本地连接
localConn, remoteConn := net.Pipe()
defer localConn.Close()
defer remoteConn.Close()
connID := uint32(12345)
connection := &LocalConnection{
ID: connID,
TargetAddr: "127.0.0.1:8080",
Conn: localConn,
closeChan: make(chan struct{}),
}
client.connMu.Lock()
client.connections[connID] = connection
client.connMu.Unlock()
// 创建数据消息
testData := []byte("Hello, World!")
dataMsg := make([]byte, 4+len(testData))
binary.BigEndian.PutUint32(dataMsg[0:4], connID)
copy(dataMsg[4:], testData)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(dataMsg)),
Data: dataMsg,
}
// 处理数据消息
go client.handleTunnelMessage(msg)
// 从远程连接读取数据
buffer := make([]byte, len(testData))
remoteConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := remoteConn.Read(buffer)
if err != nil {
t.Fatalf("读取数据失败: %v", err)
}
if n != len(testData) {
t.Errorf("数据长度不匹配,期望 %d得到 %d", len(testData), n)
}
if string(buffer[:n]) != string(testData) {
t.Errorf("数据内容不匹配,期望 %s得到 %s", string(testData), string(buffer[:n]))
}
// 清理连接
client.closeConnection(connID)
}
// TestClientHandleCloseMessage 测试客户端处理关闭消息
func TestClientHandleCloseMessage(t *testing.T) {
client := NewClient("127.0.0.1:9000")
// 创建模拟本地连接
localConn, _ := net.Pipe()
defer localConn.Close()
connID := uint32(12345)
connection := &LocalConnection{
ID: connID,
TargetAddr: "127.0.0.1:8080",
Conn: localConn,
closeChan: make(chan struct{}),
}
client.connMu.Lock()
client.connections[connID] = connection
client.connMu.Unlock()
// 创建关闭消息
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeClose,
Length: 4,
Data: closeData,
}
// 处理关闭消息
client.handleTunnelMessage(msg)
// 等待连接关闭
time.Sleep(100 * time.Millisecond)
// 验证连接是否被移除
client.connMu.RLock()
_, exists := client.connections[connID]
client.connMu.RUnlock()
if exists {
t.Error("连接未被移除")
}
}
// TestClientKeepAlive 测试客户端心跳处理
func TestClientKeepAlive(t *testing.T) {
client := NewClient("127.0.0.1:9000")
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
// 启动发送处理
go func() {
for {
select {
case msg := <-client.sendChan:
client.writeTunnelMessage(clientConn, msg)
case <-time.After(time.Second):
return
}
}
}()
// 创建心跳消息
keepAliveMsg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeKeepAlive,
Length: 0,
Data: nil,
}
// 处理心跳消息
client.handleTunnelMessage(keepAliveMsg)
// 读取心跳响应
response, err := client.readTunnelMessage(serverConn)
if err != nil {
t.Fatalf("读取心跳响应失败: %v", err)
}
if response.Type != MsgTypeKeepAlive {
t.Errorf("心跳响应类型不正确,期望 %d得到 %d", MsgTypeKeepAlive, response.Type)
}
}
// TestClientReconnect 测试客户端重连功能
func TestClientReconnect(t *testing.T) {
// 这个测试比较复杂,需要模拟服务器关闭和重启
// 简化版本只测试客户端的创建和基本功能
client := NewClient("127.0.0.1:19999") // 使用不存在的端口
err := client.Start()
if err != nil {
t.Fatalf("启动客户端失败: %v", err)
}
// 等待客户端尝试连接
time.Sleep(200 * time.Millisecond)
err = client.Stop()
if err != nil {
t.Errorf("停止客户端失败: %v", err)
}
}
// TestClientConcurrentConnections 测试客户端并发连接处理
func TestClientConcurrentConnections(t *testing.T) {
client := NewClient("127.0.0.1:9000")
var wg sync.WaitGroup
connCount := 5
// 创建多个模拟连接
for i := 0; i < connCount; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// 创建模拟本地连接
localConn, _ := net.Pipe()
defer localConn.Close()
connID := uint32(id + 1)
connection := &LocalConnection{
ID: connID,
TargetAddr: "127.0.0.1:8080",
Conn: localConn,
closeChan: make(chan struct{}),
}
client.connMu.Lock()
client.connections[connID] = connection
client.connMu.Unlock()
// 保持连接一段时间
time.Sleep(100 * time.Millisecond)
// 清理连接
client.closeConnection(connID)
}(i)
}
wg.Wait()
// 验证所有连接都被清理
client.connMu.RLock()
remainingConns := len(client.connections)
client.connMu.RUnlock()
if remainingConns != 0 {
t.Errorf("还有 %d 个连接未清理", remainingConns)
}
}
// TestProtocolConstants 测试协议常量
func TestProtocolConstants(t *testing.T) {
if ProtocolVersion != 0x01 {
t.Errorf("协议版本不正确,期望 0x01得到 0x%02x", ProtocolVersion)
}
if HeaderSize != 6 {
t.Errorf("消息头大小不正确,期望 6得到 %d", HeaderSize)
}
if MaxPacketSize != 1024*1024 {
t.Errorf("最大包大小不正确,期望 %d得到 %d", 1024*1024, MaxPacketSize)
}
// 验证消息类型常量
expectedTypes := map[string]byte{
"MsgTypeConnectRequest": MsgTypeConnectRequest,
"MsgTypeConnectResponse": MsgTypeConnectResponse,
"MsgTypeData": MsgTypeData,
"MsgTypeClose": MsgTypeClose,
"MsgTypeKeepAlive": MsgTypeKeepAlive,
}
for name, value := range expectedTypes {
if value == 0 {
t.Errorf("消息类型 %s 未定义", name)
}
}
// 验证连接状态常量
if ConnStatusSuccess != 0x00 {
t.Errorf("连接成功状态不正确,期望 0x00得到 0x%02x", ConnStatusSuccess)
}
if ConnStatusFailed != 0x01 {
t.Errorf("连接失败状态不正确,期望 0x01得到 0x%02x", ConnStatusFailed)
}
}

View File

@ -20,25 +20,25 @@ type Handler struct {
tunnelServer *tunnel.Server
portRangeFrom int
portRangeEnd int
useTunnel bool
}
// NewHandler 创建新的 API 处理器
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, portFrom, portEnd int, useTunnel bool) *Handler {
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, portFrom, portEnd int) *Handler {
return &Handler{
db: database,
forwarderMgr: fwdMgr,
tunnelServer: ts,
portRangeFrom: portFrom,
portRangeEnd: portEnd,
useTunnel: useTunnel,
}
}
// CreateMappingRequest 创建映射请求
type CreateMappingRequest struct {
Port int `json:"port"` // 源端口和目标端口(相同)
TargetIP string `json:"target_ip"` // 目标 IP非隧道模式使用
SourcePort int `json:"source_port"` // 源端口(本地监听端口)
TargetPort int `json:"target_port"` // 目标端口(远程服务端口)
TargetIP string `json:"target_ip"` // 目标 IP非隧道模式使用
UseTunnel bool `json:"use_tunnel"` // 是否使用隧道模式
}
// RemoveMappingRequest 删除映射请求
@ -75,57 +75,72 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
}
// 验证端口范围
if req.Port < h.portRangeFrom || req.Port > h.portRangeEnd {
if req.SourcePort < h.portRangeFrom || req.SourcePort > h.portRangeEnd {
h.writeError(w, http.StatusBadRequest, fmt.Sprintf("端口必须在 %d-%d 范围内", h.portRangeFrom, h.portRangeEnd))
return
}
// 检查端口是否已被使用
if h.forwarderMgr.Exists(req.Port) {
if h.forwarderMgr.Exists(req.SourcePort) {
h.writeError(w, http.StatusConflict, "端口已被占用")
return
}
// 非隧道模式需要验证 IP
if !h.useTunnel {
if req.TargetIP == "" {
h.writeError(w, http.StatusBadRequest, "target_ip 不能为空")
// 根据请求决定使用哪种模式
if req.UseTunnel {
// 隧道模式,检查隧道服务器是否可用
if h.tunnelServer == nil {
h.writeError(w, http.StatusServiceUnavailable, "隧道服务未启用")
return
}
if net.ParseIP(req.TargetIP) == nil {
h.writeError(w, http.StatusBadRequest, "target_ip 格式无效")
return
}
} else {
// 隧道模式,检查隧道是否连接
if !h.tunnelServer.IsConnected() {
h.writeError(w, http.StatusServiceUnavailable, "隧道未连接")
return
}
// 隧道模式使用本地地址
req.TargetIP = "127.0.0.1"
} else {
// 直接模式需要验证 IP
if req.TargetIP == "" {
h.writeError(w, http.StatusBadRequest, "非隧道模式下 target_ip 不能为空")
return
}
if net.ParseIP(req.TargetIP) == nil {
h.writeError(w, http.StatusBadRequest, "target_ip 格式无效")
return
}
}
// 添加到数据库
if err := h.db.AddMapping(req.Port, req.TargetIP, req.Port); err != nil {
if err := h.db.AddMapping(req.SourcePort, req.TargetIP, req.TargetPort, req.UseTunnel); err != nil {
h.writeError(w, http.StatusInternalServerError, "保存映射失败: "+err.Error())
return
}
// 启动转发器
if err := h.forwarderMgr.Add(req.Port, req.TargetIP, req.Port); err != nil {
var err error
if req.UseTunnel {
// 隧道模式:使用隧道转发
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.SourcePort, h.tunnelServer)
} else {
// 直接模式直接TCP转发
err = h.forwarderMgr.Add(req.SourcePort, req.TargetIP, req.TargetPort)
}
if err != nil {
// 回滚数据库操作
h.db.RemoveMapping(req.Port)
h.db.RemoveMapping(req.SourcePort)
h.writeError(w, http.StatusInternalServerError, "启动转发失败: "+err.Error())
return
}
log.Printf("创建端口映射: %d -> %s:%d", req.Port, req.TargetIP, req.Port)
log.Printf("创建端口映射: %d -> %s:%d (tunnel: %v)", req.SourcePort, req.TargetIP, req.TargetPort, req.UseTunnel)
h.writeSuccess(w, "端口映射创建成功", map[string]interface{}{
"port": req.Port,
"target_ip": req.TargetIP,
"use_tunnel": h.useTunnel,
"source_port": req.SourcePort,
"target_ip": req.TargetIP,
"target_port": req.TargetPort,
"use_tunnel": req.UseTunnel,
})
}
@ -187,9 +202,8 @@ func (h *Handler) handleListMappings(w http.ResponseWriter, r *http.Request) {
}
h.writeSuccess(w, "获取映射列表成功", map[string]interface{}{
"mappings": mappings,
"count": len(mappings),
"use_tunnel": h.useTunnel,
"mappings": mappings,
"count": len(mappings),
})
}
@ -197,11 +211,11 @@ func (h *Handler) handleListMappings(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
status := map[string]interface{}{
"status": "ok",
"tunnel_enabled": h.useTunnel,
"tunnel_enabled": h.tunnelServer != nil,
"tunnel_connected": false,
}
if h.useTunnel {
if h.tunnelServer != nil {
status["tunnel_connected"] = h.tunnelServer.IsConnected()
}

580
src/server/api/api_test.go Normal file
View File

@ -0,0 +1,580 @@
package api
import (
"bytes"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"port-forward/server/db"
"port-forward/server/forwarder"
"port-forward/server/tunnel"
"testing"
)
// setupTestHandler 创建测试用的 Handler
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
// 创建临时数据库
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
database, err := db.New(dbPath)
if err != nil {
t.Fatalf("创建数据库失败: %v", err)
}
// 创建转发器管理器
fwdMgr := forwarder.NewManager()
// 创建隧道服务器(如果启用)
var tunnelServer *tunnel.Server
if useTunnel {
// 使用随机端口
listener, _ := net.Listen("tcp", "127.0.0.1:0")
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
tunnelServer = tunnel.NewServer(port)
tunnelServer.Start()
}
handler := NewHandler(database, fwdMgr, tunnelServer, 10000, 20000)
cleanup := func() {
fwdMgr.StopAll()
if tunnelServer != nil {
tunnelServer.Stop()
}
database.Close()
os.RemoveAll(tmpDir)
}
return handler, database, cleanup
}
// TestNewHandler 测试创建处理器
func TestNewHandler(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
if handler == nil {
t.Fatal("创建处理器失败")
}
if handler.portRangeFrom != 10000 {
t.Errorf("起始端口不正确,期望 10000得到 %d", handler.portRangeFrom)
}
if handler.portRangeEnd != 20000 {
t.Errorf("结束端口不正确,期望 20000得到 %d", handler.portRangeEnd)
}
}
// TestHandleHealth 测试健康检查
func TestHandleHealth(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
handler.handleHealth(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result map[string]interface{}
err := json.NewDecoder(w.Body).Decode(&result)
if err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if result["status"] != "ok" {
t.Errorf("健康状态不正确,期望 ok得到 %v", result["status"])
}
}
// TestHandleHealthWithTunnel 测试带隧道的健康检查
func TestHandleHealthWithTunnel(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, true)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
handler.handleHealth(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result map[string]interface{}
json.NewDecoder(w.Body).Decode(&result)
if result["tunnel_enabled"] != true {
t.Error("隧道应该启用")
}
// 隧道未连接客户端时应该为 false
if result["tunnel_connected"] != false {
t.Error("隧道应该未连接")
}
}
// TestHandleCreateMapping 测试创建映射
func TestHandleCreateMapping(t *testing.T) {
handler, database, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetIP: "192.168.1.100",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
if !result.Success {
t.Errorf("创建映射失败: %s", result.Message)
}
// 验证数据库中存在映射
mapping, err := database.GetMapping(15000)
if err != nil {
t.Fatalf("获取映射失败: %v", err)
}
if mapping == nil {
t.Fatal("映射不存在")
}
if mapping.TargetIP != "192.168.1.100" {
t.Errorf("目标 IP 不正确,期望 192.168.1.100,得到 %s", mapping.TargetIP)
}
}
// TestHandleCreateMappingInvalidPort 测试创建映射时端口无效
func TestHandleCreateMappingInvalidPort(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
tests := []struct {
name string
port int
}{
{"端口太小", 5000},
{"端口太大", 25000},
{"端口为0", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqBody := CreateMappingRequest{
SourcePort: tt.port,
TargetPort: tt.port,
TargetIP: "192.168.1.100",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
})
}
}
// TestHandleCreateMappingDuplicate 测试创建重复映射
func TestHandleCreateMappingDuplicate(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetIP: "192.168.1.100",
}
// 第一次创建
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusOK {
t.Fatalf("第一次创建失败")
}
// 第二次创建(应该失败)
body, _ = json.Marshal(reqBody)
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
w = httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusConflict {
t.Errorf("状态码不正确,期望 409得到 %d", w.Code)
}
}
// TestHandleCreateMappingInvalidJSON 测试无效的 JSON
func TestHandleCreateMappingInvalidJSON(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
}
// TestHandleCreateMappingInvalidIP 测试无效的 IP 地址
func TestHandleCreateMappingInvalidIP(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetIP: "invalid-ip",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
}
// TestHandleCreateMappingEmptyIP 测试空 IP 地址
func TestHandleCreateMappingEmptyIP(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetIP: "",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
}
// TestHandleCreateMappingTunnelNotConnected 测试隧道未连接时创建映射
func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, true)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
UseTunnel: true, // 明确指定使用隧道模式
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("状态码不正确,期望 503得到 %d", w.Code)
}
}
// TestHandleRemoveMapping 测试删除映射
func TestHandleRemoveMapping(t *testing.T) {
handler, database, cleanup := setupTestHandler(t, false)
defer cleanup()
// 先创建一个映射
database.AddMapping(15000, "192.168.1.100", 15000, false)
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000)
reqBody := RemoveMappingRequest{
Port: 15000,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleRemoveMapping(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
// 验证映射已删除
mapping, _ := database.GetMapping(15000)
if mapping != nil {
t.Error("映射应该已被删除")
}
}
// TestHandleRemoveMappingNotExist 测试删除不存在的映射
func TestHandleRemoveMappingNotExist(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := RemoveMappingRequest{
Port: 15000,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleRemoveMapping(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("状态码不正确,期望 404得到 %d", w.Code)
}
}
// TestHandleListMappings 测试列出映射
func TestHandleListMappings(t *testing.T) {
handler, database, cleanup := setupTestHandler(t, false)
defer cleanup()
// 添加一些映射
database.AddMapping(15000, "192.168.1.100", 15000, false)
database.AddMapping(15001, "192.168.1.101", 15001, true)
database.AddMapping(15002, "192.168.1.102", 15002, false)
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
w := httptest.NewRecorder()
handler.handleListMappings(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
if !result.Success {
t.Errorf("列出映射失败: %s", result.Message)
}
data := result.Data.(map[string]interface{})
count := int(data["count"].(float64))
if count != 3 {
t.Errorf("映射数量不正确,期望 3得到 %d", count)
}
}
// TestHandleListMappingsEmpty 测试列出空映射列表
func TestHandleListMappingsEmpty(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
w := httptest.NewRecorder()
handler.handleListMappings(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
data := result.Data.(map[string]interface{})
count := int(data["count"].(float64))
if count != 0 {
t.Errorf("映射数量不正确,期望 0得到 %d", count)
}
}
// TestHandleMethodNotAllowed 测试不允许的 HTTP 方法
func TestHandleMethodNotAllowed(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
tests := []struct {
name string
handler func(http.ResponseWriter, *http.Request)
method string
}{
{"创建映射 GET", handler.handleCreateMapping, http.MethodGet},
{"删除映射 GET", handler.handleRemoveMapping, http.MethodGet},
{"列出映射 POST", handler.handleListMappings, http.MethodPost},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/test", nil)
w := httptest.NewRecorder()
tt.handler(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("状态码不正确,期望 405得到 %d", w.Code)
}
})
}
}
// TestRegisterRoutes 测试路由注册
func TestRegisterRoutes(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
mux := http.NewServeMux()
handler.RegisterRoutes(mux)
// 测试路由是否注册
routes := []string{
"/api/mapping/create",
"/api/mapping/remove",
"/api/mapping/list",
"/health",
}
for _, route := range routes {
req := httptest.NewRequest(http.MethodGet, route, nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
// 如果路由不存在,应该返回 404
if w.Code == http.StatusNotFound {
t.Errorf("路由 %s 未注册", route)
}
}
}
// TestWriteSuccess 测试成功响应
func TestWriteSuccess(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
w := httptest.NewRecorder()
handler.writeSuccess(w, "测试成功", map[string]string{"key": "value"})
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
if !result.Success {
t.Error("Success 应该为 true")
}
if result.Message != "测试成功" {
t.Errorf("消息不正确,期望 '测试成功',得到 '%s'", result.Message)
}
}
// TestWriteError 测试错误响应
func TestWriteError(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
w := httptest.NewRecorder()
handler.writeError(w, http.StatusBadRequest, "测试错误")
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
if result.Success {
t.Error("Success 应该为 false")
}
if result.Message != "测试错误" {
t.Errorf("消息不正确,期望 '测试错误',得到 '%s'", result.Message)
}
}
// BenchmarkHandleHealth 基准测试健康检查
func BenchmarkHandleHealth(b *testing.B) {
tmpDir := b.TempDir()
dbPath := filepath.Join(tmpDir, "bench.db")
database, _ := db.New(dbPath)
defer database.Close()
fwdMgr := forwarder.NewManager()
handler := NewHandler(database, fwdMgr, nil, 10000, 20000)
req := httptest.NewRequest(http.MethodGet, "/health", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
handler.handleHealth(w, req)
}
}
// BenchmarkHandleListMappings 基准测试列出映射
func BenchmarkHandleListMappings(b *testing.B) {
tmpDir := b.TempDir()
dbPath := filepath.Join(tmpDir, "bench.db")
database, _ := db.New(dbPath)
defer database.Close()
// 添加一些映射
for i := 0; i < 100; i++ {
useTunnel := i%2 == 0 // 偶数使用隧道模式
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel)
}
fwdMgr := forwarder.NewManager()
handler := NewHandler(database, fwdMgr, nil, 10000, 20000)
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
handler.handleListMappings(w, req)
}
}

View File

@ -57,6 +57,110 @@ database:
}
}
// TestLoadConfigFileNotFound 测试配置文件不存在
func TestLoadConfigFileNotFound(t *testing.T) {
_, err := Load("/nonexistent/config.yaml")
if err == nil {
t.Error("应该返回文件不存在错误")
}
}
// TestLoadConfigInvalidYAML 测试无效的 YAML 格式
func TestLoadConfigInvalidYAML(t *testing.T) {
tmpFile, err := os.CreateTemp("", "config_invalid_*.yaml")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(tmpFile.Name())
// 写入无效的 YAML
invalidYAML := `
port_range:
from: 10000
end: invalid_number
`
tmpFile.Write([]byte(invalidYAML))
tmpFile.Close()
_, err = Load(tmpFile.Name())
if err == nil {
t.Error("应该返回 YAML 解析错误")
}
}
// TestConfigEdgeCases 测试配置边界情况
func TestConfigEdgeCases(t *testing.T) {
tmpFile, err := os.CreateTemp("", "config_edge_*.yaml")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(tmpFile.Name())
// 测试刚好 10000 个端口(边界值)
configContent := `
port_range:
from: 1
end: 10000
tunnel:
enabled: false
api:
listen_port: 8080
database:
path: ./data/mappings.db
`
tmpFile.Write([]byte(configContent))
tmpFile.Close()
cfg, err := Load(tmpFile.Name())
if err != nil {
t.Errorf("边界值配置应该有效: %v", err)
}
if cfg != nil && (cfg.PortRange.End-cfg.PortRange.From) != 9999 {
t.Errorf("端口范围计算不正确")
}
}
// BenchmarkLoadConfig 基准测试配置加载
func BenchmarkLoadConfig(b *testing.B) {
tmpFile, _ := os.CreateTemp("", "config_bench_*.yaml")
defer os.Remove(tmpFile.Name())
configContent := `
port_range:
from: 10000
end: 20000
tunnel:
enabled: true
listen_port: 9000
api:
listen_port: 8080
database:
path: ./data/mappings.db
`
tmpFile.Write([]byte(configContent))
tmpFile.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = Load(tmpFile.Name())
}
}
// BenchmarkValidateConfig 基准测试配置验证
func BenchmarkValidateConfig(b *testing.B) {
cfg := &Config{
PortRange: PortRangeConfig{From: 10000, End: 20000},
Tunnel: TunnelConfig{Enabled: true, ListenPort: 9000},
API: APIConfig{ListenPort: 8080},
Database: DatabaseConfig{Path: "./data/db.sqlite"},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = cfg.Validate()
}
}
func TestValidateConfig(t *testing.T) {
tests := []struct {
name string

View File

@ -16,6 +16,7 @@ type Mapping struct {
SourcePort int `json:"source_port"`
TargetIP string `json:"target_ip"`
TargetPort int `json:"target_port"`
UseTunnel bool `json:"use_tunnel"`
CreatedAt string `json:"created_at"`
}
@ -55,12 +56,14 @@ func New(dbPath string) (*Database, error) {
// initTables 初始化数据库表
func (d *Database) initTables() error {
// 创建表结构
query := `
CREATE TABLE IF NOT EXISTS mappings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_port INTEGER NOT NULL UNIQUE,
target_ip TEXT NOT NULL,
target_port INTEGER NOT NULL,
use_tunnel BOOLEAN NOT NULL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port);
@ -71,16 +74,59 @@ func (d *Database) initTables() error {
return fmt.Errorf("初始化数据库表失败: %w", err)
}
// 检查是否需要迁移现有数据
if err := d.migrateDatabase(); err != nil {
return fmt.Errorf("数据库迁移失败: %w", err)
}
return nil
}
// migrateDatabase 迁移现有数据库结构
func (d *Database) migrateDatabase() error {
// 检查 use_tunnel 列是否存在
rows, err := d.db.Query("PRAGMA table_info(mappings)")
if err != nil {
return fmt.Errorf("检查表结构失败: %w", err)
}
defer rows.Close()
hasUseTunnel := false
for rows.Next() {
var cid int
var name, dataType string
var notNull, hasDefault int
var defaultValue interface{}
err := rows.Scan(&cid, &name, &dataType, &notNull, &defaultValue, &hasDefault)
if err != nil {
return fmt.Errorf("扫描表结构失败: %w", err)
}
if name == "use_tunnel" {
hasUseTunnel = true
break
}
}
// 如果不存在 use_tunnel 列,则添加它
if !hasUseTunnel {
_, err := d.db.Exec("ALTER TABLE mappings ADD COLUMN use_tunnel BOOLEAN NOT NULL DEFAULT 0")
if err != nil {
return fmt.Errorf("添加 use_tunnel 列失败: %w", err)
}
}
return nil
}
// AddMapping 添加端口映射
func (d *Database) AddMapping(sourcePort int, targetIP string, targetPort int) error {
func (d *Database) AddMapping(sourcePort int, targetIP string, targetPort int, useTunnel bool) error {
d.mu.Lock()
defer d.mu.Unlock()
query := `INSERT INTO mappings (source_port, target_ip, target_port) VALUES (?, ?, ?)`
_, err := d.db.Exec(query, sourcePort, targetIP, targetPort)
query := `INSERT INTO mappings (source_port, target_ip, target_port, use_tunnel) VALUES (?, ?, ?, ?)`
_, err := d.db.Exec(query, sourcePort, targetIP, targetPort, useTunnel)
if err != nil {
return fmt.Errorf("添加端口映射失败: %w", err)
}
@ -116,7 +162,7 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
d.mu.RLock()
defer d.mu.RUnlock()
query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings WHERE source_port = ?`
query := `SELECT id, source_port, target_ip, target_port, use_tunnel, created_at FROM mappings WHERE source_port = ?`
var mapping Mapping
err := d.db.QueryRow(query, sourcePort).Scan(
@ -124,6 +170,7 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
&mapping.SourcePort,
&mapping.TargetIP,
&mapping.TargetPort,
&mapping.UseTunnel,
&mapping.CreatedAt,
)
@ -142,7 +189,7 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
d.mu.RLock()
defer d.mu.RUnlock()
query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings ORDER BY source_port`
query := `SELECT id, source_port, target_ip, target_port, use_tunnel, created_at FROM mappings ORDER BY source_port`
rows, err := d.db.Query(query)
if err != nil {
@ -158,6 +205,7 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
&mapping.SourcePort,
&mapping.TargetIP,
&mapping.TargetPort,
&mapping.UseTunnel,
&mapping.CreatedAt,
); err != nil {
return nil, fmt.Errorf("扫描映射记录失败: %w", err)

View File

@ -18,7 +18,7 @@ func TestDatabase(t *testing.T) {
defer db.Close()
t.Run("添加映射", func(t *testing.T) {
err := db.AddMapping(10001, "192.168.1.100", 22)
err := db.AddMapping(10001, "192.168.1.100", 22, false)
if err != nil {
t.Errorf("添加映射失败: %v", err)
}
@ -44,7 +44,7 @@ func TestDatabase(t *testing.T) {
})
t.Run("添加重复映射应该失败", func(t *testing.T) {
err := db.AddMapping(10001, "192.168.1.101", 22)
err := db.AddMapping(10001, "192.168.1.101", 22, true)
if err == nil {
t.Error("添加重复映射应该失败")
}
@ -52,8 +52,8 @@ func TestDatabase(t *testing.T) {
t.Run("获取所有映射", func(t *testing.T) {
// 添加更多映射
db.AddMapping(10002, "192.168.1.101", 22)
db.AddMapping(10003, "192.168.1.102", 22)
db.AddMapping(10002, "192.168.1.101", 22, true)
db.AddMapping(10003, "192.168.1.102", 22, false)
mappings, err := db.GetAllMappings()
if err != nil {
@ -101,7 +101,8 @@ func TestDatabaseConcurrency(t *testing.T) {
done := make(chan bool)
for i := 0; i < 10; i++ {
go func(port int) {
err := db.AddMapping(10000+port, "192.168.1.100", port)
useTunnel := port%2 == 0 // 偶数端口使用隧道模式
err := db.AddMapping(10000+port, "192.168.1.100", port, useTunnel)
if err != nil {
t.Logf("添加映射失败 (端口 %d): %v", 10000+port, err)
}

View File

@ -10,17 +10,23 @@ import (
"time"
)
// TunnelServer 隧道服务器接口
type TunnelServer interface {
ForwardConnection(clientConn net.Conn, targetPort int) error
IsConnected() bool
}
// Forwarder 端口转发器
type Forwarder struct {
sourcePort int
targetAddr string
listener net.Listener
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
tunnelConn net.Conn
useTunnel bool
mu sync.RWMutex
sourcePort int
targetPort int
targetAddr string
listener net.Listener
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
tunnelServer TunnelServer
useTunnel bool
}
// NewForwarder 创建新的端口转发器
@ -28,6 +34,7 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
return &Forwarder{
sourcePort: sourcePort,
targetPort: targetPort,
targetAddr: fmt.Sprintf("%s:%d", targetIP, targetPort),
cancel: cancel,
ctx: ctx,
@ -36,15 +43,15 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
}
// NewTunnelForwarder 创建使用隧道的端口转发器
func NewTunnelForwarder(sourcePort int, targetPort int, tunnelConn net.Conn) *Forwarder {
func NewTunnelForwarder(sourcePort int, targetPort int, tunnelServer TunnelServer) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
return &Forwarder{
sourcePort: sourcePort,
targetAddr: fmt.Sprintf("127.0.0.1:%d", targetPort),
tunnelConn: tunnelConn,
useTunnel: true,
cancel: cancel,
ctx: ctx,
sourcePort: sourcePort,
targetPort: targetPort,
tunnelServer: tunnelServer,
useTunnel: true,
cancel: cancel,
ctx: ctx,
}
}
@ -100,36 +107,37 @@ func (f *Forwarder) acceptLoop() {
// handleConnection 处理单个连接
func (f *Forwarder) handleConnection(clientConn net.Conn) {
defer f.wg.Done()
defer clientConn.Close()
var targetConn net.Conn
var err error
if f.useTunnel {
// 使用隧道连接
f.mu.RLock()
targetConn = f.tunnelConn
f.mu.RUnlock()
// 使用隧道转发
if f.tunnelServer == nil || !f.tunnelServer.IsConnected() {
log.Printf("隧道服务器不可用 (端口 %d)", f.sourcePort)
clientConn.Close()
return
}
if targetConn == nil {
log.Printf("隧道连接不可用 (端口 %d)", f.sourcePort)
return
// 将连接转发到隧道ForwardConnection 会处理连接关闭
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetPort); err != nil {
log.Printf("隧道转发失败 (端口 %d -> %d): %v", f.sourcePort, f.targetPort, err)
}
} else {
// 直接连接目标
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
targetConn, err = dialer.DialContext(f.ctx, "tcp", f.targetAddr)
if err != nil {
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, f.targetAddr, err)
return
}
defer targetConn.Close()
return
}
// 直接连接目标
defer clientConn.Close()
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
targetConn, err := dialer.DialContext(f.ctx, "tcp", f.targetAddr)
if err != nil {
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, f.targetAddr, err)
return
}
defer targetConn.Close()
// 双向转发
errChan := make(chan error, 2)
@ -181,13 +189,6 @@ func (f *Forwarder) Stop() error {
return nil
}
// SetTunnelConn 设置隧道连接
func (f *Forwarder) SetTunnelConn(conn net.Conn) {
f.mu.Lock()
defer f.mu.Unlock()
f.tunnelConn = conn
}
// Manager 转发器管理器
type Manager struct {
forwarders map[int]*Forwarder
@ -220,7 +221,7 @@ func (m *Manager) Add(sourcePort int, targetIP string, targetPort int) error {
}
// AddTunnel 添加使用隧道的转发器
func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelConn net.Conn) error {
func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelServer TunnelServer) error {
m.mu.Lock()
defer m.mu.Unlock()
@ -228,7 +229,7 @@ func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelConn net.Conn)
return fmt.Errorf("端口 %d 已被占用", sourcePort)
}
forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelConn)
forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelServer)
if err := forwarder.Start(); err != nil {
return err
}

View File

@ -0,0 +1,432 @@
package forwarder
import (
"fmt"
"io"
"net"
"testing"
"time"
)
// TestNewForwarder 测试创建转发器
func TestNewForwarder(t *testing.T) {
fwd := NewForwarder(8080, "192.168.1.100", 80)
if fwd == nil {
t.Fatal("创建转发器失败")
}
if fwd.sourcePort != 8080 {
t.Errorf("源端口不正确,期望 8080得到 %d", fwd.sourcePort)
}
if fwd.targetAddr != "192.168.1.100:80" {
t.Errorf("目标地址不正确,期望 192.168.1.100:80得到 %s", fwd.targetAddr)
}
if fwd.useTunnel {
t.Error("普通转发器不应使用隧道")
}
}
// mockTunnelServer 模拟隧道服务器
type mockTunnelServer struct {
connected bool
}
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetPort int) error {
// 简单的模拟实现
defer clientConn.Close()
return nil
}
func (m *mockTunnelServer) IsConnected() bool {
return m.connected
}
// TestNewTunnelForwarder 测试创建隧道转发器
func TestNewTunnelForwarder(t *testing.T) {
// 创建模拟隧道服务器
mockServer := &mockTunnelServer{connected: true}
fwd := NewTunnelForwarder(8080, 80, mockServer)
if fwd == nil {
t.Fatal("创建隧道转发器失败")
}
if !fwd.useTunnel {
t.Error("隧道转发器应使用隧道")
}
if fwd.tunnelServer == nil {
t.Error("隧道服务器未设置")
}
}
// TestForwarderStartStop 测试转发器启动和停止
func TestForwarderStartStop(t *testing.T) {
// 创建模拟目标服务器
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建目标服务器失败: %v", err)
}
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 启动转发器到一个随机端口
fwd := NewForwarder(0, "127.0.0.1", targetPort)
// 创建监听器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
fwd.listener = listener
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
// 启动接受循环
fwd.wg.Add(1)
go fwd.acceptLoop()
// 等待一段时间
time.Sleep(100 * time.Millisecond)
// 停止转发器
err = fwd.Stop()
if err != nil {
t.Errorf("停止转发器失败: %v", err)
}
}
// TestForwarderConnection 测试转发器连接处理
func TestForwarderConnection(t *testing.T) {
// 创建模拟目标服务器
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建目标服务器失败: %v", err)
}
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 在后台处理连接
go func() {
conn, err := targetListener.Accept()
if err != nil {
return
}
defer conn.Close()
// 回显服务器
io.Copy(conn, conn)
}()
// 创建并启动转发器
fwd := NewForwarder(0, "127.0.0.1", targetPort)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
fwd.listener = listener
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
fwd.wg.Add(1)
go fwd.acceptLoop()
defer fwd.Stop()
time.Sleep(100 * time.Millisecond)
// 连接到转发器
client, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", fwd.sourcePort))
if err != nil {
t.Fatalf("连接转发器失败: %v", err)
}
defer client.Close()
// 发送数据
testData := []byte("Hello, World!")
_, err = client.Write(testData)
if err != nil {
t.Fatalf("发送数据失败: %v", err)
}
// 读取响应
buf := make([]byte, len(testData))
client.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := io.ReadFull(client, buf)
if err != nil {
t.Fatalf("读取响应失败: %v", err)
}
if n != len(testData) {
t.Errorf("读取数据长度不正确,期望 %d得到 %d", len(testData), n)
}
if string(buf) != string(testData) {
t.Errorf("数据不匹配,期望 %s得到 %s", testData, buf)
}
}
// TestManagerAdd 测试管理器添加转发器
func TestManagerAdd(t *testing.T) {
mgr := NewManager()
// 创建模拟目标服务器
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建目标服务器失败: %v", err)
}
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 添加转发器到一个随机可用端口
fwdListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建转发监听器失败: %v", err)
}
sourcePort := fwdListener.Addr().(*net.TCPAddr).Port
fwdListener.Close() // 关闭以便转发器可以使用这个端口
err = mgr.Add(sourcePort, "127.0.0.1", targetPort)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
// 验证转发器已添加
if !mgr.Exists(sourcePort) {
t.Error("转发器应该存在")
}
// 清理
mgr.Remove(sourcePort)
}
// TestManagerAddDuplicate 测试添加重复转发器
func TestManagerAddDuplicate(t *testing.T) {
mgr := NewManager()
// 获取一个随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
sourcePort := listener.Addr().(*net.TCPAddr).Port
listener.Close()
// 添加第一个转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80)
if err != nil {
t.Fatalf("添加第一个转发器失败: %v", err)
}
defer mgr.Remove(sourcePort)
// 尝试添加重复端口
err = mgr.Add(sourcePort, "127.0.0.1", 81)
if err == nil {
t.Error("应该返回端口已占用错误")
}
}
// TestManagerRemove 测试移除转发器
func TestManagerRemove(t *testing.T) {
mgr := NewManager()
// 获取一个随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
sourcePort := listener.Addr().(*net.TCPAddr).Port
listener.Close()
// 添加转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
// 移除转发器
err = mgr.Remove(sourcePort)
if err != nil {
t.Errorf("移除转发器失败: %v", err)
}
// 验证转发器已移除
if mgr.Exists(sourcePort) {
t.Error("转发器应该已被移除")
}
}
// TestManagerRemoveNonExistent 测试移除不存在的转发器
func TestManagerRemoveNonExistent(t *testing.T) {
mgr := NewManager()
err := mgr.Remove(9999)
if err == nil {
t.Error("应该返回转发器不存在错误")
}
}
// TestManagerExists 测试检查转发器是否存在
func TestManagerExists(t *testing.T) {
mgr := NewManager()
// 检查不存在的转发器
if mgr.Exists(8080) {
t.Error("转发器不应该存在")
}
// 获取一个随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
sourcePort := listener.Addr().(*net.TCPAddr).Port
listener.Close()
// 添加转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
defer mgr.Remove(sourcePort)
// 检查存在的转发器
if !mgr.Exists(sourcePort) {
t.Error("转发器应该存在")
}
}
// TestManagerStopAll 测试停止所有转发器
func TestManagerStopAll(t *testing.T) {
mgr := NewManager()
// 添加多个转发器
ports := make([]int, 0)
for i := 0; i < 3; i++ {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
err = mgr.Add(port, "127.0.0.1", 80+i)
if err != nil {
t.Fatalf("添加转发器 %d 失败: %v", i, err)
}
ports = append(ports, port)
}
// 停止所有转发器
mgr.StopAll()
// 验证所有转发器已停止
for _, port := range ports {
if mgr.Exists(port) {
t.Errorf("端口 %d 的转发器应该已停止", port)
}
}
}
// TestForwarderContextCancellation 测试上下文取消
func TestForwarderContextCancellation(t *testing.T) {
fwd := NewForwarder(0, "127.0.0.1", 80)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
fwd.listener = listener
fwd.wg.Add(1)
go fwd.acceptLoop()
// 取消上下文
fwd.cancel()
// 等待 goroutine 退出
done := make(chan struct{})
go func() {
fwd.wg.Wait()
close(done)
}()
select {
case <-done:
// 成功退出
case <-time.After(2 * time.Second):
t.Error("上下文取消后 goroutine 未退出")
}
}
// BenchmarkForwarderConnection 基准测试转发器连接
func BenchmarkForwarderConnection(b *testing.B) {
// 创建模拟目标服务器
targetListener, _ := net.Listen("tcp", "127.0.0.1:0")
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 后台处理连接
go func() {
for {
conn, err := targetListener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
io.Copy(io.Discard, c)
}(conn)
}
}()
// 创建转发器
fwd := NewForwarder(0, "127.0.0.1", targetPort)
listener, _ := net.Listen("tcp", "127.0.0.1:0")
fwd.listener = listener
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
fwd.wg.Add(1)
go fwd.acceptLoop()
defer fwd.Stop()
time.Sleep(100 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", fwd.sourcePort))
if err != nil {
b.Fatalf("连接失败: %v", err)
}
conn.Write([]byte("test"))
conn.Close()
}
}
// BenchmarkManagerOperations 基准测试管理器操作
func BenchmarkManagerOperations(b *testing.B) {
mgr := NewManager()
b.Run("Add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
listener, _ := net.Listen("tcp", "127.0.0.1:0")
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
mgr.Add(port, "127.0.0.1", 80)
}
})
b.Run("Exists", func(b *testing.B) {
for i := 0; i < b.N; i++ {
mgr.Exists(8080)
}
})
}

View File

@ -64,8 +64,22 @@ func main() {
continue
}
log.Printf("恢复端口映射: %d -> %s:%d", mapping.SourcePort, mapping.TargetIP, mapping.TargetPort)
if err := fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort); err != nil {
log.Printf("恢复端口映射: %d -> %s:%d (tunnel: %v)", mapping.SourcePort, mapping.TargetIP, mapping.TargetPort, mapping.UseTunnel)
var err error
if mapping.UseTunnel {
// 隧道模式:检查隧道服务器是否可用
if !cfg.Tunnel.Enabled || tunnelServer == nil {
log.Printf("警告: 端口 %d 需要隧道模式但隧道服务未启用,跳过", mapping.SourcePort)
continue
}
err = fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetPort, tunnelServer)
} else {
// 直接模式
err = fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort)
}
if err != nil {
log.Printf("警告: 启动端口 %d 的转发失败: %v", mapping.SourcePort, err)
}
}
@ -80,7 +94,6 @@ func main() {
tunnelServer,
cfg.PortRange.From,
cfg.PortRange.End,
cfg.Tunnel.Enabled,
)
// 启动 HTTP API 服务器
@ -107,7 +120,7 @@ func main() {
log.Println("\n接收到关闭信号正在优雅关闭...")
// 创建关闭上下文
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 停止所有转发器

View File

@ -11,50 +11,113 @@ import (
"time"
)
// Protocol 定义隧道协议
// 消息格式: [4字节长度][4字节端口][数据]
// Tunnel协议定义
// 消息格式: | 版本(1B) | 类型(1B) | 长度(4B) | 数据 |
const (
// HeaderSize 消息头大小(长度+端口)
HeaderSize = 8
// MaxPacketSize 最大包大小 (1MB)
// 协议版本
ProtocolVersion = 0x01
// 消息头大小
HeaderSize = 6 // 版本(1) + 类型(1) + 长度(4)
// 最大包大小 (1MB)
MaxPacketSize = 1024 * 1024
// 消息类型
MsgTypeConnectRequest = 0x01 // 连接请求
MsgTypeConnectResponse = 0x02 // 连接响应
MsgTypeData = 0x03 // 数据传输
MsgTypeClose = 0x04 // 关闭连接
MsgTypeKeepAlive = 0x05 // 心跳
// 连接响应状态
ConnStatusSuccess = 0x00 // 连接成功
ConnStatusFailed = 0x01 // 连接失败
// 超时设置
ConnectTimeout = 10 * time.Second // 连接超时
ReadTimeout = 30 * time.Second // 读取超时
)
// TunnelMessage 隧道消息
type TunnelMessage struct {
Version byte
Type byte
Length uint32
Data []byte
}
// ConnectRequestData 连接请求数据
type ConnectRequestData struct {
ConnID uint32 // 连接ID
TargetPort uint16 // 目标端口
}
// ConnectResponseData 连接响应数据
type ConnectResponseData struct {
ConnID uint32 // 连接ID
Status byte // 状态码
}
// DataMessage 数据消息
type DataMessage struct {
ConnID uint32 // 连接ID
Data []byte // 数据
}
// CloseMessage 关闭消息
type CloseMessage struct {
ConnID uint32 // 连接ID
}
// PendingConnection 待处理连接
type PendingConnection struct {
ID uint32
ClientConn net.Conn
TargetPort int
Created time.Time
ResponseChan chan bool // 用于接收连接响应
}
// ActiveConnection 活跃连接
type ActiveConnection struct {
ID uint32
ClientConn net.Conn
TargetPort int
Created time.Time
}
// Server 内网穿透服务器
type Server struct {
listenPort int
listener net.Listener
client net.Conn
tunnelConn net.Conn
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
mu sync.RWMutex
// 连接管理
connections map[uint32]*Connection
connMu sync.RWMutex
nextConnID uint32
}
// Connection 表示一个客户端连接
type Connection struct {
ID uint32
TargetPort int
ClientConn net.Conn
readChan chan []byte
writeChan chan []byte
closeChan chan struct{}
pendingConns map[uint32]*PendingConnection // 待确认连接
activeConns map[uint32]*ActiveConnection // 活跃连接
connMu sync.RWMutex
nextConnID uint32
// 消息队列
sendChan chan *TunnelMessage
}
// NewServer 创建新的隧道服务器
func NewServer(listenPort int) *Server {
ctx, cancel := context.WithCancel(context.Background())
return &Server{
listenPort: listenPort,
cancel: cancel,
ctx: ctx,
connections: make(map[uint32]*Connection),
listenPort: listenPort,
cancel: cancel,
ctx: ctx,
pendingConns: make(map[uint32]*PendingConnection),
activeConns: make(map[uint32]*ActiveConnection),
sendChan: make(chan *TunnelMessage, 1000),
}
}
@ -102,42 +165,47 @@ func (s *Server) acceptLoop() {
// 只允许一个客户端连接
s.mu.Lock()
if s.client != nil {
if s.tunnelConn != nil {
log.Printf("拒绝额外的隧道连接: %s", conn.RemoteAddr())
conn.Close()
s.mu.Unlock()
continue
}
s.client = conn
s.tunnelConn = conn
s.mu.Unlock()
log.Printf("隧道客户端已连接: %s", conn.RemoteAddr())
s.wg.Add(1)
go s.handleClient(conn)
s.wg.Add(2)
go s.handleTunnelRead(conn)
go s.handleTunnelWrite(conn)
}
}
// handleClient 处理客户端连接
func (s *Server) handleClient(conn net.Conn) {
// handleTunnelRead 处理隧道读取
func (s *Server) handleTunnelRead(conn net.Conn) {
defer s.wg.Done()
defer func() {
conn.Close()
s.mu.Lock()
s.client = nil
s.tunnelConn = nil
s.mu.Unlock()
log.Printf("隧道客户端已断开")
// 关闭所有活动连接
s.connMu.Lock()
for _, c := range s.connections {
close(c.closeChan)
for _, c := range s.pendingConns {
c.ClientConn.Close()
close(c.ResponseChan)
}
s.connections = make(map[uint32]*Connection)
for _, c := range s.activeConns {
c.ClientConn.Close()
}
s.pendingConns = make(map[uint32]*PendingConnection)
s.activeConns = make(map[uint32]*ActiveConnection)
s.connMu.Unlock()
}()
// 读取来自客户端的数据
for {
select {
case <-s.ctx.Done():
@ -145,147 +213,374 @@ func (s *Server) handleClient(conn net.Conn) {
default:
}
// 读取消息头
header := make([]byte, HeaderSize)
if _, err := io.ReadFull(conn, header); err != nil {
msg, err := s.readTunnelMessage(conn)
if err != nil {
if err != io.EOF {
log.Printf("读取隧道消息失败: %v", err)
log.Printf("读取隧道消息失败: %v", err)
}
return
}
dataLen := binary.BigEndian.Uint32(header[0:4])
connID := binary.BigEndian.Uint32(header[4:8])
s.handleTunnelMessage(msg)
}
}
if dataLen > MaxPacketSize {
log.Printf("数据包过大: %d bytes", dataLen)
// handleTunnelWrite 处理隧道写入
func (s *Server) handleTunnelWrite(conn net.Conn) {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
}
// 读取数据
data := make([]byte, dataLen)
if _, err := io.ReadFull(conn, data); err != nil {
log.Printf("读取隧道数据失败: %v", err)
return
}
// 将数据发送到对应的连接
s.connMu.RLock()
connection, exists := s.connections[connID]
s.connMu.RUnlock()
if exists {
select {
case connection.readChan <- data:
case <-connection.closeChan:
case <-time.After(5 * time.Second):
log.Printf("向连接 %d 发送数据超时", connID)
case msg := <-s.sendChan:
if err := s.writeTunnelMessage(conn, msg); err != nil {
log.Printf("写入隧道消息失败: %v", err)
return
}
}
}
}
// ForwardConnection 转发连接到隧道
// readTunnelMessage 读取隧道消息
func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
// 读取消息头
header := make([]byte, HeaderSize)
if _, err := io.ReadFull(conn, header); err != nil {
return nil, err
}
version := header[0]
msgType := header[1]
dataLen := binary.BigEndian.Uint32(header[2:6])
if version != ProtocolVersion {
return nil, fmt.Errorf("不支持的协议版本: %d", version)
}
if dataLen > MaxPacketSize {
return nil, fmt.Errorf("数据包过大: %d bytes", dataLen)
}
// 读取数据
var data []byte
if dataLen > 0 {
data = make([]byte, dataLen)
if _, err := io.ReadFull(conn, data); err != nil {
return nil, err
}
}
return &TunnelMessage{
Version: version,
Type: msgType,
Length: dataLen,
Data: data,
}, nil
}
// writeTunnelMessage 写入隧道消息
func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
// 构建消息头
header := make([]byte, HeaderSize)
header[0] = msg.Version
header[1] = msg.Type
binary.BigEndian.PutUint32(header[2:6], msg.Length)
// 写入消息头
if _, err := conn.Write(header); err != nil {
return err
}
// 写入数据
if msg.Length > 0 && msg.Data != nil {
if _, err := conn.Write(msg.Data); err != nil {
return err
}
}
return nil
}
// handleTunnelMessage 处理隧道消息
func (s *Server) handleTunnelMessage(msg *TunnelMessage) {
switch msg.Type {
case MsgTypeConnectResponse:
s.handleConnectResponse(msg)
case MsgTypeData:
s.handleDataMessage(msg)
case MsgTypeClose:
s.handleCloseMessage(msg)
case MsgTypeKeepAlive:
s.handleKeepAlive(msg)
default:
log.Printf("未知消息类型: %d", msg.Type)
}
}
// handleConnectResponse 处理连接响应
func (s *Server) handleConnectResponse(msg *TunnelMessage) {
if len(msg.Data) < 5 {
log.Printf("连接响应数据太短")
return
}
connID := binary.BigEndian.Uint32(msg.Data[0:4])
status := msg.Data[4]
s.connMu.Lock()
pending, exists := s.pendingConns[connID]
if !exists {
s.connMu.Unlock()
log.Printf("收到未知连接的响应: %d", connID)
return
}
delete(s.pendingConns, connID)
s.connMu.Unlock()
if status == ConnStatusSuccess {
// 连接成功,移到活跃连接
active := &ActiveConnection{
ID: connID,
ClientConn: pending.ClientConn,
TargetPort: pending.TargetPort,
Created: time.Now(),
}
s.connMu.Lock()
s.activeConns[connID] = active
s.connMu.Unlock()
log.Printf("连接已建立: ID=%d, 端口=%d", connID, pending.TargetPort)
// 启动数据转发
s.wg.Add(1)
go s.forwardData(active)
// 通知等待的goroutine
select {
case pending.ResponseChan <- true:
default:
}
} else {
// 连接失败
log.Printf("连接失败: ID=%d, 状态=%d", connID, status)
pending.ClientConn.Close()
// 通知等待的goroutine
select {
case pending.ResponseChan <- false:
default:
}
}
close(pending.ResponseChan)
}
// handleDataMessage 处理数据消息
func (s *Server) handleDataMessage(msg *TunnelMessage) {
if len(msg.Data) < 4 {
log.Printf("数据消息太短")
return
}
connID := binary.BigEndian.Uint32(msg.Data[0:4])
data := msg.Data[4:]
s.connMu.RLock()
active, exists := s.activeConns[connID]
s.connMu.RUnlock()
if !exists {
log.Printf("收到未知连接的数据: %d", connID)
return
}
// 写入到客户端连接
if _, err := active.ClientConn.Write(data); err != nil {
log.Printf("写入客户端连接失败 (ID=%d): %v", connID, err)
s.closeConnection(connID)
}
}
// handleCloseMessage 处理关闭消息
func (s *Server) handleCloseMessage(msg *TunnelMessage) {
if len(msg.Data) < 4 {
log.Printf("关闭消息数据太短")
return
}
connID := binary.BigEndian.Uint32(msg.Data[0:4])
s.closeConnection(connID)
}
// handleKeepAlive 处理心跳消息
func (s *Server) handleKeepAlive(msg *TunnelMessage) {
// 回应心跳
response := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeKeepAlive,
Length: 0,
Data: nil,
}
select {
case s.sendChan <- response:
default:
log.Printf("发送心跳响应失败: 发送队列已满")
}
}
// forwardData 转发数据
func (s *Server) forwardData(active *ActiveConnection) {
defer s.wg.Done()
defer s.closeConnection(active.ID)
buffer := make([]byte, 32*1024)
for {
select {
case <-s.ctx.Done():
return
default:
}
active.ClientConn.SetReadDeadline(time.Now().Add(ReadTimeout))
n, err := active.ClientConn.Read(buffer)
if err != nil {
if err != io.EOF && !isTimeout(err) {
log.Printf("读取客户端连接失败 (ID=%d): %v", active.ID, err)
}
return
}
// 发送数据到隧道
dataMsg := make([]byte, 4+n)
binary.BigEndian.PutUint32(dataMsg[0:4], active.ID)
copy(dataMsg[4:], buffer[:n])
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(dataMsg)),
Data: dataMsg,
}
select {
case s.sendChan <- msg:
case <-time.After(5 * time.Second):
log.Printf("发送数据超时 (ID=%d)", active.ID)
return
case <-s.ctx.Done():
return
}
}
}
// closeConnection 关闭连接
func (s *Server) closeConnection(connID uint32) {
s.connMu.Lock()
active, exists := s.activeConns[connID]
if exists {
delete(s.activeConns, connID)
active.ClientConn.Close()
}
s.connMu.Unlock()
// 发送关闭消息
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeClose,
Length: 4,
Data: closeData,
}
select {
case s.sendChan <- msg:
default:
// 发送队列满,忽略
}
if exists {
log.Printf("连接已关闭: ID=%d", connID)
}
}
// ForwardConnection 转发连接到隧道(新的透明代理实现)
func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
s.mu.RLock()
tunnelConn := s.client
tunnelConnected := s.tunnelConn != nil
s.mu.RUnlock()
if tunnelConn == nil {
if !tunnelConnected {
return fmt.Errorf("隧道连接不可用")
}
// 创建连接对象
// 创建待处理连接
s.connMu.Lock()
connID := s.nextConnID
s.nextConnID++
connection := &Connection{
ID: connID,
TargetPort: targetPort,
ClientConn: clientConn,
readChan: make(chan []byte, 100),
writeChan: make(chan []byte, 100),
closeChan: make(chan struct{}),
pending := &PendingConnection{
ID: connID,
ClientConn: clientConn,
TargetPort: targetPort,
Created: time.Now(),
ResponseChan: make(chan bool, 1),
}
s.connections[connID] = connection
s.pendingConns[connID] = pending
s.connMu.Unlock()
defer func() {
s.connMu.Lock()
delete(s.connections, connID)
s.connMu.Unlock()
close(connection.closeChan)
clientConn.Close()
}()
// 发送连接请求
reqData := make([]byte, 6)
binary.BigEndian.PutUint32(reqData[0:4], connID)
binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort))
// 启动读写协程
errChan := make(chan error, 2)
// 从客户端读取并发送到隧道
go func() {
buffer := make([]byte, 32*1024)
for {
select {
case <-connection.closeChan:
return
default:
}
clientConn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := clientConn.Read(buffer)
if err != nil {
errChan <- err
return
}
// 发送到隧道
data := make([]byte, HeaderSize+n)
binary.BigEndian.PutUint32(data[0:4], uint32(n))
binary.BigEndian.PutUint32(data[4:8], connID)
copy(data[HeaderSize:], buffer[:n])
s.mu.RLock()
_, err = tunnelConn.Write(data)
s.mu.RUnlock()
if err != nil {
errChan <- err
return
}
}
}()
// 从隧道读取并发送到客户端
go func() {
for {
select {
case data := <-connection.readChan:
if _, err := clientConn.Write(data); err != nil {
errChan <- err
return
}
case <-connection.closeChan:
return
}
}
}()
// 等待错误或关闭
select {
case <-errChan:
case <-connection.closeChan:
case <-s.ctx.Done():
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeConnectRequest,
Length: 6,
Data: reqData,
}
return nil
select {
case s.sendChan <- msg:
case <-time.After(5 * time.Second):
s.connMu.Lock()
delete(s.pendingConns, connID)
s.connMu.Unlock()
return fmt.Errorf("发送连接请求超时")
}
log.Printf("发送连接请求: ID=%d, 端口=%d", connID, targetPort)
// 等待连接响应
select {
case success := <-pending.ResponseChan:
if success {
return nil // 连接建立成功forwardData会处理后续的数据转发
} else {
return fmt.Errorf("远程连接失败")
}
case <-time.After(ConnectTimeout):
s.connMu.Lock()
delete(s.pendingConns, connID)
s.connMu.Unlock()
clientConn.Close()
return fmt.Errorf("连接超时")
case <-s.ctx.Done():
return fmt.Errorf("服务器关闭")
}
}
// IsConnected 检查隧道是否已连接
func (s *Server) IsConnected() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.client != nil
return s.tunnelConn != nil
}
// Stop 停止隧道服务器
@ -297,8 +592,8 @@ func (s *Server) Stop() error {
}
s.mu.Lock()
if s.client != nil {
s.client.Close()
if s.tunnelConn != nil {
s.tunnelConn.Close()
}
s.mu.Unlock()
@ -317,4 +612,12 @@ func (s *Server) Stop() error {
}
return nil
}
// isTimeout 检查是否为超时错误
func isTimeout(err error) bool {
if netErr, ok := err.(net.Error); ok {
return netErr.Timeout()
}
return false
}

View File

@ -0,0 +1,500 @@
package tunnel
import (
"encoding/binary"
"io"
"net"
"sync"
"testing"
"time"
)
// TestNewServer 测试创建隧道服务器
func TestNewServer(t *testing.T) {
server := NewServer(9000)
if server == nil {
t.Fatal("创建隧道服务器失败")
}
if server.listenPort != 9000 {
t.Errorf("监听端口不正确,期望 9000得到 %d", server.listenPort)
}
if server.pendingConns == nil {
t.Error("待处理连接映射未初始化")
}
if server.activeConns == nil {
t.Error("活跃连接映射未初始化")
}
if server.sendChan == nil {
t.Error("发送通道未初始化")
}
if server.ctx == nil {
t.Error("上下文未初始化")
}
}
// TestServerStartStop 测试服务器启动和停止
func TestServerStartStop(t *testing.T) {
// 使用随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("获取随机端口失败: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
server := NewServer(port)
err = server.Start()
if err != nil {
t.Fatalf("启动服务器失败: %v", err)
}
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
// 验证服务器是否监听端口
conn, err := net.Dial("tcp", server.listener.Addr().String())
if err != nil {
t.Errorf("无法连接到服务器: %v", err)
} else {
conn.Close()
}
// 停止服务器
err = server.Stop()
if err != nil {
t.Errorf("停止服务器失败: %v", err)
}
}
// TestTunnelMessage 测试隧道消息的序列化和反序列化
func TestTunnelMessage(t *testing.T) {
// 创建测试消息
data := []byte("test data")
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(data)),
Data: data,
}
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
server := NewServer(9000)
// 测试写入消息
go func() {
err := server.writeTunnelMessage(serverConn, msg)
if err != nil {
t.Errorf("写入隧道消息失败: %v", err)
}
}()
// 测试读取消息
receivedMsg, err := server.readTunnelMessage(clientConn)
if err != nil {
t.Fatalf("读取隧道消息失败: %v", err)
}
// 验证消息内容
if receivedMsg.Version != msg.Version {
t.Errorf("版本不匹配,期望 %d得到 %d", msg.Version, receivedMsg.Version)
}
if receivedMsg.Type != msg.Type {
t.Errorf("类型不匹配,期望 %d得到 %d", msg.Type, receivedMsg.Type)
}
if receivedMsg.Length != msg.Length {
t.Errorf("长度不匹配,期望 %d得到 %d", msg.Length, receivedMsg.Length)
}
if string(receivedMsg.Data) != string(msg.Data) {
t.Errorf("数据不匹配,期望 %s得到 %s", string(msg.Data), string(receivedMsg.Data))
}
}
// TestConnectRequest 测试连接请求处理
func TestConnectRequest(t *testing.T) {
// 启动一个测试HTTP服务器
testListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("启动测试服务器失败: %v", err)
}
defer testListener.Close()
testPort := testListener.Addr().(*net.TCPAddr).Port
// 启动一个简单的echo服务
go func() {
for {
conn, err := testListener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
io.Copy(c, c) // echo服务
}(conn)
}
}()
// 模拟测试(实际测试需要更复杂的设置)
t.Logf("测试服务器运行在端口: %d", testPort)
// 创建隧道服务器,使用随机端口
tunnelListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建隧道监听器失败: %v", err)
}
defer tunnelListener.Close()
tunnelPort := tunnelListener.Addr().(*net.TCPAddr).Port
tunnelListener.Close() // 关闭以便服务器重新绑定
server := NewServer(tunnelPort)
err = server.Start()
if err != nil {
t.Fatalf("启动隧道服务器失败: %v", err)
}
defer server.Stop()
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
// 模拟客户端连接
tunnelConn, err := net.Dial("tcp", server.listener.Addr().String())
if err != nil {
t.Fatalf("连接隧道服务器失败: %v", err)
}
defer tunnelConn.Close()
// 等待隧道连接被处理
time.Sleep(100 * time.Millisecond)
// 验证隧道是否已连接
if !server.IsConnected() {
t.Error("隧道未连接")
}
}
// TestProtocolVersionCheck 测试协议版本检查
func TestProtocolVersionCheck(t *testing.T) {
server := NewServer(9000)
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
// 发送错误版本的消息
wrongVersionHeader := make([]byte, HeaderSize)
wrongVersionHeader[0] = 0xFF // 错误版本
wrongVersionHeader[1] = MsgTypeData
binary.BigEndian.PutUint32(wrongVersionHeader[2:6], 0)
go func() {
clientConn.Write(wrongVersionHeader)
}()
// 尝试读取消息,应该返回错误
_, err := server.readTunnelMessage(serverConn)
if err == nil {
t.Error("期望版本检查失败,但成功了")
}
if err.Error() != "不支持的协议版本: 255" {
t.Errorf("错误消息不正确: %v", err)
}
}
// TestMaxPacketSizeCheck 测试最大包大小检查
func TestMaxPacketSizeCheck(t *testing.T) {
server := NewServer(9000)
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
// 发送超大包
oversizedHeader := make([]byte, HeaderSize)
oversizedHeader[0] = ProtocolVersion
oversizedHeader[1] = MsgTypeData
binary.BigEndian.PutUint32(oversizedHeader[2:6], MaxPacketSize+1)
go func() {
clientConn.Write(oversizedHeader)
}()
// 尝试读取消息,应该返回错误
_, err := server.readTunnelMessage(serverConn)
if err == nil {
t.Error("期望包大小检查失败,但成功了")
}
}
// TestConcurrentConnections 测试并发连接处理
func TestConcurrentConnections(t *testing.T) {
// 使用随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("获取随机端口失败: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
server := NewServer(port)
err = server.Start()
if err != nil {
t.Fatalf("启动服务器失败: %v", err)
}
defer server.Stop()
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
var wg sync.WaitGroup
connCount := 5
// 并发创建多个连接
for i := 0; i < connCount; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
conn, err := net.Dial("tcp", server.listener.Addr().String())
if err != nil {
t.Errorf("连接 %d 失败: %v", id, err)
return
}
defer conn.Close()
// 保持连接一段时间
time.Sleep(200 * time.Millisecond)
}(i)
}
wg.Wait()
// 验证只有一个隧道连接被接受
if server.IsConnected() {
// 应该只有一个连接被接受
t.Log("隧道连接已建立")
}
}
// TestKeepAlive 测试心跳消息
func TestKeepAlive(t *testing.T) {
server := NewServer(9000)
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
// 创建心跳消息
keepAliveMsg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeKeepAlive,
Length: 0,
Data: nil,
}
// 启动消息处理
go func() {
for {
msg, err := server.readTunnelMessage(clientConn)
if err != nil {
return
}
server.handleTunnelMessage(msg)
}
}()
// 启动发送处理
go func() {
for {
select {
case msg := <-server.sendChan:
server.writeTunnelMessage(serverConn, msg)
case <-time.After(time.Second):
return
}
}
}()
// 发送心跳
err := server.writeTunnelMessage(serverConn, keepAliveMsg)
if err != nil {
t.Fatalf("发送心跳失败: %v", err)
}
// 读取响应
response, err := server.readTunnelMessage(clientConn)
if err != nil {
t.Fatalf("读取心跳响应失败: %v", err)
}
if response.Type != MsgTypeKeepAlive {
t.Errorf("心跳响应类型不正确,期望 %d得到 %d", MsgTypeKeepAlive, response.Type)
}
}
// MockClient 模拟客户端用于测试
type MockClient struct {
conn net.Conn
}
func (mc *MockClient) sendConnectResponse(connID uint32, status byte) error {
responseData := make([]byte, 5)
binary.BigEndian.PutUint32(responseData[0:4], connID)
responseData[4] = status
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeConnectResponse,
Length: 5,
Data: responseData,
}
// 构建消息头
header := make([]byte, HeaderSize)
header[0] = msg.Version
header[1] = msg.Type
binary.BigEndian.PutUint32(header[2:6], msg.Length)
// 写入消息
if _, err := mc.conn.Write(header); err != nil {
return err
}
if msg.Length > 0 && msg.Data != nil {
if _, err := mc.conn.Write(msg.Data); err != nil {
return err
}
}
return nil
}
// TestForwardConnection 测试连接转发
func TestForwardConnection(t *testing.T) {
// 启动测试目标服务
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("启动目标服务失败: %v", err)
}
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 启动简单的echo服务
go func() {
for {
conn, err := targetListener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
io.Copy(c, c)
}(conn)
}
}()
// 创建隧道服务器,使用随机端口
tunnelListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建隧道监听器失败: %v", err)
}
tunnelPort := tunnelListener.Addr().(*net.TCPAddr).Port
tunnelListener.Close() // 关闭以便服务器重新绑定
server := NewServer(tunnelPort)
err = server.Start()
if err != nil {
t.Fatalf("启动隧道服务器失败: %v", err)
}
defer server.Stop()
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
// 连接到隧道服务器(模拟客户端)
tunnelConn, err := net.Dial("tcp", server.listener.Addr().String())
if err != nil {
t.Fatalf("连接隧道服务器失败: %v", err)
}
defer tunnelConn.Close()
// 等待隧道连接建立
time.Sleep(100 * time.Millisecond)
if !server.IsConnected() {
t.Fatal("隧道未连接")
}
// 创建模拟客户端连接
clientConn, serverSideConn := net.Pipe()
defer clientConn.Close()
defer serverSideConn.Close()
// 创建模拟客户端
mockClient := &MockClient{conn: tunnelConn}
// 启动连接转发
go func() {
err := server.ForwardConnection(serverSideConn, targetPort)
if err != nil {
t.Errorf("转发连接失败: %v", err)
}
}()
// 读取连接请求
header := make([]byte, HeaderSize)
_, err = io.ReadFull(tunnelConn, header)
if err != nil {
t.Fatalf("读取连接请求头失败: %v", err)
}
if header[1] != MsgTypeConnectRequest {
t.Fatalf("期望连接请求,得到消息类型: %d", header[1])
}
dataLen := binary.BigEndian.Uint32(header[2:6])
data := make([]byte, dataLen)
_, err = io.ReadFull(tunnelConn, data)
if err != nil {
t.Fatalf("读取连接请求数据失败: %v", err)
}
connID := binary.BigEndian.Uint32(data[0:4])
requestedPort := binary.BigEndian.Uint16(data[4:6])
if int(requestedPort) != targetPort {
t.Errorf("请求端口不匹配,期望 %d得到 %d", targetPort, requestedPort)
}
// 发送连接成功响应
err = mockClient.sendConnectResponse(connID, ConnStatusSuccess)
if err != nil {
t.Fatalf("发送连接响应失败: %v", err)
}
// 等待连接建立
time.Sleep(100 * time.Millisecond)
t.Log("连接转发测试完成")
}

225
src/test/README.md Normal file
View File

@ -0,0 +1,225 @@
# 集成测试文档
本目录包含 go-tunnel 项目的集成测试,用于测试服务器和客户端之间的端口转发功能。
## 前置条件
在运行测试之前,需要:
1. **启动服务器**
```bash
cd src
make run-server
# 或者
./bin/server -config config.yaml
```
2. **启动客户端**
```bash
cd src
make run-client
# 或者
./bin/client -server localhost:9000
```
## 测试配置
测试使用的默认配置如下(可以在 `integration_test.go` 中修改):
```go
ServerAPIAddr: "http://localhost:8080", // 服务器 API 地址
ServerTunnelPort: 9000, // 服务器隧道端口
TestPort: 30001, // 测试使用的端口
LocalServicePort: 8888, // 本地 echo 服务端口
```
如果你的服务器使用不同的端口,请修改这些配置。
## 运行测试
### 运行所有测试
```bash
cd src/test
go test -v
```
### 运行特定测试
```bash
# 测试基本转发功能
go test -v -run TestForwardingBasic
# 测试多端口转发
go test -v -run TestMultipleForwards
# 测试动态添加/删除映射
go test -v -run TestAddAndRemoveMapping
# 测试并发请求
go test -v -run TestConcurrentRequests
# 测试映射列表
go test -v -run TestListMappings
# 测试健康检查
go test -v -run TestHealthCheck
```
## 测试说明
### 1. TestForwardingBasic - 基本转发功能测试
测试流程:
1. 在本地 8888 端口启动一个 echo 服务器
2. 通过 API 创建端口映射30001 -> 8888
3. 向服务器的 30001 端口发送消息
4. 验证能否通过隧道转发到客户端的 8888 端口并收到响应
**预期结果**:发送 "Hello, Tunnel!" 能收到 "ECHO: Hello, Tunnel!"
### 2. TestMultipleForwards - 多端口转发测试
测试流程:
1. 创建多个端口映射30002, 30003, 30004
2. 同时向这些端口发送请求
3. 验证所有端口都能正确转发
**预期结果**:所有端口都能正常转发并收到正确响应
### 3. TestAddAndRemoveMapping - 动态管理测试
测试流程:
1. 创建端口映射并验证工作
2. 删除端口映射并验证无法连接
3. 重新创建映射并验证恢复工作
**预期结果**:映射的创建、删除、重建都能正确工作
### 4. TestConcurrentRequests - 并发请求测试
测试流程:
1. 创建一个端口映射
2. 并发发送 10 个请求
3. 验证所有请求都能正确处理
**预期结果**10 个并发请求全部成功
### 5. TestListMappings - 映射列表测试
测试流程:
1. 创建多个端口映射
2. 通过 API 获取映射列表
3. 验证所有创建的映射都在列表中
**预期结果**API 返回完整的映射列表
### 6. TestHealthCheck - 健康检查测试
测试流程:
1. 向服务器发送健康检查请求
**预期结果**:返回 200 状态码
## 测试端口使用
测试使用以下端口,请确保这些端口未被占用:
- 8888: 本地 echo 测试服务
- 30001: TestForwardingBasic
- 30002-30004: TestMultipleForwards
- 30005: TestAddAndRemoveMapping
- 30006: TestConcurrentRequests
- 30010-30012: TestListMappings
## 自定义测试服务器地址
如果你的服务器运行在不同的地址,可以在测试代码中修改 `defaultConfig`
```go
var defaultConfig = TestConfig{
ServerAPIAddr: "http://your-server:8080", // 修改这里
ServerTunnelPort: 9000,
TestPort: 30001,
LocalServicePort: 8888,
}
```
或者创建环境变量:
```bash
export TEST_SERVER_API="http://your-server:8080"
export TEST_SERVER_TUNNEL_PORT="9000"
```
## 故障排查
### 测试失败:连接被拒绝
**原因**:服务器或客户端未启动
**解决**:确保服务器和客户端都在运行
### 测试失败:隧道未连接
**原因**:客户端未成功连接到服务器
**解决**
1. 检查客户端日志,确认连接成功
2. 检查服务器地址配置是否正确
### 测试失败:端口已被占用
**原因**:测试端口被其他程序占用
**解决**
1. 使用 `netstat -tuln | grep <port>` 查看端口占用
2. 关闭占用端口的程序
3. 或在测试代码中修改端口号
### 测试超时
**原因**:网络延迟或服务响应慢
**解决**
1. 增加测试中的 `time.Sleep` 时间
2. 检查服务器和客户端是否正常运行
3. 查看服务器和客户端日志
## 示例输出
成功运行测试的输出示例:
```
=== RUN TestForwardingBasic
integration_test.go:54: 启动本地测试服务...
integration_test.go:493: Echo 服务器启动在端口 8888
integration_test.go:63: 创建端口映射: 30001 -> localhost:8888
integration_test.go:68: 端口映射创建成功
integration_test.go:82: 发送测试消息: Hello, Tunnel!
integration_test.go:95: ✓ 转发成功,收到响应: ECHO: Hello, Tunnel!
integration_test.go:73: 清理端口映射...
--- PASS: TestForwardingBasic (1.52s)
=== RUN TestHealthCheck
integration_test.go:346: ✓ 健康检查成功: {"success":true,"message":"服务器运行正常"}
--- PASS: TestHealthCheck (0.01s)
PASS
ok test 1.534s
```
## 扩展测试
你可以根据需要添加更多测试用例:
1. 测试不同大小的数据传输
2. 测试长时间连接
3. 测试异常情况处理
4. 测试性能和吞吐量
5. 测试错误恢复机制
## 注意事项
1. 测试会自动清理创建的端口映射,但如果测试中断可能需要手动清理
2. 每个测试都是独立的,可以单独运行
3. 测试内置了 echo 服务器,不需要额外准备测试服务
4. 建议在开发环境运行测试,避免影响生产环境

View File

@ -0,0 +1,554 @@
package test
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"testing"
"time"
)
// TestConfig 测试配置
type TestConfig struct {
ServerAPIAddr string // 服务器 API 地址
ServerTunnelPort int // 服务器隧道端口
TestPort int // 测试用的端口
LocalServicePort int // 本地服务端口(客户端监听)
}
// 默认测试配置
var defaultConfig = TestConfig{
ServerAPIAddr: "http://localhost:8080",
ServerTunnelPort: 9000,
TestPort: 30001,
LocalServicePort: 8888,
}
// CreateMappingRequest 创建映射请求
type CreateMappingRequest struct {
Port int `json:"port"`
TargetIP string `json:"target_ip,omitempty"`
}
// RemoveMappingRequest 删除映射请求
type RemoveMappingRequest struct {
Port int `json:"port"`
}
// Response API 响应
type Response struct {
Success bool `json:"success"`
Message string `json:"message"`
Data map[string]interface{} `json:"data,omitempty"`
}
// TestForwardingBasic 测试基本转发功能
// 前置条件:服务器和客户端已启动,客户端在本地 8888 端口运行了一个 echo 服务
func TestForwardingBasic(t *testing.T) {
config := defaultConfig
// 1. 启动本地测试服务(模拟客户端本地服务)
t.Log("启动本地测试服务...")
stopChan := make(chan struct{})
go startEchoServer(config.LocalServicePort, stopChan, t)
defer close(stopChan)
// 等待服务启动
time.Sleep(1 * time.Second)
// 2. 创建端口映射
t.Logf("创建端口映射: %d -> localhost:%d", config.TestPort, config.LocalServicePort)
err := createMapping(config.ServerAPIAddr, config.TestPort, "")
if err != nil {
t.Fatalf("创建端口映射失败: %v", err)
}
t.Log("端口映射创建成功")
// 清理:测试结束后删除映射
defer func() {
t.Log("清理端口映射...")
err := removeMapping(config.ServerAPIAddr, config.TestPort)
if err != nil {
t.Logf("清理端口映射失败: %v", err)
}
// 等待清理完成
time.Sleep(2 * time.Second)
}()
// 等待映射生效
time.Sleep(2 * time.Second)
// 3. 通过服务器端口发送请求
testMessage := "Hello, Tunnel!"
t.Logf("发送测试消息: %s", testMessage)
response, err := sendTCPRequest(fmt.Sprintf("localhost:%d", config.TestPort), testMessage)
if err != nil {
t.Fatalf("发送请求失败: %v", err)
}
// 4. 验证响应
expectedResponse := "ECHO: " + testMessage
if response != expectedResponse {
t.Fatalf("响应不匹配.\n期望: %s\n实际: %s", expectedResponse, response)
}
t.Logf("✓ 转发成功,收到响应: %s", response)
}
// TestMultipleForwards 测试多个端口同时转发
func TestMultipleForwards(t *testing.T) {
config := defaultConfig
// 启动本地测试服务
t.Log("启动本地测试服务...")
stopChan := make(chan struct{})
go startEchoServer(config.LocalServicePort, stopChan, t)
defer close(stopChan)
time.Sleep(1 * time.Second)
// 测试多个端口
testPorts := []int{30002, 30003, 30004}
// 创建多个映射
for _, port := range testPorts {
t.Logf("创建端口映射: %d", port)
err := createMapping(config.ServerAPIAddr, port, "")
if err != nil {
t.Fatalf("创建端口 %d 映射失败: %v", port, err)
}
defer func(p int) {
err := removeMapping(config.ServerAPIAddr, p)
if err != nil {
t.Logf("删除端口 %d 映射失败: %v", p, err)
}
}(port)
}
time.Sleep(2 * time.Second)
// 并发测试所有端口
for _, port := range testPorts {
t.Run(fmt.Sprintf("Port_%d", port), func(t *testing.T) {
testMessage := fmt.Sprintf("Test message for port %d", port)
response, err := sendTCPRequest(fmt.Sprintf("localhost:%d", port), testMessage)
if err != nil {
t.Errorf("端口 %d 请求失败: %v", port, err)
return
}
expectedResponse := "ECHO: " + testMessage
if response != expectedResponse {
t.Errorf("端口 %d 响应不匹配.\n期望: %s\n实际: %s", port, expectedResponse, response)
return
}
t.Logf("✓ 端口 %d 转发成功", port)
})
}
// 等待清理完成
time.Sleep(2 * time.Second)
}
// TestAddAndRemoveMapping 测试动态添加和删除映射
func TestAddAndRemoveMapping(t *testing.T) {
config := defaultConfig
// 启动本地测试服务
t.Log("启动本地测试服务...")
stopChan := make(chan struct{})
go startEchoServer(config.LocalServicePort, stopChan, t)
defer close(stopChan)
time.Sleep(1 * time.Second)
port := 30005
// 1. 创建映射
t.Logf("创建端口映射: %d", port)
err := createMapping(config.ServerAPIAddr, port, "")
if err != nil {
t.Fatalf("创建端口映射失败: %v", err)
}
time.Sleep(2 * time.Second)
// 2. 验证映射工作
t.Log("验证映射工作...")
response, err := sendTCPRequest(fmt.Sprintf("localhost:%d", port), "test1")
if err != nil {
t.Fatalf("映射创建后请求失败: %v", err)
}
if response != "ECHO: test1" {
t.Fatalf("响应不匹配: %s", response)
}
t.Log("✓ 映射工作正常")
// 3. 删除映射
t.Log("删除端口映射...")
err = removeMapping(config.ServerAPIAddr, port)
if err != nil {
t.Fatalf("删除端口映射失败: %v", err)
}
time.Sleep(2 * time.Second)
// 4. 验证映射已删除(连接应该失败)
t.Log("验证映射已删除...")
_, err = sendTCPRequest(fmt.Sprintf("localhost:%d", port), "test2")
if err == nil {
t.Fatal("映射删除后请求应该失败,但成功了")
}
t.Logf("✓ 映射已删除,连接失败(符合预期): %v", err)
// 5. 重新创建映射
t.Log("重新创建端口映射...")
err = createMapping(config.ServerAPIAddr, port, "")
if err != nil {
t.Fatalf("重新创建端口映射失败: %v", err)
}
defer func() {
err := removeMapping(config.ServerAPIAddr, port)
if err != nil {
t.Logf("清理端口映射失败: %v", err)
}
time.Sleep(2 * time.Second)
}()
time.Sleep(2 * time.Second)
// 6. 验证映射恢复工作
t.Log("验证映射恢复工作...")
response, err = sendTCPRequest(fmt.Sprintf("localhost:%d", port), "test3")
if err != nil {
t.Fatalf("映射重新创建后请求失败: %v", err)
}
if response != "ECHO: test3" {
t.Fatalf("响应不匹配: %s", response)
}
t.Log("✓ 映射恢复工作正常")
}
// TestConcurrentRequests 测试并发请求
func TestConcurrentRequests(t *testing.T) {
config := defaultConfig
// 启动本地测试服务
t.Log("启动本地测试服务...")
stopChan := make(chan struct{})
go startEchoServer(config.LocalServicePort, stopChan, t)
defer close(stopChan)
time.Sleep(1 * time.Second)
port := 30006
// 创建映射
t.Logf("创建端口映射: %d", port)
err := createMapping(config.ServerAPIAddr, port, "")
if err != nil {
t.Fatalf("创建端口映射失败: %v", err)
}
defer func() {
err := removeMapping(config.ServerAPIAddr, port)
if err != nil {
t.Logf("清理端口映射失败: %v", err)
}
time.Sleep(2 * time.Second)
}()
time.Sleep(2 * time.Second)
// 并发发送请求
concurrency := 10
results := make(chan error, concurrency)
t.Logf("并发发送 %d 个请求...", concurrency)
for i := 0; i < concurrency; i++ {
go func(index int) {
message := fmt.Sprintf("concurrent_%d", index)
response, err := sendTCPRequest(fmt.Sprintf("localhost:%d", port), message)
if err != nil {
results <- fmt.Errorf("请求 %d 失败: %w", index, err)
return
}
expectedResponse := "ECHO: " + message
if response != expectedResponse {
results <- fmt.Errorf("请求 %d 响应不匹配: 期望=%s, 实际=%s", index, expectedResponse, response)
return
}
results <- nil
}(i)
}
// 收集结果
successCount := 0
for i := 0; i < concurrency; i++ {
err := <-results
if err != nil {
t.Errorf("%v", err)
} else {
successCount++
}
}
t.Logf("✓ 并发测试完成: %d/%d 成功", successCount, concurrency)
if successCount < concurrency {
t.Fatalf("部分并发请求失败")
}
}
// TestListMappings 测试列出所有映射
func TestListMappings(t *testing.T) {
config := defaultConfig
// 创建几个映射
testPorts := []int{30010, 30011, 30012}
for _, port := range testPorts {
t.Logf("创建端口映射: %d", port)
err := createMapping(config.ServerAPIAddr, port, "")
if err != nil {
t.Fatalf("创建端口 %d 映射失败: %v", port, err)
}
defer removeMapping(config.ServerAPIAddr, port)
}
time.Sleep(500 * time.Millisecond)
// 列出所有映射
t.Log("列出所有映射...")
mappings, err := listMappings(config.ServerAPIAddr)
if err != nil {
t.Fatalf("列出映射失败: %v", err)
}
t.Logf("当前映射数量: %d", len(mappings))
// 验证创建的映射都在列表中
for _, port := range testPorts {
found := false
for _, mapping := range mappings {
if sourcePort, ok := mapping["source_port"].(float64); ok && int(sourcePort) == port {
found = true
t.Logf("✓ 找到映射: 端口 %d", port)
break
}
}
if !found {
t.Errorf("映射端口 %d 未在列表中找到", port)
}
}
}
// TestHealthCheck 测试健康检查
func TestHealthCheck(t *testing.T) {
config := defaultConfig
url := config.ServerAPIAddr + "/health"
resp, err := http.Get(url)
if err != nil {
t.Fatalf("健康检查请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("健康检查返回非 200 状态码: %d", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
t.Logf("✓ 健康检查成功: %s", string(body))
}
// ============ 辅助函数 ============
// createMapping 创建端口映射
func createMapping(apiAddr string, port int, targetIP string) error {
url := apiAddr + "/api/mapping/create"
reqBody := CreateMappingRequest{
Port: port,
TargetIP: targetIP,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("序列化请求失败: %w", err)
}
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var result Response
if err := json.Unmarshal(body, &result); err != nil {
return fmt.Errorf("解析响应失败: %w", err)
}
if !result.Success {
return fmt.Errorf("API 返回失败: %s", result.Message)
}
return nil
}
// removeMapping 删除端口映射
func removeMapping(apiAddr string, port int) error {
url := apiAddr + "/api/mapping/remove"
reqBody := RemoveMappingRequest{
Port: port,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("序列化请求失败: %w", err)
}
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var result Response
if err := json.Unmarshal(body, &result); err != nil {
return fmt.Errorf("解析响应失败: %w", err)
}
if !result.Success {
return fmt.Errorf("API 返回失败: %s", result.Message)
}
return nil
}
// listMappings 列出所有映射
func listMappings(apiAddr string) ([]map[string]interface{}, error) {
url := apiAddr + "/api/mapping/list"
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var result Response
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if !result.Success {
return nil, fmt.Errorf("API 返回失败: %s", result.Message)
}
// 提取映射列表
if mappingsData, ok := result.Data["mappings"].([]interface{}); ok {
mappings := make([]map[string]interface{}, len(mappingsData))
for i, item := range mappingsData {
if mapping, ok := item.(map[string]interface{}); ok {
mappings[i] = mapping
}
}
return mappings, nil
}
return []map[string]interface{}{}, nil
}
// sendTCPRequest 发送 TCP 请求并接收响应
func sendTCPRequest(addr string, message string) (string, error) {
conn, err := net.DialTimeout("tcp", addr, 10*time.Second)
if err != nil {
return "", fmt.Errorf("连接失败: %w", err)
}
defer conn.Close()
// 设置超时 - 增加到10秒
conn.SetDeadline(time.Now().Add(10 * time.Second))
// 发送消息
_, err = conn.Write([]byte(message + "\n"))
if err != nil {
return "", fmt.Errorf("发送失败: %w", err)
}
// 接收响应
buffer := make([]byte, 1024)
n, err := conn.Read(buffer)
if err != nil {
return "", fmt.Errorf("接收失败: %w", err)
}
return string(buffer[:n]), nil
}
// startEchoServer 启动简单的 echo 服务器(用于测试)
func startEchoServer(port int, stopChan chan struct{}, t *testing.T) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
t.Logf("启动 echo 服务器失败: %v", err)
return
}
defer listener.Close()
t.Logf("Echo 服务器启动在端口 %d", port)
go func() {
<-stopChan
listener.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-stopChan:
return
default:
continue
}
}
go handleEchoConnection(conn, t)
}
}
// handleEchoConnection 处理 echo 连接
func handleEchoConnection(conn net.Conn, t *testing.T) {
defer conn.Close()
buffer := make([]byte, 1024)
n, err := conn.Read(buffer)
if err != nil {
if err != io.EOF {
t.Logf("读取数据失败: %v", err)
}
return
}
message := string(buffer[:n])
message = string(bytes.TrimSpace([]byte(message)))
response := fmt.Sprintf("ECHO: %s", message)
_, err = conn.Write([]byte(response))
if err != nil {
t.Logf("发送响应失败: %v", err)
}
}

83
src/test/run_tests.sh Normal file
View File

@ -0,0 +1,83 @@
#!/bin/bash
# 集成测试运行脚本
# 此脚本用于运行 go-tunnel 的集成测试
set -e
# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# 配置
SERVER_API=${SERVER_API:-"http://localhost:8080"}
SERVER_TUNNEL_PORT=${SERVER_TUNNEL_PORT:-9000}
TEST_TIMEOUT=${TEST_TIMEOUT:-30s}
echo -e "${BLUE}================================${NC}"
echo -e "${BLUE}Go-Tunnel 集成测试${NC}"
echo -e "${BLUE}================================${NC}"
echo ""
# 检查服务器是否运行
echo -e "${YELLOW}检查服务器状态...${NC}"
if curl -s "${SERVER_API}/health" > /dev/null 2>&1; then
echo -e "${GREEN}✓ 服务器运行正常${NC}"
else
echo -e "${RED}✗ 服务器未响应${NC}"
echo -e "${YELLOW}请确保服务器已启动:${NC}"
echo " cd ../src && make run-server"
exit 1
fi
# 检查客户端是否连接(通过创建一个测试映射来验证)
echo -e "${YELLOW}检查隧道连接...${NC}"
if curl -s -X POST "${SERVER_API}/api/mapping/create" \
-H "Content-Type: application/json" \
-d '{"port": 39999}' | grep -q "success"; then
echo -e "${GREEN}✓ 隧道连接正常${NC}"
# 清理测试映射
curl -s -X POST "${SERVER_API}/api/mapping/remove" \
-H "Content-Type: application/json" \
-d '{"port": 39999}' > /dev/null 2>&1 || true
else
echo -e "${YELLOW}⚠ 无法验证隧道连接,测试可能失败${NC}"
echo -e "${YELLOW}请确保客户端已启动:${NC}"
echo " cd ../src && make run-client"
fi
echo ""
echo -e "${BLUE}开始运行测试...${NC}"
echo ""
# 切换到测试目录
cd "$(dirname "$0")"
# 运行测试
if [ "$1" = "verbose" ] || [ "$1" = "-v" ]; then
go test -v -timeout ${TEST_TIMEOUT}
elif [ -n "$1" ]; then
# 运行特定测试
echo -e "${BLUE}运行测试: $1${NC}"
go test -v -timeout ${TEST_TIMEOUT} -run "$1"
else
go test -timeout ${TEST_TIMEOUT}
fi
TEST_EXIT_CODE=$?
echo ""
if [ $TEST_EXIT_CODE -eq 0 ]; then
echo -e "${GREEN}================================${NC}"
echo -e "${GREEN}✓ 所有测试通过${NC}"
echo -e "${GREEN}================================${NC}"
else
echo -e "${RED}================================${NC}"
echo -e "${RED}✗ 测试失败${NC}"
echo -e "${RED}================================${NC}"
fi
exit $TEST_EXIT_CODE

View File

@ -0,0 +1,118 @@
#!/bin/bash
# 简化的集成测试运行脚本
# 说明:此脚本用于逐个运行集成测试,避免端口冲突
set -e
echo "========================================="
echo "集成测试开始"
echo "========================================="
# 获取脚本所在目录
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
# 清理函数
cleanup_ports() {
echo "清理测试端口..."
for port in 30001 30002 30003 30004 30005 30006 30010 30011 30012; do
# 尝试清理可能残留的端口映射
curl -s -X POST http://localhost:8080/api/mapping/remove \
-H "Content-Type: application/json" \
-d "{\"port\": $port}" > /dev/null 2>&1 || true
done
# 等待端口完全释放
sleep 3
}
# 测试前清理
echo "测试前清理..."
cleanup_ports
# 检查服务器是否运行
echo "检查服务器状态..."
if ! curl -s http://localhost:8080/health > /dev/null 2>&1; then
echo "错误:服务器未运行,请先启动服务器"
echo "运行: cd .. && make run-server"
exit 1
fi
echo "✓ 服务器正在运行"
# 检查客户端是否连接
echo "检查客户端连接..."
MAPPING_RESPONSE=$(curl -s http://localhost:8080/api/mapping/list)
if [ -z "$MAPPING_RESPONSE" ]; then
echo "警告:无法获取映射列表"
else
echo "✓ 可以访问服务器 API"
fi
# 运行测试
echo "========================================="
echo "运行集成测试(逐个测试)..."
echo "========================================="
TEST_FAILED=0
# 逐个运行测试,避免端口冲突
echo ""
echo "测试 1: TestForwardingBasic"
echo "-----------------------------------------"
if go test -v -timeout 30s -run TestForwardingBasic; then
echo "✓ TestForwardingBasic 通过"
else
echo "✗ TestForwardingBasic 失败"
TEST_FAILED=1
fi
cleanup_ports
echo ""
echo "测试 2: TestMultipleForwards"
echo "-----------------------------------------"
if go test -v -timeout 30s -run TestMultipleForwards; then
echo "✓ TestMultipleForwards 通过"
else
echo "✗ TestMultipleForwards 失败"
TEST_FAILED=1
fi
cleanup_ports
echo ""
echo "测试 3: TestAddAndRemoveMapping"
echo "-----------------------------------------"
if go test -v -timeout 30s -run TestAddAndRemoveMapping; then
echo "✓ TestAddAndRemoveMapping 通过"
else
echo "✗ TestAddAndRemoveMapping 失败"
TEST_FAILED=1
fi
cleanup_ports
echo ""
echo "测试 4: TestConcurrentRequests"
echo "-----------------------------------------"
if go test -v -timeout 30s -run TestConcurrentRequests; then
echo "✓ TestConcurrentRequests 通过"
else
echo "✗ TestConcurrentRequests 失败"
TEST_FAILED=1
fi
cleanup_ports
# 测试后清理
echo ""
echo "测试后清理..."
cleanup_ports
echo "========================================="
if [ $TEST_FAILED -eq 0 ]; then
echo "✓ 所有测试通过"
echo "========================================="
exit 0
else
echo "✗ 部分测试失败"
echo "========================================="
exit 1
fi