feat: 端口转发已实现
This commit is contained in:
parent
6b02def8de
commit
8c2f803b77
|
|
@ -15,24 +15,24 @@ build: build-server build-client
|
|||
# 编译服务器
|
||||
build-server:
|
||||
@echo "编译服务器..."
|
||||
go build -o bin/server ./server
|
||||
go build -o ../bin/server ./server
|
||||
@echo "服务器编译完成: bin/server"
|
||||
|
||||
# 编译客户端
|
||||
build-client:
|
||||
@echo "编译客户端..."
|
||||
go build -o bin/client ./client
|
||||
go build -o ../bin/client ./client
|
||||
@echo "客户端编译完成: bin/client"
|
||||
|
||||
# 运行服务器
|
||||
run-server: build-server
|
||||
@echo "启动服务器..."
|
||||
./bin/server -config config.yaml
|
||||
../bin/server -config config.yaml
|
||||
|
||||
# 运行客户端
|
||||
run-client: build-client
|
||||
@echo "启动客户端..."
|
||||
./bin/client -server localhost:9000
|
||||
../bin/client -server localhost:9000
|
||||
|
||||
# 清理编译文件
|
||||
clean:
|
||||
|
|
@ -56,8 +56,8 @@ init:
|
|||
# 交叉编译 Linux
|
||||
build-linux:
|
||||
@echo "交叉编译 Linux 版本..."
|
||||
GOOS=linux GOARCH=amd64 go build -o bin/server-linux ./server
|
||||
GOOS=linux GOARCH=amd64 go build -o bin/client-linux ./client
|
||||
GOOS=linux GOARCH=amd64 go build -o ../bin/server-linux ./server
|
||||
GOOS=linux GOARCH=amd64 go build -o ../bin/client-linux ./client
|
||||
@echo "Linux 版本编译完成"
|
||||
|
||||
# 格式化代码
|
||||
|
|
@ -12,14 +12,51 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// HeaderSize 消息头大小
|
||||
HeaderSize = 8
|
||||
// MaxPacketSize 最大包大小
|
||||
// 协议版本
|
||||
ProtocolVersion = 0x01
|
||||
|
||||
// 消息头大小
|
||||
HeaderSize = 6 // 版本(1) + 类型(1) + 长度(4)
|
||||
|
||||
// 最大包大小
|
||||
MaxPacketSize = 1024 * 1024
|
||||
// ReconnectDelay 重连延迟
|
||||
|
||||
// 重连延迟
|
||||
ReconnectDelay = 5 * time.Second
|
||||
|
||||
// 消息类型
|
||||
MsgTypeConnectRequest = 0x01 // 连接请求
|
||||
MsgTypeConnectResponse = 0x02 // 连接响应
|
||||
MsgTypeData = 0x03 // 数据传输
|
||||
MsgTypeClose = 0x04 // 关闭连接
|
||||
MsgTypeKeepAlive = 0x05 // 心跳
|
||||
|
||||
// 连接响应状态
|
||||
ConnStatusSuccess = 0x00 // 连接成功
|
||||
ConnStatusFailed = 0x01 // 连接失败
|
||||
|
||||
// 超时设置
|
||||
ConnectTimeout = 10 * time.Second // 连接超时
|
||||
ReadTimeout = 30 * time.Second // 读取超时
|
||||
)
|
||||
|
||||
// TunnelMessage 隧道消息
|
||||
type TunnelMessage struct {
|
||||
Version byte
|
||||
Type byte
|
||||
Length uint32
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// LocalConnection 本地连接
|
||||
type LocalConnection struct {
|
||||
ID uint32
|
||||
TargetAddr string
|
||||
Conn net.Conn
|
||||
closeChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// Client 内网穿透客户端
|
||||
type Client struct {
|
||||
serverAddr string
|
||||
|
|
@ -32,14 +69,9 @@ type Client struct {
|
|||
// 连接管理
|
||||
connections map[uint32]*LocalConnection
|
||||
connMu sync.RWMutex
|
||||
}
|
||||
|
||||
// LocalConnection 本地连接
|
||||
type LocalConnection struct {
|
||||
ID uint32
|
||||
TargetAddr string
|
||||
Conn net.Conn
|
||||
closeChan chan struct{}
|
||||
|
||||
// 消息队列
|
||||
sendChan chan *TunnelMessage
|
||||
}
|
||||
|
||||
// NewClient 创建新的隧道客户端
|
||||
|
|
@ -50,6 +82,7 @@ func NewClient(serverAddr string) *Client {
|
|||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
connections: make(map[uint32]*LocalConnection),
|
||||
sendChan: make(chan *TunnelMessage, 1000),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -88,11 +121,19 @@ func (c *Client) connectLoop() {
|
|||
c.mu.Unlock()
|
||||
|
||||
// 处理连接
|
||||
if err := c.handleServerConnection(conn); err != nil {
|
||||
if err != io.EOF {
|
||||
log.Printf("处理服务器连接出错: %v", err)
|
||||
}
|
||||
}
|
||||
var connWg sync.WaitGroup
|
||||
connWg.Add(2)
|
||||
go func() {
|
||||
defer connWg.Done()
|
||||
c.handleServerRead(conn)
|
||||
}()
|
||||
go func() {
|
||||
defer connWg.Done()
|
||||
c.handleServerWrite(conn)
|
||||
}()
|
||||
|
||||
// 等待连接断开
|
||||
connWg.Wait()
|
||||
|
||||
c.mu.Lock()
|
||||
c.serverConn = nil
|
||||
|
|
@ -101,7 +142,9 @@ func (c *Client) connectLoop() {
|
|||
// 关闭所有本地连接
|
||||
c.connMu.Lock()
|
||||
for _, conn := range c.connections {
|
||||
close(conn.closeChan)
|
||||
conn.closeOnce.Do(func() {
|
||||
close(conn.closeChan)
|
||||
})
|
||||
if conn.Conn != nil {
|
||||
conn.Conn.Close()
|
||||
}
|
||||
|
|
@ -114,140 +157,317 @@ func (c *Client) connectLoop() {
|
|||
}
|
||||
}
|
||||
|
||||
// handleServerConnection 处理服务器连接
|
||||
func (c *Client) handleServerConnection(conn net.Conn) error {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// 读取消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dataLen := binary.BigEndian.Uint32(header[0:4])
|
||||
connID := binary.BigEndian.Uint32(header[4:8])
|
||||
|
||||
if dataLen > MaxPacketSize {
|
||||
return fmt.Errorf("数据包过大: %d bytes", dataLen)
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
data := make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(conn, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理数据
|
||||
c.handleData(connID, data)
|
||||
}
|
||||
}
|
||||
|
||||
// handleData 处理接收到的数据
|
||||
func (c *Client) handleData(connID uint32, data []byte) {
|
||||
c.connMu.Lock()
|
||||
localConn, exists := c.connections[connID]
|
||||
|
||||
if !exists {
|
||||
// 新连接,需要建立到本地服务的连接
|
||||
// 从数据中解析目标端口(这里简化处理,实际应该从协议中获取)
|
||||
localConn = &LocalConnection{
|
||||
ID: connID,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
c.connections[connID] = localConn
|
||||
c.connMu.Unlock()
|
||||
|
||||
// 启动本地连接处理
|
||||
c.wg.Add(1)
|
||||
go c.handleLocalConnection(localConn)
|
||||
|
||||
// 重新获取锁并发送数据
|
||||
c.connMu.Lock()
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
// 发送数据到本地连接
|
||||
if localConn.Conn != nil {
|
||||
localConn.Conn.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
// handleLocalConnection 处理本地连接
|
||||
func (c *Client) handleLocalConnection(localConn *LocalConnection) {
|
||||
defer c.wg.Done()
|
||||
defer func() {
|
||||
c.connMu.Lock()
|
||||
delete(c.connections, localConn.ID)
|
||||
c.connMu.Unlock()
|
||||
|
||||
if localConn.Conn != nil {
|
||||
localConn.Conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// 连接到本地目标服务
|
||||
// 这里使用固定的本地地址,实际应该根据映射配置
|
||||
targetAddr := localConn.TargetAddr
|
||||
if targetAddr == "" {
|
||||
targetAddr = "127.0.0.1:22" // 默认 SSH
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", targetAddr, 5*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("连接本地服务失败 (连接 %d -> %s): %v", localConn.ID, targetAddr, err)
|
||||
return
|
||||
}
|
||||
// handleServerRead 处理服务器读取
|
||||
func (c *Client) handleServerRead(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
localConn.Conn = conn
|
||||
log.Printf("建立本地连接: %d -> %s", localConn.ID, targetAddr)
|
||||
|
||||
// 从本地服务读取数据并发送到服务器
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-localConn.closeChan:
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, err := conn.Read(buffer)
|
||||
msg, err := c.readTunnelMessage(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF && !isTimeout(err) {
|
||||
log.Printf("读取本地连接失败 (连接 %d): %v", localConn.ID, err)
|
||||
if err != io.EOF {
|
||||
log.Printf("读取隧道消息失败: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送到服务器
|
||||
c.mu.RLock()
|
||||
serverConn := c.serverConn
|
||||
c.mu.RUnlock()
|
||||
c.handleTunnelMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
if serverConn == nil {
|
||||
// handleServerWrite 处理服务器写入
|
||||
func (c *Client) handleServerWrite(conn net.Conn) {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case msg := <-c.sendChan:
|
||||
if err := c.writeTunnelMessage(conn, msg); err != nil {
|
||||
log.Printf("写入隧道消息失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readTunnelMessage 读取隧道消息
|
||||
func (c *Client) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
|
||||
// 读取消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
version := header[0]
|
||||
msgType := header[1]
|
||||
dataLen := binary.BigEndian.Uint32(header[2:6])
|
||||
|
||||
if version != ProtocolVersion {
|
||||
return nil, fmt.Errorf("不支持的协议版本: %d", version)
|
||||
}
|
||||
|
||||
if dataLen > MaxPacketSize {
|
||||
return nil, fmt.Errorf("数据包过大: %d bytes", dataLen)
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
var data []byte
|
||||
if dataLen > 0 {
|
||||
data = make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(conn, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &TunnelMessage{
|
||||
Version: version,
|
||||
Type: msgType,
|
||||
Length: dataLen,
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// writeTunnelMessage 写入隧道消息
|
||||
func (c *Client) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
|
||||
// 构建消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
header[0] = msg.Version
|
||||
header[1] = msg.Type
|
||||
binary.BigEndian.PutUint32(header[2:6], msg.Length)
|
||||
|
||||
// 写入消息头
|
||||
if _, err := conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 写入数据
|
||||
if msg.Length > 0 && msg.Data != nil {
|
||||
if _, err := conn.Write(msg.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTunnelMessage 处理隧道消息
|
||||
func (c *Client) handleTunnelMessage(msg *TunnelMessage) {
|
||||
switch msg.Type {
|
||||
case MsgTypeConnectRequest:
|
||||
c.handleConnectRequest(msg)
|
||||
case MsgTypeData:
|
||||
c.handleDataMessage(msg)
|
||||
case MsgTypeClose:
|
||||
c.handleCloseMessage(msg)
|
||||
case MsgTypeKeepAlive:
|
||||
c.handleKeepAlive(msg)
|
||||
default:
|
||||
log.Printf("未知消息类型: %d", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnectRequest 处理连接请求
|
||||
func (c *Client) handleConnectRequest(msg *TunnelMessage) {
|
||||
if len(msg.Data) < 6 {
|
||||
log.Printf("连接请求数据太短")
|
||||
return
|
||||
}
|
||||
|
||||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||||
targetPort := binary.BigEndian.Uint16(msg.Data[4:6])
|
||||
targetAddr := fmt.Sprintf("127.0.0.1:%d", targetPort)
|
||||
|
||||
log.Printf("收到连接请求: ID=%d, 端口=%d", connID, targetPort)
|
||||
|
||||
// 尝试连接到本地服务
|
||||
localConn, err := net.DialTimeout("tcp", targetAddr, ConnectTimeout)
|
||||
if err != nil {
|
||||
log.Printf("连接本地服务失败 (ID=%d -> %s): %v", connID, targetAddr, err)
|
||||
c.sendConnectResponse(connID, ConnStatusFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建本地连接对象
|
||||
connection := &LocalConnection{
|
||||
ID: connID,
|
||||
TargetAddr: targetAddr,
|
||||
Conn: localConn,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
c.connMu.Lock()
|
||||
c.connections[connID] = connection
|
||||
c.connMu.Unlock()
|
||||
|
||||
log.Printf("建立本地连接: ID=%d -> %s", connID, targetAddr)
|
||||
|
||||
// 发送连接成功响应
|
||||
c.sendConnectResponse(connID, ConnStatusSuccess)
|
||||
|
||||
// 启动数据转发
|
||||
go c.forwardData(connection)
|
||||
}
|
||||
|
||||
// handleDataMessage 处理数据消息
|
||||
func (c *Client) handleDataMessage(msg *TunnelMessage) {
|
||||
if len(msg.Data) < 4 {
|
||||
log.Printf("数据消息太短")
|
||||
return
|
||||
}
|
||||
|
||||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||||
data := msg.Data[4:]
|
||||
|
||||
c.connMu.RLock()
|
||||
connection, exists := c.connections[connID]
|
||||
c.connMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
log.Printf("收到未知连接的数据: %d", connID)
|
||||
return
|
||||
}
|
||||
|
||||
// 写入到本地连接
|
||||
if _, err := connection.Conn.Write(data); err != nil {
|
||||
log.Printf("写入本地连接失败 (ID=%d): %v", connID, err)
|
||||
c.closeConnection(connID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCloseMessage 处理关闭消息
|
||||
func (c *Client) handleCloseMessage(msg *TunnelMessage) {
|
||||
if len(msg.Data) < 4 {
|
||||
log.Printf("关闭消息数据太短")
|
||||
return
|
||||
}
|
||||
|
||||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||||
c.closeConnection(connID)
|
||||
}
|
||||
|
||||
// handleKeepAlive 处理心跳消息
|
||||
func (c *Client) handleKeepAlive(msg *TunnelMessage) {
|
||||
// 回应心跳
|
||||
response := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeKeepAlive,
|
||||
Length: 0,
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
select {
|
||||
case c.sendChan <- response:
|
||||
default:
|
||||
log.Printf("发送心跳响应失败: 发送队列已满")
|
||||
}
|
||||
}
|
||||
|
||||
// sendConnectResponse 发送连接响应
|
||||
func (c *Client) sendConnectResponse(connID uint32, status byte) {
|
||||
responseData := make([]byte, 5)
|
||||
binary.BigEndian.PutUint32(responseData[0:4], connID)
|
||||
responseData[4] = status
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeConnectResponse,
|
||||
Length: 5,
|
||||
Data: responseData,
|
||||
}
|
||||
|
||||
select {
|
||||
case c.sendChan <- msg:
|
||||
default:
|
||||
log.Printf("发送连接响应失败: 发送队列已满")
|
||||
}
|
||||
}
|
||||
|
||||
// forwardData 转发数据
|
||||
func (c *Client) forwardData(connection *LocalConnection) {
|
||||
defer c.closeConnection(connection.ID)
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-connection.closeChan:
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
connection.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
|
||||
n, err := connection.Conn.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF && !isTimeout(err) {
|
||||
log.Printf("读取本地连接失败 (ID=%d): %v", connection.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]byte, HeaderSize+n)
|
||||
binary.BigEndian.PutUint32(data[0:4], uint32(n))
|
||||
binary.BigEndian.PutUint32(data[4:8], localConn.ID)
|
||||
copy(data[HeaderSize:], buffer[:n])
|
||||
// 发送数据到隧道
|
||||
dataMsg := make([]byte, 4+n)
|
||||
binary.BigEndian.PutUint32(dataMsg[0:4], connection.ID)
|
||||
copy(dataMsg[4:], buffer[:n])
|
||||
|
||||
if _, err := serverConn.Write(data); err != nil {
|
||||
log.Printf("发送数据到服务器失败 (连接 %d): %v", localConn.ID, err)
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(dataMsg)),
|
||||
Data: dataMsg,
|
||||
}
|
||||
|
||||
select {
|
||||
case c.sendChan <- msg:
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("发送数据超时 (ID=%d)", connection.ID)
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnection 关闭连接
|
||||
func (c *Client) closeConnection(connID uint32) {
|
||||
c.connMu.Lock()
|
||||
connection, exists := c.connections[connID]
|
||||
if exists {
|
||||
delete(c.connections, connID)
|
||||
connection.closeOnce.Do(func() {
|
||||
close(connection.closeChan)
|
||||
})
|
||||
connection.Conn.Close()
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
// 发送关闭消息
|
||||
closeData := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(closeData, connID)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeClose,
|
||||
Length: 4,
|
||||
Data: closeData,
|
||||
}
|
||||
|
||||
select {
|
||||
case c.sendChan <- msg:
|
||||
default:
|
||||
// 发送队列满,忽略
|
||||
}
|
||||
|
||||
if exists {
|
||||
log.Printf("连接已关闭: ID=%d", connID)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止隧道客户端
|
||||
func (c *Client) Stop() error {
|
||||
log.Println("正在停止隧道客户端...")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,482 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewClient 测试创建隧道客户端
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient("127.0.0.1:9000")
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("创建隧道客户端失败")
|
||||
}
|
||||
|
||||
if client.serverAddr != "127.0.0.1:9000" {
|
||||
t.Errorf("服务器地址不正确,期望 127.0.0.1:9000,得到 %s", client.serverAddr)
|
||||
}
|
||||
|
||||
if client.connections == nil {
|
||||
t.Error("连接映射未初始化")
|
||||
}
|
||||
|
||||
if client.sendChan == nil {
|
||||
t.Error("发送通道未初始化")
|
||||
}
|
||||
|
||||
if client.ctx == nil {
|
||||
t.Error("上下文未初始化")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientStartStop 测试客户端启动和停止
|
||||
func TestClientStartStop(t *testing.T) {
|
||||
// 创建模拟服务器
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建模拟服务器失败: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
// 启动模拟服务器
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 保持连接短暂时间然后关闭
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
client := NewClient(serverAddr)
|
||||
|
||||
err = client.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动客户端失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待客户端尝试连接
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// 停止客户端
|
||||
err = client.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("停止客户端失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientTunnelMessage 测试客户端隧道消息处理
|
||||
func TestClientTunnelMessage(t *testing.T) {
|
||||
client := NewClient("127.0.0.1:9000")
|
||||
|
||||
// 创建测试消息
|
||||
data := []byte("test data")
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(data)),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
// 测试写入消息
|
||||
go func() {
|
||||
err := client.writeTunnelMessage(serverConn, msg)
|
||||
if err != nil {
|
||||
t.Errorf("写入隧道消息失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 测试读取消息
|
||||
receivedMsg, err := client.readTunnelMessage(clientConn)
|
||||
if err != nil {
|
||||
t.Fatalf("读取隧道消息失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证消息内容
|
||||
if receivedMsg.Version != msg.Version {
|
||||
t.Errorf("版本不匹配,期望 %d,得到 %d", msg.Version, receivedMsg.Version)
|
||||
}
|
||||
|
||||
if receivedMsg.Type != msg.Type {
|
||||
t.Errorf("类型不匹配,期望 %d,得到 %d", msg.Type, receivedMsg.Type)
|
||||
}
|
||||
|
||||
if receivedMsg.Length != msg.Length {
|
||||
t.Errorf("长度不匹配,期望 %d,得到 %d", msg.Length, receivedMsg.Length)
|
||||
}
|
||||
|
||||
if string(receivedMsg.Data) != string(msg.Data) {
|
||||
t.Errorf("数据不匹配,期望 %s,得到 %s", string(msg.Data), string(receivedMsg.Data))
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientHandleConnectRequest 测试客户端处理连接请求
|
||||
func TestClientHandleConnectRequest(t *testing.T) {
|
||||
// 启动本地测试服务
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("启动本地测试服务失败: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
localPort := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// 启动echo服务
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
io.Copy(c, c) // echo服务
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
client := NewClient("127.0.0.1:9000")
|
||||
|
||||
// 创建连接请求消息
|
||||
connID := uint32(12345)
|
||||
reqData := make([]byte, 6)
|
||||
binary.BigEndian.PutUint32(reqData[0:4], connID)
|
||||
binary.BigEndian.PutUint16(reqData[4:6], uint16(localPort))
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeConnectRequest,
|
||||
Length: 6,
|
||||
Data: reqData,
|
||||
}
|
||||
|
||||
// 创建模拟服务器连接
|
||||
serverConn, clientServerConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientServerConn.Close()
|
||||
|
||||
// 启动发送处理
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case msg := <-client.sendChan:
|
||||
client.writeTunnelMessage(clientServerConn, msg)
|
||||
case <-time.After(time.Second):
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 处理连接请求
|
||||
client.handleTunnelMessage(msg)
|
||||
|
||||
// 等待连接建立
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证连接是否建立
|
||||
client.connMu.RLock()
|
||||
_, exists := client.connections[connID]
|
||||
client.connMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Error("连接未建立")
|
||||
}
|
||||
|
||||
// 读取连接响应
|
||||
response, err := client.readTunnelMessage(serverConn)
|
||||
if err != nil {
|
||||
t.Fatalf("读取连接响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Type != MsgTypeConnectResponse {
|
||||
t.Errorf("响应类型不正确,期望 %d,得到 %d", MsgTypeConnectResponse, response.Type)
|
||||
}
|
||||
|
||||
if len(response.Data) < 5 {
|
||||
t.Fatal("响应数据太短")
|
||||
}
|
||||
|
||||
responseConnID := binary.BigEndian.Uint32(response.Data[0:4])
|
||||
status := response.Data[4]
|
||||
|
||||
if responseConnID != connID {
|
||||
t.Errorf("响应连接ID不匹配,期望 %d,得到 %d", connID, responseConnID)
|
||||
}
|
||||
|
||||
if status != ConnStatusSuccess {
|
||||
t.Errorf("连接状态不正确,期望 %d,得到 %d", ConnStatusSuccess, status)
|
||||
}
|
||||
|
||||
// 清理连接
|
||||
client.closeConnection(connID)
|
||||
}
|
||||
|
||||
// TestClientHandleDataMessage 测试客户端处理数据消息
|
||||
func TestClientHandleDataMessage(t *testing.T) {
|
||||
client := NewClient("127.0.0.1:9000")
|
||||
|
||||
// 创建模拟本地连接
|
||||
localConn, remoteConn := net.Pipe()
|
||||
defer localConn.Close()
|
||||
defer remoteConn.Close()
|
||||
|
||||
connID := uint32(12345)
|
||||
connection := &LocalConnection{
|
||||
ID: connID,
|
||||
TargetAddr: "127.0.0.1:8080",
|
||||
Conn: localConn,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
client.connMu.Lock()
|
||||
client.connections[connID] = connection
|
||||
client.connMu.Unlock()
|
||||
|
||||
// 创建数据消息
|
||||
testData := []byte("Hello, World!")
|
||||
dataMsg := make([]byte, 4+len(testData))
|
||||
binary.BigEndian.PutUint32(dataMsg[0:4], connID)
|
||||
copy(dataMsg[4:], testData)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(dataMsg)),
|
||||
Data: dataMsg,
|
||||
}
|
||||
|
||||
// 处理数据消息
|
||||
go client.handleTunnelMessage(msg)
|
||||
|
||||
// 从远程连接读取数据
|
||||
buffer := make([]byte, len(testData))
|
||||
remoteConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := remoteConn.Read(buffer)
|
||||
if err != nil {
|
||||
t.Fatalf("读取数据失败: %v", err)
|
||||
}
|
||||
|
||||
if n != len(testData) {
|
||||
t.Errorf("数据长度不匹配,期望 %d,得到 %d", len(testData), n)
|
||||
}
|
||||
|
||||
if string(buffer[:n]) != string(testData) {
|
||||
t.Errorf("数据内容不匹配,期望 %s,得到 %s", string(testData), string(buffer[:n]))
|
||||
}
|
||||
|
||||
// 清理连接
|
||||
client.closeConnection(connID)
|
||||
}
|
||||
|
||||
// TestClientHandleCloseMessage 测试客户端处理关闭消息
|
||||
func TestClientHandleCloseMessage(t *testing.T) {
|
||||
client := NewClient("127.0.0.1:9000")
|
||||
|
||||
// 创建模拟本地连接
|
||||
localConn, _ := net.Pipe()
|
||||
defer localConn.Close()
|
||||
|
||||
connID := uint32(12345)
|
||||
connection := &LocalConnection{
|
||||
ID: connID,
|
||||
TargetAddr: "127.0.0.1:8080",
|
||||
Conn: localConn,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
client.connMu.Lock()
|
||||
client.connections[connID] = connection
|
||||
client.connMu.Unlock()
|
||||
|
||||
// 创建关闭消息
|
||||
closeData := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(closeData, connID)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeClose,
|
||||
Length: 4,
|
||||
Data: closeData,
|
||||
}
|
||||
|
||||
// 处理关闭消息
|
||||
client.handleTunnelMessage(msg)
|
||||
|
||||
// 等待连接关闭
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证连接是否被移除
|
||||
client.connMu.RLock()
|
||||
_, exists := client.connections[connID]
|
||||
client.connMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
t.Error("连接未被移除")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientKeepAlive 测试客户端心跳处理
|
||||
func TestClientKeepAlive(t *testing.T) {
|
||||
client := NewClient("127.0.0.1:9000")
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
// 启动发送处理
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case msg := <-client.sendChan:
|
||||
client.writeTunnelMessage(clientConn, msg)
|
||||
case <-time.After(time.Second):
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 创建心跳消息
|
||||
keepAliveMsg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeKeepAlive,
|
||||
Length: 0,
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
// 处理心跳消息
|
||||
client.handleTunnelMessage(keepAliveMsg)
|
||||
|
||||
// 读取心跳响应
|
||||
response, err := client.readTunnelMessage(serverConn)
|
||||
if err != nil {
|
||||
t.Fatalf("读取心跳响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Type != MsgTypeKeepAlive {
|
||||
t.Errorf("心跳响应类型不正确,期望 %d,得到 %d", MsgTypeKeepAlive, response.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientReconnect 测试客户端重连功能
|
||||
func TestClientReconnect(t *testing.T) {
|
||||
// 这个测试比较复杂,需要模拟服务器关闭和重启
|
||||
// 简化版本只测试客户端的创建和基本功能
|
||||
client := NewClient("127.0.0.1:19999") // 使用不存在的端口
|
||||
|
||||
err := client.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动客户端失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待客户端尝试连接
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
err = client.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("停止客户端失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientConcurrentConnections 测试客户端并发连接处理
|
||||
func TestClientConcurrentConnections(t *testing.T) {
|
||||
client := NewClient("127.0.0.1:9000")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
connCount := 5
|
||||
|
||||
// 创建多个模拟连接
|
||||
for i := 0; i < connCount; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// 创建模拟本地连接
|
||||
localConn, _ := net.Pipe()
|
||||
defer localConn.Close()
|
||||
|
||||
connID := uint32(id + 1)
|
||||
connection := &LocalConnection{
|
||||
ID: connID,
|
||||
TargetAddr: "127.0.0.1:8080",
|
||||
Conn: localConn,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
client.connMu.Lock()
|
||||
client.connections[connID] = connection
|
||||
client.connMu.Unlock()
|
||||
|
||||
// 保持连接一段时间
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 清理连接
|
||||
client.closeConnection(connID)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 验证所有连接都被清理
|
||||
client.connMu.RLock()
|
||||
remainingConns := len(client.connections)
|
||||
client.connMu.RUnlock()
|
||||
|
||||
if remainingConns != 0 {
|
||||
t.Errorf("还有 %d 个连接未清理", remainingConns)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProtocolConstants 测试协议常量
|
||||
func TestProtocolConstants(t *testing.T) {
|
||||
if ProtocolVersion != 0x01 {
|
||||
t.Errorf("协议版本不正确,期望 0x01,得到 0x%02x", ProtocolVersion)
|
||||
}
|
||||
|
||||
if HeaderSize != 6 {
|
||||
t.Errorf("消息头大小不正确,期望 6,得到 %d", HeaderSize)
|
||||
}
|
||||
|
||||
if MaxPacketSize != 1024*1024 {
|
||||
t.Errorf("最大包大小不正确,期望 %d,得到 %d", 1024*1024, MaxPacketSize)
|
||||
}
|
||||
|
||||
// 验证消息类型常量
|
||||
expectedTypes := map[string]byte{
|
||||
"MsgTypeConnectRequest": MsgTypeConnectRequest,
|
||||
"MsgTypeConnectResponse": MsgTypeConnectResponse,
|
||||
"MsgTypeData": MsgTypeData,
|
||||
"MsgTypeClose": MsgTypeClose,
|
||||
"MsgTypeKeepAlive": MsgTypeKeepAlive,
|
||||
}
|
||||
|
||||
for name, value := range expectedTypes {
|
||||
if value == 0 {
|
||||
t.Errorf("消息类型 %s 未定义", name)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证连接状态常量
|
||||
if ConnStatusSuccess != 0x00 {
|
||||
t.Errorf("连接成功状态不正确,期望 0x00,得到 0x%02x", ConnStatusSuccess)
|
||||
}
|
||||
|
||||
if ConnStatusFailed != 0x01 {
|
||||
t.Errorf("连接失败状态不正确,期望 0x01,得到 0x%02x", ConnStatusFailed)
|
||||
}
|
||||
}
|
||||
|
|
@ -20,25 +20,25 @@ type Handler struct {
|
|||
tunnelServer *tunnel.Server
|
||||
portRangeFrom int
|
||||
portRangeEnd int
|
||||
useTunnel bool
|
||||
}
|
||||
|
||||
// NewHandler 创建新的 API 处理器
|
||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, portFrom, portEnd int, useTunnel bool) *Handler {
|
||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, portFrom, portEnd int) *Handler {
|
||||
return &Handler{
|
||||
db: database,
|
||||
forwarderMgr: fwdMgr,
|
||||
tunnelServer: ts,
|
||||
portRangeFrom: portFrom,
|
||||
portRangeEnd: portEnd,
|
||||
useTunnel: useTunnel,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMappingRequest 创建映射请求
|
||||
type CreateMappingRequest struct {
|
||||
Port int `json:"port"` // 源端口和目标端口(相同)
|
||||
TargetIP string `json:"target_ip"` // 目标 IP(非隧道模式使用)
|
||||
SourcePort int `json:"source_port"` // 源端口(本地监听端口)
|
||||
TargetPort int `json:"target_port"` // 目标端口(远程服务端口)
|
||||
TargetIP string `json:"target_ip"` // 目标 IP(非隧道模式使用)
|
||||
UseTunnel bool `json:"use_tunnel"` // 是否使用隧道模式
|
||||
}
|
||||
|
||||
// RemoveMappingRequest 删除映射请求
|
||||
|
|
@ -75,57 +75,72 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// 验证端口范围
|
||||
if req.Port < h.portRangeFrom || req.Port > h.portRangeEnd {
|
||||
if req.SourcePort < h.portRangeFrom || req.SourcePort > h.portRangeEnd {
|
||||
h.writeError(w, http.StatusBadRequest, fmt.Sprintf("端口必须在 %d-%d 范围内", h.portRangeFrom, h.portRangeEnd))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查端口是否已被使用
|
||||
if h.forwarderMgr.Exists(req.Port) {
|
||||
if h.forwarderMgr.Exists(req.SourcePort) {
|
||||
h.writeError(w, http.StatusConflict, "端口已被占用")
|
||||
return
|
||||
}
|
||||
|
||||
// 非隧道模式需要验证 IP
|
||||
if !h.useTunnel {
|
||||
if req.TargetIP == "" {
|
||||
h.writeError(w, http.StatusBadRequest, "target_ip 不能为空")
|
||||
// 根据请求决定使用哪种模式
|
||||
if req.UseTunnel {
|
||||
// 隧道模式,检查隧道服务器是否可用
|
||||
if h.tunnelServer == nil {
|
||||
h.writeError(w, http.StatusServiceUnavailable, "隧道服务未启用")
|
||||
return
|
||||
}
|
||||
if net.ParseIP(req.TargetIP) == nil {
|
||||
h.writeError(w, http.StatusBadRequest, "target_ip 格式无效")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 隧道模式,检查隧道是否连接
|
||||
if !h.tunnelServer.IsConnected() {
|
||||
h.writeError(w, http.StatusServiceUnavailable, "隧道未连接")
|
||||
return
|
||||
}
|
||||
// 隧道模式使用本地地址
|
||||
req.TargetIP = "127.0.0.1"
|
||||
} else {
|
||||
// 直接模式需要验证 IP
|
||||
if req.TargetIP == "" {
|
||||
h.writeError(w, http.StatusBadRequest, "非隧道模式下 target_ip 不能为空")
|
||||
return
|
||||
}
|
||||
if net.ParseIP(req.TargetIP) == nil {
|
||||
h.writeError(w, http.StatusBadRequest, "target_ip 格式无效")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 添加到数据库
|
||||
if err := h.db.AddMapping(req.Port, req.TargetIP, req.Port); err != nil {
|
||||
if err := h.db.AddMapping(req.SourcePort, req.TargetIP, req.TargetPort, req.UseTunnel); err != nil {
|
||||
h.writeError(w, http.StatusInternalServerError, "保存映射失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 启动转发器
|
||||
if err := h.forwarderMgr.Add(req.Port, req.TargetIP, req.Port); err != nil {
|
||||
var err error
|
||||
if req.UseTunnel {
|
||||
// 隧道模式:使用隧道转发
|
||||
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.SourcePort, h.tunnelServer)
|
||||
} else {
|
||||
// 直接模式:直接TCP转发
|
||||
err = h.forwarderMgr.Add(req.SourcePort, req.TargetIP, req.TargetPort)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// 回滚数据库操作
|
||||
h.db.RemoveMapping(req.Port)
|
||||
h.db.RemoveMapping(req.SourcePort)
|
||||
h.writeError(w, http.StatusInternalServerError, "启动转发失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("创建端口映射: %d -> %s:%d", req.Port, req.TargetIP, req.Port)
|
||||
log.Printf("创建端口映射: %d -> %s:%d (tunnel: %v)", req.SourcePort, req.TargetIP, req.TargetPort, req.UseTunnel)
|
||||
|
||||
h.writeSuccess(w, "端口映射创建成功", map[string]interface{}{
|
||||
"port": req.Port,
|
||||
"target_ip": req.TargetIP,
|
||||
"use_tunnel": h.useTunnel,
|
||||
"source_port": req.SourcePort,
|
||||
"target_ip": req.TargetIP,
|
||||
"target_port": req.TargetPort,
|
||||
"use_tunnel": req.UseTunnel,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -187,9 +202,8 @@ func (h *Handler) handleListMappings(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
h.writeSuccess(w, "获取映射列表成功", map[string]interface{}{
|
||||
"mappings": mappings,
|
||||
"count": len(mappings),
|
||||
"use_tunnel": h.useTunnel,
|
||||
"mappings": mappings,
|
||||
"count": len(mappings),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -197,11 +211,11 @@ func (h *Handler) handleListMappings(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := map[string]interface{}{
|
||||
"status": "ok",
|
||||
"tunnel_enabled": h.useTunnel,
|
||||
"tunnel_enabled": h.tunnelServer != nil,
|
||||
"tunnel_connected": false,
|
||||
}
|
||||
|
||||
if h.useTunnel {
|
||||
if h.tunnelServer != nil {
|
||||
status["tunnel_connected"] = h.tunnelServer.IsConnected()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,580 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"port-forward/server/db"
|
||||
"port-forward/server/forwarder"
|
||||
"port-forward/server/tunnel"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setupTestHandler 创建测试用的 Handler
|
||||
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
|
||||
// 创建临时数据库
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
database, err := db.New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("创建数据库失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建转发器管理器
|
||||
fwdMgr := forwarder.NewManager()
|
||||
|
||||
// 创建隧道服务器(如果启用)
|
||||
var tunnelServer *tunnel.Server
|
||||
if useTunnel {
|
||||
// 使用随机端口
|
||||
listener, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
tunnelServer = tunnel.NewServer(port)
|
||||
tunnelServer.Start()
|
||||
}
|
||||
|
||||
handler := NewHandler(database, fwdMgr, tunnelServer, 10000, 20000)
|
||||
|
||||
cleanup := func() {
|
||||
fwdMgr.StopAll()
|
||||
if tunnelServer != nil {
|
||||
tunnelServer.Stop()
|
||||
}
|
||||
database.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return handler, database, cleanup
|
||||
}
|
||||
|
||||
// TestNewHandler 测试创建处理器
|
||||
func TestNewHandler(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("创建处理器失败")
|
||||
}
|
||||
|
||||
if handler.portRangeFrom != 10000 {
|
||||
t.Errorf("起始端口不正确,期望 10000,得到 %d", handler.portRangeFrom)
|
||||
}
|
||||
|
||||
if handler.portRangeEnd != 20000 {
|
||||
t.Errorf("结束端口不正确,期望 20000,得到 %d", handler.portRangeEnd)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleHealth 测试健康检查
|
||||
func TestHandleHealth(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleHealth(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
err := json.NewDecoder(w.Body).Decode(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if result["status"] != "ok" {
|
||||
t.Errorf("健康状态不正确,期望 ok,得到 %v", result["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleHealthWithTunnel 测试带隧道的健康检查
|
||||
func TestHandleHealthWithTunnel(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, true)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleHealth(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
if result["tunnel_enabled"] != true {
|
||||
t.Error("隧道应该启用")
|
||||
}
|
||||
|
||||
// 隧道未连接客户端时应该为 false
|
||||
if result["tunnel_connected"] != false {
|
||||
t.Error("隧道应该未连接")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCreateMapping 测试创建映射
|
||||
func TestHandleCreateMapping(t *testing.T) {
|
||||
handler, database, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
TargetIP: "192.168.1.100",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
if !result.Success {
|
||||
t.Errorf("创建映射失败: %s", result.Message)
|
||||
}
|
||||
|
||||
// 验证数据库中存在映射
|
||||
mapping, err := database.GetMapping(15000)
|
||||
if err != nil {
|
||||
t.Fatalf("获取映射失败: %v", err)
|
||||
}
|
||||
|
||||
if mapping == nil {
|
||||
t.Fatal("映射不存在")
|
||||
}
|
||||
|
||||
if mapping.TargetIP != "192.168.1.100" {
|
||||
t.Errorf("目标 IP 不正确,期望 192.168.1.100,得到 %s", mapping.TargetIP)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCreateMappingInvalidPort 测试创建映射时端口无效
|
||||
func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port int
|
||||
}{
|
||||
{"端口太小", 5000},
|
||||
{"端口太大", 25000},
|
||||
{"端口为0", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqBody := CreateMappingRequest{
|
||||
SourcePort: tt.port,
|
||||
TargetPort: tt.port,
|
||||
TargetIP: "192.168.1.100",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCreateMappingDuplicate 测试创建重复映射
|
||||
func TestHandleCreateMappingDuplicate(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
TargetIP: "192.168.1.100",
|
||||
}
|
||||
|
||||
// 第一次创建
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("第一次创建失败")
|
||||
}
|
||||
|
||||
// 第二次创建(应该失败)
|
||||
body, _ = json.Marshal(reqBody)
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
w = httptest.NewRecorder()
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusConflict {
|
||||
t.Errorf("状态码不正确,期望 409,得到 %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCreateMappingInvalidJSON 测试无效的 JSON
|
||||
func TestHandleCreateMappingInvalidJSON(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCreateMappingInvalidIP 测试无效的 IP 地址
|
||||
func TestHandleCreateMappingInvalidIP(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
TargetIP: "invalid-ip",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCreateMappingEmptyIP 测试空 IP 地址
|
||||
func TestHandleCreateMappingEmptyIP(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
TargetIP: "",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCreateMappingTunnelNotConnected 测试隧道未连接时创建映射
|
||||
func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, true)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
UseTunnel: true, // 明确指定使用隧道模式
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("状态码不正确,期望 503,得到 %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleRemoveMapping 测试删除映射
|
||||
func TestHandleRemoveMapping(t *testing.T) {
|
||||
handler, database, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
// 先创建一个映射
|
||||
database.AddMapping(15000, "192.168.1.100", 15000, false)
|
||||
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000)
|
||||
|
||||
reqBody := RemoveMappingRequest{
|
||||
Port: 15000,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleRemoveMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
// 验证映射已删除
|
||||
mapping, _ := database.GetMapping(15000)
|
||||
if mapping != nil {
|
||||
t.Error("映射应该已被删除")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleRemoveMappingNotExist 测试删除不存在的映射
|
||||
func TestHandleRemoveMappingNotExist(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
reqBody := RemoveMappingRequest{
|
||||
Port: 15000,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleRemoveMapping(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("状态码不正确,期望 404,得到 %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleListMappings 测试列出映射
|
||||
func TestHandleListMappings(t *testing.T) {
|
||||
handler, database, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
// 添加一些映射
|
||||
database.AddMapping(15000, "192.168.1.100", 15000, false)
|
||||
database.AddMapping(15001, "192.168.1.101", 15001, true)
|
||||
database.AddMapping(15002, "192.168.1.102", 15002, false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleListMappings(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
if !result.Success {
|
||||
t.Errorf("列出映射失败: %s", result.Message)
|
||||
}
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
count := int(data["count"].(float64))
|
||||
|
||||
if count != 3 {
|
||||
t.Errorf("映射数量不正确,期望 3,得到 %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleListMappingsEmpty 测试列出空映射列表
|
||||
func TestHandleListMappingsEmpty(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.handleListMappings(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
count := int(data["count"].(float64))
|
||||
|
||||
if count != 0 {
|
||||
t.Errorf("映射数量不正确,期望 0,得到 %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleMethodNotAllowed 测试不允许的 HTTP 方法
|
||||
func TestHandleMethodNotAllowed(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler func(http.ResponseWriter, *http.Request)
|
||||
method string
|
||||
}{
|
||||
{"创建映射 GET", handler.handleCreateMapping, http.MethodGet},
|
||||
{"删除映射 GET", handler.handleRemoveMapping, http.MethodGet},
|
||||
{"列出映射 POST", handler.handleListMappings, http.MethodPost},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tt.handler(w, req)
|
||||
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("状态码不正确,期望 405,得到 %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterRoutes 测试路由注册
|
||||
func TestRegisterRoutes(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
handler.RegisterRoutes(mux)
|
||||
|
||||
// 测试路由是否注册
|
||||
routes := []string{
|
||||
"/api/mapping/create",
|
||||
"/api/mapping/remove",
|
||||
"/api/mapping/list",
|
||||
"/health",
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
req := httptest.NewRequest(http.MethodGet, route, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux.ServeHTTP(w, req)
|
||||
|
||||
// 如果路由不存在,应该返回 404
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Errorf("路由 %s 未注册", route)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteSuccess 测试成功响应
|
||||
func TestWriteSuccess(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.writeSuccess(w, "测试成功", map[string]string{"key": "value"})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
if !result.Success {
|
||||
t.Error("Success 应该为 true")
|
||||
}
|
||||
|
||||
if result.Message != "测试成功" {
|
||||
t.Errorf("消息不正确,期望 '测试成功',得到 '%s'", result.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteError 测试错误响应
|
||||
func TestWriteError(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.writeError(w, http.StatusBadRequest, "测试错误")
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
if result.Success {
|
||||
t.Error("Success 应该为 false")
|
||||
}
|
||||
|
||||
if result.Message != "测试错误" {
|
||||
t.Errorf("消息不正确,期望 '测试错误',得到 '%s'", result.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkHandleHealth 基准测试健康检查
|
||||
func BenchmarkHandleHealth(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "bench.db")
|
||||
database, _ := db.New(dbPath)
|
||||
defer database.Close()
|
||||
|
||||
fwdMgr := forwarder.NewManager()
|
||||
handler := NewHandler(database, fwdMgr, nil, 10000, 20000)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
handler.handleHealth(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkHandleListMappings 基准测试列出映射
|
||||
func BenchmarkHandleListMappings(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "bench.db")
|
||||
database, _ := db.New(dbPath)
|
||||
defer database.Close()
|
||||
|
||||
// 添加一些映射
|
||||
for i := 0; i < 100; i++ {
|
||||
useTunnel := i%2 == 0 // 偶数使用隧道模式
|
||||
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel)
|
||||
}
|
||||
|
||||
fwdMgr := forwarder.NewManager()
|
||||
handler := NewHandler(database, fwdMgr, nil, 10000, 20000)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
handler.handleListMappings(w, req)
|
||||
}
|
||||
}
|
||||
|
|
@ -57,6 +57,110 @@ database:
|
|||
}
|
||||
}
|
||||
|
||||
// TestLoadConfigFileNotFound 测试配置文件不存在
|
||||
func TestLoadConfigFileNotFound(t *testing.T) {
|
||||
_, err := Load("/nonexistent/config.yaml")
|
||||
if err == nil {
|
||||
t.Error("应该返回文件不存在错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfigInvalidYAML 测试无效的 YAML 格式
|
||||
func TestLoadConfigInvalidYAML(t *testing.T) {
|
||||
tmpFile, err := os.CreateTemp("", "config_invalid_*.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("创建临时文件失败: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// 写入无效的 YAML
|
||||
invalidYAML := `
|
||||
port_range:
|
||||
from: 10000
|
||||
end: invalid_number
|
||||
`
|
||||
tmpFile.Write([]byte(invalidYAML))
|
||||
tmpFile.Close()
|
||||
|
||||
_, err = Load(tmpFile.Name())
|
||||
if err == nil {
|
||||
t.Error("应该返回 YAML 解析错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigEdgeCases 测试配置边界情况
|
||||
func TestConfigEdgeCases(t *testing.T) {
|
||||
tmpFile, err := os.CreateTemp("", "config_edge_*.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("创建临时文件失败: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// 测试刚好 10000 个端口(边界值)
|
||||
configContent := `
|
||||
port_range:
|
||||
from: 1
|
||||
end: 10000
|
||||
tunnel:
|
||||
enabled: false
|
||||
api:
|
||||
listen_port: 8080
|
||||
database:
|
||||
path: ./data/mappings.db
|
||||
`
|
||||
tmpFile.Write([]byte(configContent))
|
||||
tmpFile.Close()
|
||||
|
||||
cfg, err := Load(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Errorf("边界值配置应该有效: %v", err)
|
||||
}
|
||||
|
||||
if cfg != nil && (cfg.PortRange.End-cfg.PortRange.From) != 9999 {
|
||||
t.Errorf("端口范围计算不正确")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLoadConfig 基准测试配置加载
|
||||
func BenchmarkLoadConfig(b *testing.B) {
|
||||
tmpFile, _ := os.CreateTemp("", "config_bench_*.yaml")
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
configContent := `
|
||||
port_range:
|
||||
from: 10000
|
||||
end: 20000
|
||||
tunnel:
|
||||
enabled: true
|
||||
listen_port: 9000
|
||||
api:
|
||||
listen_port: 8080
|
||||
database:
|
||||
path: ./data/mappings.db
|
||||
`
|
||||
tmpFile.Write([]byte(configContent))
|
||||
tmpFile.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = Load(tmpFile.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkValidateConfig 基准测试配置验证
|
||||
func BenchmarkValidateConfig(b *testing.B) {
|
||||
cfg := &Config{
|
||||
PortRange: PortRangeConfig{From: 10000, End: 20000},
|
||||
Tunnel: TunnelConfig{Enabled: true, ListenPort: 9000},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/db.sqlite"},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = cfg.Validate()
|
||||
}
|
||||
}
|
||||
func TestValidateConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ type Mapping struct {
|
|||
SourcePort int `json:"source_port"`
|
||||
TargetIP string `json:"target_ip"`
|
||||
TargetPort int `json:"target_port"`
|
||||
UseTunnel bool `json:"use_tunnel"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
|
|
@ -55,12 +56,14 @@ func New(dbPath string) (*Database, error) {
|
|||
|
||||
// initTables 初始化数据库表
|
||||
func (d *Database) initTables() error {
|
||||
// 创建表结构
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS mappings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_port INTEGER NOT NULL UNIQUE,
|
||||
target_ip TEXT NOT NULL,
|
||||
target_port INTEGER NOT NULL,
|
||||
use_tunnel BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port);
|
||||
|
|
@ -71,16 +74,59 @@ func (d *Database) initTables() error {
|
|||
return fmt.Errorf("初始化数据库表失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要迁移现有数据
|
||||
if err := d.migrateDatabase(); err != nil {
|
||||
return fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateDatabase 迁移现有数据库结构
|
||||
func (d *Database) migrateDatabase() error {
|
||||
// 检查 use_tunnel 列是否存在
|
||||
rows, err := d.db.Query("PRAGMA table_info(mappings)")
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查表结构失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
hasUseTunnel := false
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name, dataType string
|
||||
var notNull, hasDefault int
|
||||
var defaultValue interface{}
|
||||
|
||||
err := rows.Scan(&cid, &name, &dataType, ¬Null, &defaultValue, &hasDefault)
|
||||
if err != nil {
|
||||
return fmt.Errorf("扫描表结构失败: %w", err)
|
||||
}
|
||||
|
||||
if name == "use_tunnel" {
|
||||
hasUseTunnel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不存在 use_tunnel 列,则添加它
|
||||
if !hasUseTunnel {
|
||||
_, err := d.db.Exec("ALTER TABLE mappings ADD COLUMN use_tunnel BOOLEAN NOT NULL DEFAULT 0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加 use_tunnel 列失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMapping 添加端口映射
|
||||
func (d *Database) AddMapping(sourcePort int, targetIP string, targetPort int) error {
|
||||
func (d *Database) AddMapping(sourcePort int, targetIP string, targetPort int, useTunnel bool) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
query := `INSERT INTO mappings (source_port, target_ip, target_port) VALUES (?, ?, ?)`
|
||||
_, err := d.db.Exec(query, sourcePort, targetIP, targetPort)
|
||||
query := `INSERT INTO mappings (source_port, target_ip, target_port, use_tunnel) VALUES (?, ?, ?, ?)`
|
||||
_, err := d.db.Exec(query, sourcePort, targetIP, targetPort, useTunnel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加端口映射失败: %w", err)
|
||||
}
|
||||
|
|
@ -116,7 +162,7 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
|
|||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings WHERE source_port = ?`
|
||||
query := `SELECT id, source_port, target_ip, target_port, use_tunnel, created_at FROM mappings WHERE source_port = ?`
|
||||
|
||||
var mapping Mapping
|
||||
err := d.db.QueryRow(query, sourcePort).Scan(
|
||||
|
|
@ -124,6 +170,7 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
|
|||
&mapping.SourcePort,
|
||||
&mapping.TargetIP,
|
||||
&mapping.TargetPort,
|
||||
&mapping.UseTunnel,
|
||||
&mapping.CreatedAt,
|
||||
)
|
||||
|
||||
|
|
@ -142,7 +189,7 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
|
|||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings ORDER BY source_port`
|
||||
query := `SELECT id, source_port, target_ip, target_port, use_tunnel, created_at FROM mappings ORDER BY source_port`
|
||||
|
||||
rows, err := d.db.Query(query)
|
||||
if err != nil {
|
||||
|
|
@ -158,6 +205,7 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
|
|||
&mapping.SourcePort,
|
||||
&mapping.TargetIP,
|
||||
&mapping.TargetPort,
|
||||
&mapping.UseTunnel,
|
||||
&mapping.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描映射记录失败: %w", err)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ func TestDatabase(t *testing.T) {
|
|||
defer db.Close()
|
||||
|
||||
t.Run("添加映射", func(t *testing.T) {
|
||||
err := db.AddMapping(10001, "192.168.1.100", 22)
|
||||
err := db.AddMapping(10001, "192.168.1.100", 22, false)
|
||||
if err != nil {
|
||||
t.Errorf("添加映射失败: %v", err)
|
||||
}
|
||||
|
|
@ -44,7 +44,7 @@ func TestDatabase(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("添加重复映射应该失败", func(t *testing.T) {
|
||||
err := db.AddMapping(10001, "192.168.1.101", 22)
|
||||
err := db.AddMapping(10001, "192.168.1.101", 22, true)
|
||||
if err == nil {
|
||||
t.Error("添加重复映射应该失败")
|
||||
}
|
||||
|
|
@ -52,8 +52,8 @@ func TestDatabase(t *testing.T) {
|
|||
|
||||
t.Run("获取所有映射", func(t *testing.T) {
|
||||
// 添加更多映射
|
||||
db.AddMapping(10002, "192.168.1.101", 22)
|
||||
db.AddMapping(10003, "192.168.1.102", 22)
|
||||
db.AddMapping(10002, "192.168.1.101", 22, true)
|
||||
db.AddMapping(10003, "192.168.1.102", 22, false)
|
||||
|
||||
mappings, err := db.GetAllMappings()
|
||||
if err != nil {
|
||||
|
|
@ -101,7 +101,8 @@ func TestDatabaseConcurrency(t *testing.T) {
|
|||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(port int) {
|
||||
err := db.AddMapping(10000+port, "192.168.1.100", port)
|
||||
useTunnel := port%2 == 0 // 偶数端口使用隧道模式
|
||||
err := db.AddMapping(10000+port, "192.168.1.100", port, useTunnel)
|
||||
if err != nil {
|
||||
t.Logf("添加映射失败 (端口 %d): %v", 10000+port, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,17 +10,23 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// TunnelServer 隧道服务器接口
|
||||
type TunnelServer interface {
|
||||
ForwardConnection(clientConn net.Conn, targetPort int) error
|
||||
IsConnected() bool
|
||||
}
|
||||
|
||||
// Forwarder 端口转发器
|
||||
type Forwarder struct {
|
||||
sourcePort int
|
||||
targetAddr string
|
||||
listener net.Listener
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
tunnelConn net.Conn
|
||||
useTunnel bool
|
||||
mu sync.RWMutex
|
||||
sourcePort int
|
||||
targetPort int
|
||||
targetAddr string
|
||||
listener net.Listener
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
tunnelServer TunnelServer
|
||||
useTunnel bool
|
||||
}
|
||||
|
||||
// NewForwarder 创建新的端口转发器
|
||||
|
|
@ -28,6 +34,7 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Forwarder{
|
||||
sourcePort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetAddr: fmt.Sprintf("%s:%d", targetIP, targetPort),
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
|
|
@ -36,15 +43,15 @@ func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
|
|||
}
|
||||
|
||||
// NewTunnelForwarder 创建使用隧道的端口转发器
|
||||
func NewTunnelForwarder(sourcePort int, targetPort int, tunnelConn net.Conn) *Forwarder {
|
||||
func NewTunnelForwarder(sourcePort int, targetPort int, tunnelServer TunnelServer) *Forwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Forwarder{
|
||||
sourcePort: sourcePort,
|
||||
targetAddr: fmt.Sprintf("127.0.0.1:%d", targetPort),
|
||||
tunnelConn: tunnelConn,
|
||||
useTunnel: true,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
sourcePort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
tunnelServer: tunnelServer,
|
||||
useTunnel: true,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,36 +107,37 @@ func (f *Forwarder) acceptLoop() {
|
|||
// handleConnection 处理单个连接
|
||||
func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
||||
defer f.wg.Done()
|
||||
defer clientConn.Close()
|
||||
|
||||
var targetConn net.Conn
|
||||
var err error
|
||||
|
||||
if f.useTunnel {
|
||||
// 使用隧道连接
|
||||
f.mu.RLock()
|
||||
targetConn = f.tunnelConn
|
||||
f.mu.RUnlock()
|
||||
// 使用隧道转发
|
||||
if f.tunnelServer == nil || !f.tunnelServer.IsConnected() {
|
||||
log.Printf("隧道服务器不可用 (端口 %d)", f.sourcePort)
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if targetConn == nil {
|
||||
log.Printf("隧道连接不可用 (端口 %d)", f.sourcePort)
|
||||
return
|
||||
// 将连接转发到隧道,ForwardConnection 会处理连接关闭
|
||||
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetPort); err != nil {
|
||||
log.Printf("隧道转发失败 (端口 %d -> %d): %v", f.sourcePort, f.targetPort, err)
|
||||
}
|
||||
} else {
|
||||
// 直接连接目标
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
targetConn, err = dialer.DialContext(f.ctx, "tcp", f.targetAddr)
|
||||
if err != nil {
|
||||
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, f.targetAddr, err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 直接连接目标
|
||||
defer clientConn.Close()
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
targetConn, err := dialer.DialContext(f.ctx, "tcp", f.targetAddr)
|
||||
if err != nil {
|
||||
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, f.targetAddr, err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
// 双向转发
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
|
|
@ -181,13 +189,6 @@ func (f *Forwarder) Stop() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SetTunnelConn 设置隧道连接
|
||||
func (f *Forwarder) SetTunnelConn(conn net.Conn) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.tunnelConn = conn
|
||||
}
|
||||
|
||||
// Manager 转发器管理器
|
||||
type Manager struct {
|
||||
forwarders map[int]*Forwarder
|
||||
|
|
@ -220,7 +221,7 @@ func (m *Manager) Add(sourcePort int, targetIP string, targetPort int) error {
|
|||
}
|
||||
|
||||
// AddTunnel 添加使用隧道的转发器
|
||||
func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelConn net.Conn) error {
|
||||
func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelServer TunnelServer) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
|
|
@ -228,7 +229,7 @@ func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelConn net.Conn)
|
|||
return fmt.Errorf("端口 %d 已被占用", sourcePort)
|
||||
}
|
||||
|
||||
forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelConn)
|
||||
forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelServer)
|
||||
if err := forwarder.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,432 @@
|
|||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewForwarder 测试创建转发器
|
||||
func TestNewForwarder(t *testing.T) {
|
||||
fwd := NewForwarder(8080, "192.168.1.100", 80)
|
||||
|
||||
if fwd == nil {
|
||||
t.Fatal("创建转发器失败")
|
||||
}
|
||||
|
||||
if fwd.sourcePort != 8080 {
|
||||
t.Errorf("源端口不正确,期望 8080,得到 %d", fwd.sourcePort)
|
||||
}
|
||||
|
||||
if fwd.targetAddr != "192.168.1.100:80" {
|
||||
t.Errorf("目标地址不正确,期望 192.168.1.100:80,得到 %s", fwd.targetAddr)
|
||||
}
|
||||
|
||||
if fwd.useTunnel {
|
||||
t.Error("普通转发器不应使用隧道")
|
||||
}
|
||||
}
|
||||
|
||||
// mockTunnelServer 模拟隧道服务器
|
||||
type mockTunnelServer struct {
|
||||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetPort int) error {
|
||||
// 简单的模拟实现
|
||||
defer clientConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTunnelServer) IsConnected() bool {
|
||||
return m.connected
|
||||
}
|
||||
|
||||
// TestNewTunnelForwarder 测试创建隧道转发器
|
||||
func TestNewTunnelForwarder(t *testing.T) {
|
||||
// 创建模拟隧道服务器
|
||||
mockServer := &mockTunnelServer{connected: true}
|
||||
|
||||
fwd := NewTunnelForwarder(8080, 80, mockServer)
|
||||
|
||||
if fwd == nil {
|
||||
t.Fatal("创建隧道转发器失败")
|
||||
}
|
||||
|
||||
if !fwd.useTunnel {
|
||||
t.Error("隧道转发器应使用隧道")
|
||||
}
|
||||
|
||||
if fwd.tunnelServer == nil {
|
||||
t.Error("隧道服务器未设置")
|
||||
}
|
||||
}
|
||||
|
||||
// TestForwarderStartStop 测试转发器启动和停止
|
||||
func TestForwarderStartStop(t *testing.T) {
|
||||
// 创建模拟目标服务器
|
||||
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建目标服务器失败: %v", err)
|
||||
}
|
||||
defer targetListener.Close()
|
||||
|
||||
targetPort := targetListener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// 启动转发器到一个随机端口
|
||||
fwd := NewForwarder(0, "127.0.0.1", targetPort)
|
||||
|
||||
// 创建监听器
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建监听器失败: %v", err)
|
||||
}
|
||||
|
||||
fwd.listener = listener
|
||||
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// 启动接受循环
|
||||
fwd.wg.Add(1)
|
||||
go fwd.acceptLoop()
|
||||
|
||||
// 等待一段时间
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 停止转发器
|
||||
err = fwd.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("停止转发器失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestForwarderConnection 测试转发器连接处理
|
||||
func TestForwarderConnection(t *testing.T) {
|
||||
// 创建模拟目标服务器
|
||||
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建目标服务器失败: %v", err)
|
||||
}
|
||||
defer targetListener.Close()
|
||||
|
||||
targetPort := targetListener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// 在后台处理连接
|
||||
go func() {
|
||||
conn, err := targetListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 回显服务器
|
||||
io.Copy(conn, conn)
|
||||
}()
|
||||
|
||||
// 创建并启动转发器
|
||||
fwd := NewForwarder(0, "127.0.0.1", targetPort)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建监听器失败: %v", err)
|
||||
}
|
||||
fwd.listener = listener
|
||||
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
fwd.wg.Add(1)
|
||||
go fwd.acceptLoop()
|
||||
defer fwd.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 连接到转发器
|
||||
client, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", fwd.sourcePort))
|
||||
if err != nil {
|
||||
t.Fatalf("连接转发器失败: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 发送数据
|
||||
testData := []byte("Hello, World!")
|
||||
_, err = client.Write(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("发送数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
buf := make([]byte, len(testData))
|
||||
client.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := io.ReadFull(client, buf)
|
||||
if err != nil {
|
||||
t.Fatalf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
if n != len(testData) {
|
||||
t.Errorf("读取数据长度不正确,期望 %d,得到 %d", len(testData), n)
|
||||
}
|
||||
|
||||
if string(buf) != string(testData) {
|
||||
t.Errorf("数据不匹配,期望 %s,得到 %s", testData, buf)
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerAdd 测试管理器添加转发器
|
||||
func TestManagerAdd(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
// 创建模拟目标服务器
|
||||
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建目标服务器失败: %v", err)
|
||||
}
|
||||
defer targetListener.Close()
|
||||
|
||||
targetPort := targetListener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// 添加转发器到一个随机可用端口
|
||||
fwdListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建转发监听器失败: %v", err)
|
||||
}
|
||||
sourcePort := fwdListener.Addr().(*net.TCPAddr).Port
|
||||
fwdListener.Close() // 关闭以便转发器可以使用这个端口
|
||||
|
||||
err = mgr.Add(sourcePort, "127.0.0.1", targetPort)
|
||||
if err != nil {
|
||||
t.Fatalf("添加转发器失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证转发器已添加
|
||||
if !mgr.Exists(sourcePort) {
|
||||
t.Error("转发器应该存在")
|
||||
}
|
||||
|
||||
// 清理
|
||||
mgr.Remove(sourcePort)
|
||||
}
|
||||
|
||||
// TestManagerAddDuplicate 测试添加重复转发器
|
||||
func TestManagerAddDuplicate(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
// 获取一个随机端口
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建监听器失败: %v", err)
|
||||
}
|
||||
sourcePort := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
// 添加第一个转发器
|
||||
err = mgr.Add(sourcePort, "127.0.0.1", 80)
|
||||
if err != nil {
|
||||
t.Fatalf("添加第一个转发器失败: %v", err)
|
||||
}
|
||||
defer mgr.Remove(sourcePort)
|
||||
|
||||
// 尝试添加重复端口
|
||||
err = mgr.Add(sourcePort, "127.0.0.1", 81)
|
||||
if err == nil {
|
||||
t.Error("应该返回端口已占用错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerRemove 测试移除转发器
|
||||
func TestManagerRemove(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
// 获取一个随机端口
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建监听器失败: %v", err)
|
||||
}
|
||||
sourcePort := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
// 添加转发器
|
||||
err = mgr.Add(sourcePort, "127.0.0.1", 80)
|
||||
if err != nil {
|
||||
t.Fatalf("添加转发器失败: %v", err)
|
||||
}
|
||||
|
||||
// 移除转发器
|
||||
err = mgr.Remove(sourcePort)
|
||||
if err != nil {
|
||||
t.Errorf("移除转发器失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证转发器已移除
|
||||
if mgr.Exists(sourcePort) {
|
||||
t.Error("转发器应该已被移除")
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerRemoveNonExistent 测试移除不存在的转发器
|
||||
func TestManagerRemoveNonExistent(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
err := mgr.Remove(9999)
|
||||
if err == nil {
|
||||
t.Error("应该返回转发器不存在错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerExists 测试检查转发器是否存在
|
||||
func TestManagerExists(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
// 检查不存在的转发器
|
||||
if mgr.Exists(8080) {
|
||||
t.Error("转发器不应该存在")
|
||||
}
|
||||
|
||||
// 获取一个随机端口
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建监听器失败: %v", err)
|
||||
}
|
||||
sourcePort := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
// 添加转发器
|
||||
err = mgr.Add(sourcePort, "127.0.0.1", 80)
|
||||
if err != nil {
|
||||
t.Fatalf("添加转发器失败: %v", err)
|
||||
}
|
||||
defer mgr.Remove(sourcePort)
|
||||
|
||||
// 检查存在的转发器
|
||||
if !mgr.Exists(sourcePort) {
|
||||
t.Error("转发器应该存在")
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerStopAll 测试停止所有转发器
|
||||
func TestManagerStopAll(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
// 添加多个转发器
|
||||
ports := make([]int, 0)
|
||||
for i := 0; i < 3; i++ {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建监听器失败: %v", err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
err = mgr.Add(port, "127.0.0.1", 80+i)
|
||||
if err != nil {
|
||||
t.Fatalf("添加转发器 %d 失败: %v", i, err)
|
||||
}
|
||||
ports = append(ports, port)
|
||||
}
|
||||
|
||||
// 停止所有转发器
|
||||
mgr.StopAll()
|
||||
|
||||
// 验证所有转发器已停止
|
||||
for _, port := range ports {
|
||||
if mgr.Exists(port) {
|
||||
t.Errorf("端口 %d 的转发器应该已停止", port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestForwarderContextCancellation 测试上下文取消
|
||||
func TestForwarderContextCancellation(t *testing.T) {
|
||||
fwd := NewForwarder(0, "127.0.0.1", 80)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建监听器失败: %v", err)
|
||||
}
|
||||
fwd.listener = listener
|
||||
|
||||
fwd.wg.Add(1)
|
||||
go fwd.acceptLoop()
|
||||
|
||||
// 取消上下文
|
||||
fwd.cancel()
|
||||
|
||||
// 等待 goroutine 退出
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
fwd.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// 成功退出
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("上下文取消后 goroutine 未退出")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkForwarderConnection 基准测试转发器连接
|
||||
func BenchmarkForwarderConnection(b *testing.B) {
|
||||
// 创建模拟目标服务器
|
||||
targetListener, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
defer targetListener.Close()
|
||||
|
||||
targetPort := targetListener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// 后台处理连接
|
||||
go func() {
|
||||
for {
|
||||
conn, err := targetListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
io.Copy(io.Discard, c)
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
// 创建转发器
|
||||
fwd := NewForwarder(0, "127.0.0.1", targetPort)
|
||||
listener, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
fwd.listener = listener
|
||||
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
fwd.wg.Add(1)
|
||||
go fwd.acceptLoop()
|
||||
defer fwd.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", fwd.sourcePort))
|
||||
if err != nil {
|
||||
b.Fatalf("连接失败: %v", err)
|
||||
}
|
||||
conn.Write([]byte("test"))
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkManagerOperations 基准测试管理器操作
|
||||
func BenchmarkManagerOperations(b *testing.B) {
|
||||
mgr := NewManager()
|
||||
|
||||
b.Run("Add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
listener, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
mgr.Add(port, "127.0.0.1", 80)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Exists", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
mgr.Exists(8080)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -64,8 +64,22 @@ func main() {
|
|||
continue
|
||||
}
|
||||
|
||||
log.Printf("恢复端口映射: %d -> %s:%d", mapping.SourcePort, mapping.TargetIP, mapping.TargetPort)
|
||||
if err := fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort); err != nil {
|
||||
log.Printf("恢复端口映射: %d -> %s:%d (tunnel: %v)", mapping.SourcePort, mapping.TargetIP, mapping.TargetPort, mapping.UseTunnel)
|
||||
|
||||
var err error
|
||||
if mapping.UseTunnel {
|
||||
// 隧道模式:检查隧道服务器是否可用
|
||||
if !cfg.Tunnel.Enabled || tunnelServer == nil {
|
||||
log.Printf("警告: 端口 %d 需要隧道模式但隧道服务未启用,跳过", mapping.SourcePort)
|
||||
continue
|
||||
}
|
||||
err = fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetPort, tunnelServer)
|
||||
} else {
|
||||
// 直接模式
|
||||
err = fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("警告: 启动端口 %d 的转发失败: %v", mapping.SourcePort, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -80,7 +94,6 @@ func main() {
|
|||
tunnelServer,
|
||||
cfg.PortRange.From,
|
||||
cfg.PortRange.End,
|
||||
cfg.Tunnel.Enabled,
|
||||
)
|
||||
|
||||
// 启动 HTTP API 服务器
|
||||
|
|
@ -107,7 +120,7 @@ func main() {
|
|||
log.Println("\n接收到关闭信号,正在优雅关闭...")
|
||||
|
||||
// 创建关闭上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 停止所有转发器
|
||||
|
|
|
|||
|
|
@ -11,50 +11,113 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// Protocol 定义隧道协议
|
||||
// 消息格式: [4字节长度][4字节端口][数据]
|
||||
// Tunnel协议定义
|
||||
// 消息格式: | 版本(1B) | 类型(1B) | 长度(4B) | 数据 |
|
||||
|
||||
const (
|
||||
// HeaderSize 消息头大小(长度+端口)
|
||||
HeaderSize = 8
|
||||
// MaxPacketSize 最大包大小 (1MB)
|
||||
// 协议版本
|
||||
ProtocolVersion = 0x01
|
||||
|
||||
// 消息头大小
|
||||
HeaderSize = 6 // 版本(1) + 类型(1) + 长度(4)
|
||||
|
||||
// 最大包大小 (1MB)
|
||||
MaxPacketSize = 1024 * 1024
|
||||
|
||||
// 消息类型
|
||||
MsgTypeConnectRequest = 0x01 // 连接请求
|
||||
MsgTypeConnectResponse = 0x02 // 连接响应
|
||||
MsgTypeData = 0x03 // 数据传输
|
||||
MsgTypeClose = 0x04 // 关闭连接
|
||||
MsgTypeKeepAlive = 0x05 // 心跳
|
||||
|
||||
// 连接响应状态
|
||||
ConnStatusSuccess = 0x00 // 连接成功
|
||||
ConnStatusFailed = 0x01 // 连接失败
|
||||
|
||||
// 超时设置
|
||||
ConnectTimeout = 10 * time.Second // 连接超时
|
||||
ReadTimeout = 30 * time.Second // 读取超时
|
||||
)
|
||||
|
||||
// TunnelMessage 隧道消息
|
||||
type TunnelMessage struct {
|
||||
Version byte
|
||||
Type byte
|
||||
Length uint32
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// ConnectRequestData 连接请求数据
|
||||
type ConnectRequestData struct {
|
||||
ConnID uint32 // 连接ID
|
||||
TargetPort uint16 // 目标端口
|
||||
}
|
||||
|
||||
// ConnectResponseData 连接响应数据
|
||||
type ConnectResponseData struct {
|
||||
ConnID uint32 // 连接ID
|
||||
Status byte // 状态码
|
||||
}
|
||||
|
||||
// DataMessage 数据消息
|
||||
type DataMessage struct {
|
||||
ConnID uint32 // 连接ID
|
||||
Data []byte // 数据
|
||||
}
|
||||
|
||||
// CloseMessage 关闭消息
|
||||
type CloseMessage struct {
|
||||
ConnID uint32 // 连接ID
|
||||
}
|
||||
|
||||
// PendingConnection 待处理连接
|
||||
type PendingConnection struct {
|
||||
ID uint32
|
||||
ClientConn net.Conn
|
||||
TargetPort int
|
||||
Created time.Time
|
||||
ResponseChan chan bool // 用于接收连接响应
|
||||
}
|
||||
|
||||
// ActiveConnection 活跃连接
|
||||
type ActiveConnection struct {
|
||||
ID uint32
|
||||
ClientConn net.Conn
|
||||
TargetPort int
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
// Server 内网穿透服务器
|
||||
type Server struct {
|
||||
listenPort int
|
||||
listener net.Listener
|
||||
client net.Conn
|
||||
tunnelConn net.Conn
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
|
||||
// 连接管理
|
||||
connections map[uint32]*Connection
|
||||
connMu sync.RWMutex
|
||||
nextConnID uint32
|
||||
}
|
||||
|
||||
// Connection 表示一个客户端连接
|
||||
type Connection struct {
|
||||
ID uint32
|
||||
TargetPort int
|
||||
ClientConn net.Conn
|
||||
readChan chan []byte
|
||||
writeChan chan []byte
|
||||
closeChan chan struct{}
|
||||
pendingConns map[uint32]*PendingConnection // 待确认连接
|
||||
activeConns map[uint32]*ActiveConnection // 活跃连接
|
||||
connMu sync.RWMutex
|
||||
nextConnID uint32
|
||||
|
||||
// 消息队列
|
||||
sendChan chan *TunnelMessage
|
||||
}
|
||||
|
||||
// NewServer 创建新的隧道服务器
|
||||
func NewServer(listenPort int) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Server{
|
||||
listenPort: listenPort,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
connections: make(map[uint32]*Connection),
|
||||
listenPort: listenPort,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
pendingConns: make(map[uint32]*PendingConnection),
|
||||
activeConns: make(map[uint32]*ActiveConnection),
|
||||
sendChan: make(chan *TunnelMessage, 1000),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -102,42 +165,47 @@ func (s *Server) acceptLoop() {
|
|||
|
||||
// 只允许一个客户端连接
|
||||
s.mu.Lock()
|
||||
if s.client != nil {
|
||||
if s.tunnelConn != nil {
|
||||
log.Printf("拒绝额外的隧道连接: %s", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
s.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
s.client = conn
|
||||
s.tunnelConn = conn
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("隧道客户端已连接: %s", conn.RemoteAddr())
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.handleClient(conn)
|
||||
s.wg.Add(2)
|
||||
go s.handleTunnelRead(conn)
|
||||
go s.handleTunnelWrite(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleClient 处理客户端连接
|
||||
func (s *Server) handleClient(conn net.Conn) {
|
||||
// handleTunnelRead 处理隧道读取
|
||||
func (s *Server) handleTunnelRead(conn net.Conn) {
|
||||
defer s.wg.Done()
|
||||
defer func() {
|
||||
conn.Close()
|
||||
s.mu.Lock()
|
||||
s.client = nil
|
||||
s.tunnelConn = nil
|
||||
s.mu.Unlock()
|
||||
log.Printf("隧道客户端已断开")
|
||||
|
||||
// 关闭所有活动连接
|
||||
s.connMu.Lock()
|
||||
for _, c := range s.connections {
|
||||
close(c.closeChan)
|
||||
for _, c := range s.pendingConns {
|
||||
c.ClientConn.Close()
|
||||
close(c.ResponseChan)
|
||||
}
|
||||
s.connections = make(map[uint32]*Connection)
|
||||
for _, c := range s.activeConns {
|
||||
c.ClientConn.Close()
|
||||
}
|
||||
s.pendingConns = make(map[uint32]*PendingConnection)
|
||||
s.activeConns = make(map[uint32]*ActiveConnection)
|
||||
s.connMu.Unlock()
|
||||
}()
|
||||
|
||||
// 读取来自客户端的数据
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
|
|
@ -145,147 +213,374 @@ func (s *Server) handleClient(conn net.Conn) {
|
|||
default:
|
||||
}
|
||||
|
||||
// 读取消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
msg, err := s.readTunnelMessage(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Printf("读取隧道消息头失败: %v", err)
|
||||
log.Printf("读取隧道消息失败: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
dataLen := binary.BigEndian.Uint32(header[0:4])
|
||||
connID := binary.BigEndian.Uint32(header[4:8])
|
||||
s.handleTunnelMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
if dataLen > MaxPacketSize {
|
||||
log.Printf("数据包过大: %d bytes", dataLen)
|
||||
// handleTunnelWrite 处理隧道写入
|
||||
func (s *Server) handleTunnelWrite(conn net.Conn) {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
data := make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(conn, data); err != nil {
|
||||
log.Printf("读取隧道数据失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 将数据发送到对应的连接
|
||||
s.connMu.RLock()
|
||||
connection, exists := s.connections[connID]
|
||||
s.connMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
select {
|
||||
case connection.readChan <- data:
|
||||
case <-connection.closeChan:
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("向连接 %d 发送数据超时", connID)
|
||||
case msg := <-s.sendChan:
|
||||
if err := s.writeTunnelMessage(conn, msg); err != nil {
|
||||
log.Printf("写入隧道消息失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardConnection 转发连接到隧道
|
||||
// readTunnelMessage 读取隧道消息
|
||||
func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
|
||||
// 读取消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
version := header[0]
|
||||
msgType := header[1]
|
||||
dataLen := binary.BigEndian.Uint32(header[2:6])
|
||||
|
||||
if version != ProtocolVersion {
|
||||
return nil, fmt.Errorf("不支持的协议版本: %d", version)
|
||||
}
|
||||
|
||||
if dataLen > MaxPacketSize {
|
||||
return nil, fmt.Errorf("数据包过大: %d bytes", dataLen)
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
var data []byte
|
||||
if dataLen > 0 {
|
||||
data = make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(conn, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &TunnelMessage{
|
||||
Version: version,
|
||||
Type: msgType,
|
||||
Length: dataLen,
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// writeTunnelMessage 写入隧道消息
|
||||
func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
|
||||
// 构建消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
header[0] = msg.Version
|
||||
header[1] = msg.Type
|
||||
binary.BigEndian.PutUint32(header[2:6], msg.Length)
|
||||
|
||||
// 写入消息头
|
||||
if _, err := conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 写入数据
|
||||
if msg.Length > 0 && msg.Data != nil {
|
||||
if _, err := conn.Write(msg.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTunnelMessage 处理隧道消息
|
||||
func (s *Server) handleTunnelMessage(msg *TunnelMessage) {
|
||||
switch msg.Type {
|
||||
case MsgTypeConnectResponse:
|
||||
s.handleConnectResponse(msg)
|
||||
case MsgTypeData:
|
||||
s.handleDataMessage(msg)
|
||||
case MsgTypeClose:
|
||||
s.handleCloseMessage(msg)
|
||||
case MsgTypeKeepAlive:
|
||||
s.handleKeepAlive(msg)
|
||||
default:
|
||||
log.Printf("未知消息类型: %d", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnectResponse 处理连接响应
|
||||
func (s *Server) handleConnectResponse(msg *TunnelMessage) {
|
||||
if len(msg.Data) < 5 {
|
||||
log.Printf("连接响应数据太短")
|
||||
return
|
||||
}
|
||||
|
||||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||||
status := msg.Data[4]
|
||||
|
||||
s.connMu.Lock()
|
||||
pending, exists := s.pendingConns[connID]
|
||||
if !exists {
|
||||
s.connMu.Unlock()
|
||||
log.Printf("收到未知连接的响应: %d", connID)
|
||||
return
|
||||
}
|
||||
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
|
||||
if status == ConnStatusSuccess {
|
||||
// 连接成功,移到活跃连接
|
||||
active := &ActiveConnection{
|
||||
ID: connID,
|
||||
ClientConn: pending.ClientConn,
|
||||
TargetPort: pending.TargetPort,
|
||||
Created: time.Now(),
|
||||
}
|
||||
|
||||
s.connMu.Lock()
|
||||
s.activeConns[connID] = active
|
||||
s.connMu.Unlock()
|
||||
|
||||
log.Printf("连接已建立: ID=%d, 端口=%d", connID, pending.TargetPort)
|
||||
|
||||
// 启动数据转发
|
||||
s.wg.Add(1)
|
||||
go s.forwardData(active)
|
||||
|
||||
// 通知等待的goroutine
|
||||
select {
|
||||
case pending.ResponseChan <- true:
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
// 连接失败
|
||||
log.Printf("连接失败: ID=%d, 状态=%d", connID, status)
|
||||
pending.ClientConn.Close()
|
||||
|
||||
// 通知等待的goroutine
|
||||
select {
|
||||
case pending.ResponseChan <- false:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
close(pending.ResponseChan)
|
||||
}
|
||||
|
||||
// handleDataMessage 处理数据消息
|
||||
func (s *Server) handleDataMessage(msg *TunnelMessage) {
|
||||
if len(msg.Data) < 4 {
|
||||
log.Printf("数据消息太短")
|
||||
return
|
||||
}
|
||||
|
||||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||||
data := msg.Data[4:]
|
||||
|
||||
s.connMu.RLock()
|
||||
active, exists := s.activeConns[connID]
|
||||
s.connMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
log.Printf("收到未知连接的数据: %d", connID)
|
||||
return
|
||||
}
|
||||
|
||||
// 写入到客户端连接
|
||||
if _, err := active.ClientConn.Write(data); err != nil {
|
||||
log.Printf("写入客户端连接失败 (ID=%d): %v", connID, err)
|
||||
s.closeConnection(connID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCloseMessage 处理关闭消息
|
||||
func (s *Server) handleCloseMessage(msg *TunnelMessage) {
|
||||
if len(msg.Data) < 4 {
|
||||
log.Printf("关闭消息数据太短")
|
||||
return
|
||||
}
|
||||
|
||||
connID := binary.BigEndian.Uint32(msg.Data[0:4])
|
||||
s.closeConnection(connID)
|
||||
}
|
||||
|
||||
// handleKeepAlive 处理心跳消息
|
||||
func (s *Server) handleKeepAlive(msg *TunnelMessage) {
|
||||
// 回应心跳
|
||||
response := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeKeepAlive,
|
||||
Length: 0,
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.sendChan <- response:
|
||||
default:
|
||||
log.Printf("发送心跳响应失败: 发送队列已满")
|
||||
}
|
||||
}
|
||||
|
||||
// forwardData 转发数据
|
||||
func (s *Server) forwardData(active *ActiveConnection) {
|
||||
defer s.wg.Done()
|
||||
defer s.closeConnection(active.ID)
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
active.ClientConn.SetReadDeadline(time.Now().Add(ReadTimeout))
|
||||
n, err := active.ClientConn.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF && !isTimeout(err) {
|
||||
log.Printf("读取客户端连接失败 (ID=%d): %v", active.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送数据到隧道
|
||||
dataMsg := make([]byte, 4+n)
|
||||
binary.BigEndian.PutUint32(dataMsg[0:4], active.ID)
|
||||
copy(dataMsg[4:], buffer[:n])
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(dataMsg)),
|
||||
Data: dataMsg,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.sendChan <- msg:
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("发送数据超时 (ID=%d)", active.ID)
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnection 关闭连接
|
||||
func (s *Server) closeConnection(connID uint32) {
|
||||
s.connMu.Lock()
|
||||
active, exists := s.activeConns[connID]
|
||||
if exists {
|
||||
delete(s.activeConns, connID)
|
||||
active.ClientConn.Close()
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
// 发送关闭消息
|
||||
closeData := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(closeData, connID)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeClose,
|
||||
Length: 4,
|
||||
Data: closeData,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.sendChan <- msg:
|
||||
default:
|
||||
// 发送队列满,忽略
|
||||
}
|
||||
|
||||
if exists {
|
||||
log.Printf("连接已关闭: ID=%d", connID)
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardConnection 转发连接到隧道(新的透明代理实现)
|
||||
func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
|
||||
s.mu.RLock()
|
||||
tunnelConn := s.client
|
||||
tunnelConnected := s.tunnelConn != nil
|
||||
s.mu.RUnlock()
|
||||
|
||||
if tunnelConn == nil {
|
||||
if !tunnelConnected {
|
||||
return fmt.Errorf("隧道连接不可用")
|
||||
}
|
||||
|
||||
// 创建连接对象
|
||||
// 创建待处理连接
|
||||
s.connMu.Lock()
|
||||
connID := s.nextConnID
|
||||
s.nextConnID++
|
||||
|
||||
connection := &Connection{
|
||||
ID: connID,
|
||||
TargetPort: targetPort,
|
||||
ClientConn: clientConn,
|
||||
readChan: make(chan []byte, 100),
|
||||
writeChan: make(chan []byte, 100),
|
||||
closeChan: make(chan struct{}),
|
||||
pending := &PendingConnection{
|
||||
ID: connID,
|
||||
ClientConn: clientConn,
|
||||
TargetPort: targetPort,
|
||||
Created: time.Now(),
|
||||
ResponseChan: make(chan bool, 1),
|
||||
}
|
||||
s.connections[connID] = connection
|
||||
s.pendingConns[connID] = pending
|
||||
s.connMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.connMu.Lock()
|
||||
delete(s.connections, connID)
|
||||
s.connMu.Unlock()
|
||||
close(connection.closeChan)
|
||||
clientConn.Close()
|
||||
}()
|
||||
// 发送连接请求
|
||||
reqData := make([]byte, 6)
|
||||
binary.BigEndian.PutUint32(reqData[0:4], connID)
|
||||
binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort))
|
||||
|
||||
// 启动读写协程
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
// 从客户端读取并发送到隧道
|
||||
go func() {
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-connection.closeChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
clientConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, err := clientConn.Read(buffer)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// 发送到隧道
|
||||
data := make([]byte, HeaderSize+n)
|
||||
binary.BigEndian.PutUint32(data[0:4], uint32(n))
|
||||
binary.BigEndian.PutUint32(data[4:8], connID)
|
||||
copy(data[HeaderSize:], buffer[:n])
|
||||
|
||||
s.mu.RLock()
|
||||
_, err = tunnelConn.Write(data)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 从隧道读取并发送到客户端
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case data := <-connection.readChan:
|
||||
if _, err := clientConn.Write(data); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
case <-connection.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待错误或关闭
|
||||
select {
|
||||
case <-errChan:
|
||||
case <-connection.closeChan:
|
||||
case <-s.ctx.Done():
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeConnectRequest,
|
||||
Length: 6,
|
||||
Data: reqData,
|
||||
}
|
||||
|
||||
return nil
|
||||
select {
|
||||
case s.sendChan <- msg:
|
||||
case <-time.After(5 * time.Second):
|
||||
s.connMu.Lock()
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
return fmt.Errorf("发送连接请求超时")
|
||||
}
|
||||
|
||||
log.Printf("发送连接请求: ID=%d, 端口=%d", connID, targetPort)
|
||||
|
||||
// 等待连接响应
|
||||
select {
|
||||
case success := <-pending.ResponseChan:
|
||||
if success {
|
||||
return nil // 连接建立成功,forwardData会处理后续的数据转发
|
||||
} else {
|
||||
return fmt.Errorf("远程连接失败")
|
||||
}
|
||||
case <-time.After(ConnectTimeout):
|
||||
s.connMu.Lock()
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
clientConn.Close()
|
||||
return fmt.Errorf("连接超时")
|
||||
case <-s.ctx.Done():
|
||||
return fmt.Errorf("服务器关闭")
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected 检查隧道是否已连接
|
||||
func (s *Server) IsConnected() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.client != nil
|
||||
return s.tunnelConn != nil
|
||||
}
|
||||
|
||||
// Stop 停止隧道服务器
|
||||
|
|
@ -297,8 +592,8 @@ func (s *Server) Stop() error {
|
|||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if s.client != nil {
|
||||
s.client.Close()
|
||||
if s.tunnelConn != nil {
|
||||
s.tunnelConn.Close()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
|
|
@ -317,4 +612,12 @@ func (s *Server) Stop() error {
|
|||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isTimeout 检查是否为超时错误
|
||||
func isTimeout(err error) bool {
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,500 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewServer 测试创建隧道服务器
|
||||
func TestNewServer(t *testing.T) {
|
||||
server := NewServer(9000)
|
||||
|
||||
if server == nil {
|
||||
t.Fatal("创建隧道服务器失败")
|
||||
}
|
||||
|
||||
if server.listenPort != 9000 {
|
||||
t.Errorf("监听端口不正确,期望 9000,得到 %d", server.listenPort)
|
||||
}
|
||||
|
||||
if server.pendingConns == nil {
|
||||
t.Error("待处理连接映射未初始化")
|
||||
}
|
||||
|
||||
if server.activeConns == nil {
|
||||
t.Error("活跃连接映射未初始化")
|
||||
}
|
||||
|
||||
if server.sendChan == nil {
|
||||
t.Error("发送通道未初始化")
|
||||
}
|
||||
|
||||
if server.ctx == nil {
|
||||
t.Error("上下文未初始化")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerStartStop 测试服务器启动和停止
|
||||
func TestServerStartStop(t *testing.T) {
|
||||
// 使用随机端口
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("获取随机端口失败: %v", err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
server := NewServer(port)
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动服务器失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证服务器是否监听端口
|
||||
conn, err := net.Dial("tcp", server.listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("无法连接到服务器: %v", err)
|
||||
} else {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
// 停止服务器
|
||||
err = server.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("停止服务器失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTunnelMessage 测试隧道消息的序列化和反序列化
|
||||
func TestTunnelMessage(t *testing.T) {
|
||||
// 创建测试消息
|
||||
data := []byte("test data")
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(data)),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
server := NewServer(9000)
|
||||
|
||||
// 测试写入消息
|
||||
go func() {
|
||||
err := server.writeTunnelMessage(serverConn, msg)
|
||||
if err != nil {
|
||||
t.Errorf("写入隧道消息失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 测试读取消息
|
||||
receivedMsg, err := server.readTunnelMessage(clientConn)
|
||||
if err != nil {
|
||||
t.Fatalf("读取隧道消息失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证消息内容
|
||||
if receivedMsg.Version != msg.Version {
|
||||
t.Errorf("版本不匹配,期望 %d,得到 %d", msg.Version, receivedMsg.Version)
|
||||
}
|
||||
|
||||
if receivedMsg.Type != msg.Type {
|
||||
t.Errorf("类型不匹配,期望 %d,得到 %d", msg.Type, receivedMsg.Type)
|
||||
}
|
||||
|
||||
if receivedMsg.Length != msg.Length {
|
||||
t.Errorf("长度不匹配,期望 %d,得到 %d", msg.Length, receivedMsg.Length)
|
||||
}
|
||||
|
||||
if string(receivedMsg.Data) != string(msg.Data) {
|
||||
t.Errorf("数据不匹配,期望 %s,得到 %s", string(msg.Data), string(receivedMsg.Data))
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectRequest 测试连接请求处理
|
||||
func TestConnectRequest(t *testing.T) {
|
||||
// 启动一个测试HTTP服务器
|
||||
testListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("启动测试服务器失败: %v", err)
|
||||
}
|
||||
defer testListener.Close()
|
||||
|
||||
testPort := testListener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// 启动一个简单的echo服务
|
||||
go func() {
|
||||
for {
|
||||
conn, err := testListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
io.Copy(c, c) // echo服务
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
// 模拟测试(实际测试需要更复杂的设置)
|
||||
t.Logf("测试服务器运行在端口: %d", testPort)
|
||||
|
||||
// 创建隧道服务器,使用随机端口
|
||||
tunnelListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建隧道监听器失败: %v", err)
|
||||
}
|
||||
defer tunnelListener.Close()
|
||||
|
||||
tunnelPort := tunnelListener.Addr().(*net.TCPAddr).Port
|
||||
tunnelListener.Close() // 关闭以便服务器重新绑定
|
||||
|
||||
server := NewServer(tunnelPort)
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动隧道服务器失败: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 模拟客户端连接
|
||||
tunnelConn, err := net.Dial("tcp", server.listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("连接隧道服务器失败: %v", err)
|
||||
}
|
||||
defer tunnelConn.Close()
|
||||
|
||||
// 等待隧道连接被处理
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证隧道是否已连接
|
||||
if !server.IsConnected() {
|
||||
t.Error("隧道未连接")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProtocolVersionCheck 测试协议版本检查
|
||||
func TestProtocolVersionCheck(t *testing.T) {
|
||||
server := NewServer(9000)
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
// 发送错误版本的消息
|
||||
wrongVersionHeader := make([]byte, HeaderSize)
|
||||
wrongVersionHeader[0] = 0xFF // 错误版本
|
||||
wrongVersionHeader[1] = MsgTypeData
|
||||
binary.BigEndian.PutUint32(wrongVersionHeader[2:6], 0)
|
||||
|
||||
go func() {
|
||||
clientConn.Write(wrongVersionHeader)
|
||||
}()
|
||||
|
||||
// 尝试读取消息,应该返回错误
|
||||
_, err := server.readTunnelMessage(serverConn)
|
||||
if err == nil {
|
||||
t.Error("期望版本检查失败,但成功了")
|
||||
}
|
||||
|
||||
if err.Error() != "不支持的协议版本: 255" {
|
||||
t.Errorf("错误消息不正确: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaxPacketSizeCheck 测试最大包大小检查
|
||||
func TestMaxPacketSizeCheck(t *testing.T) {
|
||||
server := NewServer(9000)
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
// 发送超大包
|
||||
oversizedHeader := make([]byte, HeaderSize)
|
||||
oversizedHeader[0] = ProtocolVersion
|
||||
oversizedHeader[1] = MsgTypeData
|
||||
binary.BigEndian.PutUint32(oversizedHeader[2:6], MaxPacketSize+1)
|
||||
|
||||
go func() {
|
||||
clientConn.Write(oversizedHeader)
|
||||
}()
|
||||
|
||||
// 尝试读取消息,应该返回错误
|
||||
_, err := server.readTunnelMessage(serverConn)
|
||||
if err == nil {
|
||||
t.Error("期望包大小检查失败,但成功了")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentConnections 测试并发连接处理
|
||||
func TestConcurrentConnections(t *testing.T) {
|
||||
// 使用随机端口
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("获取随机端口失败: %v", err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
server := NewServer(port)
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动服务器失败: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
connCount := 5
|
||||
|
||||
// 并发创建多个连接
|
||||
for i := 0; i < connCount; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
conn, err := net.Dial("tcp", server.listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("连接 %d 失败: %v", id, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 保持连接一段时间
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 验证只有一个隧道连接被接受
|
||||
if server.IsConnected() {
|
||||
// 应该只有一个连接被接受
|
||||
t.Log("隧道连接已建立")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeepAlive 测试心跳消息
|
||||
func TestKeepAlive(t *testing.T) {
|
||||
server := NewServer(9000)
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
// 创建心跳消息
|
||||
keepAliveMsg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeKeepAlive,
|
||||
Length: 0,
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
// 启动消息处理
|
||||
go func() {
|
||||
for {
|
||||
msg, err := server.readTunnelMessage(clientConn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
server.handleTunnelMessage(msg)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动发送处理
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case msg := <-server.sendChan:
|
||||
server.writeTunnelMessage(serverConn, msg)
|
||||
case <-time.After(time.Second):
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 发送心跳
|
||||
err := server.writeTunnelMessage(serverConn, keepAliveMsg)
|
||||
if err != nil {
|
||||
t.Fatalf("发送心跳失败: %v", err)
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
response, err := server.readTunnelMessage(clientConn)
|
||||
if err != nil {
|
||||
t.Fatalf("读取心跳响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Type != MsgTypeKeepAlive {
|
||||
t.Errorf("心跳响应类型不正确,期望 %d,得到 %d", MsgTypeKeepAlive, response.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// MockClient 模拟客户端用于测试
|
||||
type MockClient struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (mc *MockClient) sendConnectResponse(connID uint32, status byte) error {
|
||||
responseData := make([]byte, 5)
|
||||
binary.BigEndian.PutUint32(responseData[0:4], connID)
|
||||
responseData[4] = status
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeConnectResponse,
|
||||
Length: 5,
|
||||
Data: responseData,
|
||||
}
|
||||
|
||||
// 构建消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
header[0] = msg.Version
|
||||
header[1] = msg.Type
|
||||
binary.BigEndian.PutUint32(header[2:6], msg.Length)
|
||||
|
||||
// 写入消息
|
||||
if _, err := mc.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.Length > 0 && msg.Data != nil {
|
||||
if _, err := mc.conn.Write(msg.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestForwardConnection 测试连接转发
|
||||
func TestForwardConnection(t *testing.T) {
|
||||
// 启动测试目标服务
|
||||
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("启动目标服务失败: %v", err)
|
||||
}
|
||||
defer targetListener.Close()
|
||||
|
||||
targetPort := targetListener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// 启动简单的echo服务
|
||||
go func() {
|
||||
for {
|
||||
conn, err := targetListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
io.Copy(c, c)
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
// 创建隧道服务器,使用随机端口
|
||||
tunnelListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建隧道监听器失败: %v", err)
|
||||
}
|
||||
tunnelPort := tunnelListener.Addr().(*net.TCPAddr).Port
|
||||
tunnelListener.Close() // 关闭以便服务器重新绑定
|
||||
|
||||
server := NewServer(tunnelPort)
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动隧道服务器失败: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 连接到隧道服务器(模拟客户端)
|
||||
tunnelConn, err := net.Dial("tcp", server.listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("连接隧道服务器失败: %v", err)
|
||||
}
|
||||
defer tunnelConn.Close()
|
||||
|
||||
// 等待隧道连接建立
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if !server.IsConnected() {
|
||||
t.Fatal("隧道未连接")
|
||||
}
|
||||
|
||||
// 创建模拟客户端连接
|
||||
clientConn, serverSideConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverSideConn.Close()
|
||||
|
||||
// 创建模拟客户端
|
||||
mockClient := &MockClient{conn: tunnelConn}
|
||||
|
||||
// 启动连接转发
|
||||
go func() {
|
||||
err := server.ForwardConnection(serverSideConn, targetPort)
|
||||
if err != nil {
|
||||
t.Errorf("转发连接失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 读取连接请求
|
||||
header := make([]byte, HeaderSize)
|
||||
_, err = io.ReadFull(tunnelConn, header)
|
||||
if err != nil {
|
||||
t.Fatalf("读取连接请求头失败: %v", err)
|
||||
}
|
||||
|
||||
if header[1] != MsgTypeConnectRequest {
|
||||
t.Fatalf("期望连接请求,得到消息类型: %d", header[1])
|
||||
}
|
||||
|
||||
dataLen := binary.BigEndian.Uint32(header[2:6])
|
||||
data := make([]byte, dataLen)
|
||||
_, err = io.ReadFull(tunnelConn, data)
|
||||
if err != nil {
|
||||
t.Fatalf("读取连接请求数据失败: %v", err)
|
||||
}
|
||||
|
||||
connID := binary.BigEndian.Uint32(data[0:4])
|
||||
requestedPort := binary.BigEndian.Uint16(data[4:6])
|
||||
|
||||
if int(requestedPort) != targetPort {
|
||||
t.Errorf("请求端口不匹配,期望 %d,得到 %d", targetPort, requestedPort)
|
||||
}
|
||||
|
||||
// 发送连接成功响应
|
||||
err = mockClient.sendConnectResponse(connID, ConnStatusSuccess)
|
||||
if err != nil {
|
||||
t.Fatalf("发送连接响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待连接建立
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
t.Log("连接转发测试完成")
|
||||
}
|
||||
|
|
@ -0,0 +1,225 @@
|
|||
# 集成测试文档
|
||||
|
||||
本目录包含 go-tunnel 项目的集成测试,用于测试服务器和客户端之间的端口转发功能。
|
||||
|
||||
## 前置条件
|
||||
|
||||
在运行测试之前,需要:
|
||||
|
||||
1. **启动服务器**
|
||||
```bash
|
||||
cd src
|
||||
make run-server
|
||||
# 或者
|
||||
./bin/server -config config.yaml
|
||||
```
|
||||
|
||||
2. **启动客户端**
|
||||
```bash
|
||||
cd src
|
||||
make run-client
|
||||
# 或者
|
||||
./bin/client -server localhost:9000
|
||||
```
|
||||
|
||||
## 测试配置
|
||||
|
||||
测试使用的默认配置如下(可以在 `integration_test.go` 中修改):
|
||||
|
||||
```go
|
||||
ServerAPIAddr: "http://localhost:8080", // 服务器 API 地址
|
||||
ServerTunnelPort: 9000, // 服务器隧道端口
|
||||
TestPort: 30001, // 测试使用的端口
|
||||
LocalServicePort: 8888, // 本地 echo 服务端口
|
||||
```
|
||||
|
||||
如果你的服务器使用不同的端口,请修改这些配置。
|
||||
|
||||
## 运行测试
|
||||
|
||||
### 运行所有测试
|
||||
|
||||
```bash
|
||||
cd src/test
|
||||
go test -v
|
||||
```
|
||||
|
||||
### 运行特定测试
|
||||
|
||||
```bash
|
||||
# 测试基本转发功能
|
||||
go test -v -run TestForwardingBasic
|
||||
|
||||
# 测试多端口转发
|
||||
go test -v -run TestMultipleForwards
|
||||
|
||||
# 测试动态添加/删除映射
|
||||
go test -v -run TestAddAndRemoveMapping
|
||||
|
||||
# 测试并发请求
|
||||
go test -v -run TestConcurrentRequests
|
||||
|
||||
# 测试映射列表
|
||||
go test -v -run TestListMappings
|
||||
|
||||
# 测试健康检查
|
||||
go test -v -run TestHealthCheck
|
||||
```
|
||||
|
||||
## 测试说明
|
||||
|
||||
### 1. TestForwardingBasic - 基本转发功能测试
|
||||
|
||||
测试流程:
|
||||
1. 在本地 8888 端口启动一个 echo 服务器
|
||||
2. 通过 API 创建端口映射(30001 -> 8888)
|
||||
3. 向服务器的 30001 端口发送消息
|
||||
4. 验证能否通过隧道转发到客户端的 8888 端口并收到响应
|
||||
|
||||
**预期结果**:发送 "Hello, Tunnel!" 能收到 "ECHO: Hello, Tunnel!"
|
||||
|
||||
### 2. TestMultipleForwards - 多端口转发测试
|
||||
|
||||
测试流程:
|
||||
1. 创建多个端口映射(30002, 30003, 30004)
|
||||
2. 同时向这些端口发送请求
|
||||
3. 验证所有端口都能正确转发
|
||||
|
||||
**预期结果**:所有端口都能正常转发并收到正确响应
|
||||
|
||||
### 3. TestAddAndRemoveMapping - 动态管理测试
|
||||
|
||||
测试流程:
|
||||
1. 创建端口映射并验证工作
|
||||
2. 删除端口映射并验证无法连接
|
||||
3. 重新创建映射并验证恢复工作
|
||||
|
||||
**预期结果**:映射的创建、删除、重建都能正确工作
|
||||
|
||||
### 4. TestConcurrentRequests - 并发请求测试
|
||||
|
||||
测试流程:
|
||||
1. 创建一个端口映射
|
||||
2. 并发发送 10 个请求
|
||||
3. 验证所有请求都能正确处理
|
||||
|
||||
**预期结果**:10 个并发请求全部成功
|
||||
|
||||
### 5. TestListMappings - 映射列表测试
|
||||
|
||||
测试流程:
|
||||
1. 创建多个端口映射
|
||||
2. 通过 API 获取映射列表
|
||||
3. 验证所有创建的映射都在列表中
|
||||
|
||||
**预期结果**:API 返回完整的映射列表
|
||||
|
||||
### 6. TestHealthCheck - 健康检查测试
|
||||
|
||||
测试流程:
|
||||
1. 向服务器发送健康检查请求
|
||||
|
||||
**预期结果**:返回 200 状态码
|
||||
|
||||
## 测试端口使用
|
||||
|
||||
测试使用以下端口,请确保这些端口未被占用:
|
||||
|
||||
- 8888: 本地 echo 测试服务
|
||||
- 30001: TestForwardingBasic
|
||||
- 30002-30004: TestMultipleForwards
|
||||
- 30005: TestAddAndRemoveMapping
|
||||
- 30006: TestConcurrentRequests
|
||||
- 30010-30012: TestListMappings
|
||||
|
||||
## 自定义测试服务器地址
|
||||
|
||||
如果你的服务器运行在不同的地址,可以在测试代码中修改 `defaultConfig`:
|
||||
|
||||
```go
|
||||
var defaultConfig = TestConfig{
|
||||
ServerAPIAddr: "http://your-server:8080", // 修改这里
|
||||
ServerTunnelPort: 9000,
|
||||
TestPort: 30001,
|
||||
LocalServicePort: 8888,
|
||||
}
|
||||
```
|
||||
|
||||
或者创建环境变量:
|
||||
|
||||
```bash
|
||||
export TEST_SERVER_API="http://your-server:8080"
|
||||
export TEST_SERVER_TUNNEL_PORT="9000"
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 测试失败:连接被拒绝
|
||||
|
||||
**原因**:服务器或客户端未启动
|
||||
|
||||
**解决**:确保服务器和客户端都在运行
|
||||
|
||||
### 测试失败:隧道未连接
|
||||
|
||||
**原因**:客户端未成功连接到服务器
|
||||
|
||||
**解决**:
|
||||
1. 检查客户端日志,确认连接成功
|
||||
2. 检查服务器地址配置是否正确
|
||||
|
||||
### 测试失败:端口已被占用
|
||||
|
||||
**原因**:测试端口被其他程序占用
|
||||
|
||||
**解决**:
|
||||
1. 使用 `netstat -tuln | grep <port>` 查看端口占用
|
||||
2. 关闭占用端口的程序
|
||||
3. 或在测试代码中修改端口号
|
||||
|
||||
### 测试超时
|
||||
|
||||
**原因**:网络延迟或服务响应慢
|
||||
|
||||
**解决**:
|
||||
1. 增加测试中的 `time.Sleep` 时间
|
||||
2. 检查服务器和客户端是否正常运行
|
||||
3. 查看服务器和客户端日志
|
||||
|
||||
## 示例输出
|
||||
|
||||
成功运行测试的输出示例:
|
||||
|
||||
```
|
||||
=== RUN TestForwardingBasic
|
||||
integration_test.go:54: 启动本地测试服务...
|
||||
integration_test.go:493: Echo 服务器启动在端口 8888
|
||||
integration_test.go:63: 创建端口映射: 30001 -> localhost:8888
|
||||
integration_test.go:68: 端口映射创建成功
|
||||
integration_test.go:82: 发送测试消息: Hello, Tunnel!
|
||||
integration_test.go:95: ✓ 转发成功,收到响应: ECHO: Hello, Tunnel!
|
||||
integration_test.go:73: 清理端口映射...
|
||||
--- PASS: TestForwardingBasic (1.52s)
|
||||
=== RUN TestHealthCheck
|
||||
integration_test.go:346: ✓ 健康检查成功: {"success":true,"message":"服务器运行正常"}
|
||||
--- PASS: TestHealthCheck (0.01s)
|
||||
PASS
|
||||
ok test 1.534s
|
||||
```
|
||||
|
||||
## 扩展测试
|
||||
|
||||
你可以根据需要添加更多测试用例:
|
||||
|
||||
1. 测试不同大小的数据传输
|
||||
2. 测试长时间连接
|
||||
3. 测试异常情况处理
|
||||
4. 测试性能和吞吐量
|
||||
5. 测试错误恢复机制
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 测试会自动清理创建的端口映射,但如果测试中断可能需要手动清理
|
||||
2. 每个测试都是独立的,可以单独运行
|
||||
3. 测试内置了 echo 服务器,不需要额外准备测试服务
|
||||
4. 建议在开发环境运行测试,避免影响生产环境
|
||||
|
|
@ -0,0 +1,554 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestConfig 测试配置
|
||||
type TestConfig struct {
|
||||
ServerAPIAddr string // 服务器 API 地址
|
||||
ServerTunnelPort int // 服务器隧道端口
|
||||
TestPort int // 测试用的端口
|
||||
LocalServicePort int // 本地服务端口(客户端监听)
|
||||
}
|
||||
|
||||
// 默认测试配置
|
||||
var defaultConfig = TestConfig{
|
||||
ServerAPIAddr: "http://localhost:8080",
|
||||
ServerTunnelPort: 9000,
|
||||
TestPort: 30001,
|
||||
LocalServicePort: 8888,
|
||||
}
|
||||
|
||||
// CreateMappingRequest 创建映射请求
|
||||
type CreateMappingRequest struct {
|
||||
Port int `json:"port"`
|
||||
TargetIP string `json:"target_ip,omitempty"`
|
||||
}
|
||||
|
||||
// RemoveMappingRequest 删除映射请求
|
||||
type RemoveMappingRequest struct {
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
// Response API 响应
|
||||
type Response struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// TestForwardingBasic 测试基本转发功能
|
||||
// 前置条件:服务器和客户端已启动,客户端在本地 8888 端口运行了一个 echo 服务
|
||||
func TestForwardingBasic(t *testing.T) {
|
||||
config := defaultConfig
|
||||
|
||||
// 1. 启动本地测试服务(模拟客户端本地服务)
|
||||
t.Log("启动本地测试服务...")
|
||||
stopChan := make(chan struct{})
|
||||
go startEchoServer(config.LocalServicePort, stopChan, t)
|
||||
defer close(stopChan)
|
||||
|
||||
// 等待服务启动
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// 2. 创建端口映射
|
||||
t.Logf("创建端口映射: %d -> localhost:%d", config.TestPort, config.LocalServicePort)
|
||||
err := createMapping(config.ServerAPIAddr, config.TestPort, "")
|
||||
if err != nil {
|
||||
t.Fatalf("创建端口映射失败: %v", err)
|
||||
}
|
||||
t.Log("端口映射创建成功")
|
||||
|
||||
// 清理:测试结束后删除映射
|
||||
defer func() {
|
||||
t.Log("清理端口映射...")
|
||||
err := removeMapping(config.ServerAPIAddr, config.TestPort)
|
||||
if err != nil {
|
||||
t.Logf("清理端口映射失败: %v", err)
|
||||
}
|
||||
// 等待清理完成
|
||||
time.Sleep(2 * time.Second)
|
||||
}()
|
||||
|
||||
// 等待映射生效
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 3. 通过服务器端口发送请求
|
||||
testMessage := "Hello, Tunnel!"
|
||||
t.Logf("发送测试消息: %s", testMessage)
|
||||
|
||||
response, err := sendTCPRequest(fmt.Sprintf("localhost:%d", config.TestPort), testMessage)
|
||||
if err != nil {
|
||||
t.Fatalf("发送请求失败: %v", err)
|
||||
}
|
||||
|
||||
// 4. 验证响应
|
||||
expectedResponse := "ECHO: " + testMessage
|
||||
if response != expectedResponse {
|
||||
t.Fatalf("响应不匹配.\n期望: %s\n实际: %s", expectedResponse, response)
|
||||
}
|
||||
|
||||
t.Logf("✓ 转发成功,收到响应: %s", response)
|
||||
}
|
||||
|
||||
// TestMultipleForwards 测试多个端口同时转发
|
||||
func TestMultipleForwards(t *testing.T) {
|
||||
config := defaultConfig
|
||||
|
||||
// 启动本地测试服务
|
||||
t.Log("启动本地测试服务...")
|
||||
stopChan := make(chan struct{})
|
||||
go startEchoServer(config.LocalServicePort, stopChan, t)
|
||||
defer close(stopChan)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// 测试多个端口
|
||||
testPorts := []int{30002, 30003, 30004}
|
||||
|
||||
// 创建多个映射
|
||||
for _, port := range testPorts {
|
||||
t.Logf("创建端口映射: %d", port)
|
||||
err := createMapping(config.ServerAPIAddr, port, "")
|
||||
if err != nil {
|
||||
t.Fatalf("创建端口 %d 映射失败: %v", port, err)
|
||||
}
|
||||
defer func(p int) {
|
||||
err := removeMapping(config.ServerAPIAddr, p)
|
||||
if err != nil {
|
||||
t.Logf("删除端口 %d 映射失败: %v", p, err)
|
||||
}
|
||||
}(port)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 并发测试所有端口
|
||||
for _, port := range testPorts {
|
||||
t.Run(fmt.Sprintf("Port_%d", port), func(t *testing.T) {
|
||||
testMessage := fmt.Sprintf("Test message for port %d", port)
|
||||
response, err := sendTCPRequest(fmt.Sprintf("localhost:%d", port), testMessage)
|
||||
if err != nil {
|
||||
t.Errorf("端口 %d 请求失败: %v", port, err)
|
||||
return
|
||||
}
|
||||
|
||||
expectedResponse := "ECHO: " + testMessage
|
||||
if response != expectedResponse {
|
||||
t.Errorf("端口 %d 响应不匹配.\n期望: %s\n实际: %s", port, expectedResponse, response)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("✓ 端口 %d 转发成功", port)
|
||||
})
|
||||
}
|
||||
|
||||
// 等待清理完成
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
// TestAddAndRemoveMapping 测试动态添加和删除映射
|
||||
func TestAddAndRemoveMapping(t *testing.T) {
|
||||
config := defaultConfig
|
||||
|
||||
// 启动本地测试服务
|
||||
t.Log("启动本地测试服务...")
|
||||
stopChan := make(chan struct{})
|
||||
go startEchoServer(config.LocalServicePort, stopChan, t)
|
||||
defer close(stopChan)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
port := 30005
|
||||
|
||||
// 1. 创建映射
|
||||
t.Logf("创建端口映射: %d", port)
|
||||
err := createMapping(config.ServerAPIAddr, port, "")
|
||||
if err != nil {
|
||||
t.Fatalf("创建端口映射失败: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 2. 验证映射工作
|
||||
t.Log("验证映射工作...")
|
||||
response, err := sendTCPRequest(fmt.Sprintf("localhost:%d", port), "test1")
|
||||
if err != nil {
|
||||
t.Fatalf("映射创建后请求失败: %v", err)
|
||||
}
|
||||
if response != "ECHO: test1" {
|
||||
t.Fatalf("响应不匹配: %s", response)
|
||||
}
|
||||
t.Log("✓ 映射工作正常")
|
||||
|
||||
// 3. 删除映射
|
||||
t.Log("删除端口映射...")
|
||||
err = removeMapping(config.ServerAPIAddr, port)
|
||||
if err != nil {
|
||||
t.Fatalf("删除端口映射失败: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 4. 验证映射已删除(连接应该失败)
|
||||
t.Log("验证映射已删除...")
|
||||
_, err = sendTCPRequest(fmt.Sprintf("localhost:%d", port), "test2")
|
||||
if err == nil {
|
||||
t.Fatal("映射删除后请求应该失败,但成功了")
|
||||
}
|
||||
t.Logf("✓ 映射已删除,连接失败(符合预期): %v", err)
|
||||
|
||||
// 5. 重新创建映射
|
||||
t.Log("重新创建端口映射...")
|
||||
err = createMapping(config.ServerAPIAddr, port, "")
|
||||
if err != nil {
|
||||
t.Fatalf("重新创建端口映射失败: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := removeMapping(config.ServerAPIAddr, port)
|
||||
if err != nil {
|
||||
t.Logf("清理端口映射失败: %v", err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 6. 验证映射恢复工作
|
||||
t.Log("验证映射恢复工作...")
|
||||
response, err = sendTCPRequest(fmt.Sprintf("localhost:%d", port), "test3")
|
||||
if err != nil {
|
||||
t.Fatalf("映射重新创建后请求失败: %v", err)
|
||||
}
|
||||
if response != "ECHO: test3" {
|
||||
t.Fatalf("响应不匹配: %s", response)
|
||||
}
|
||||
t.Log("✓ 映射恢复工作正常")
|
||||
}
|
||||
|
||||
// TestConcurrentRequests 测试并发请求
|
||||
func TestConcurrentRequests(t *testing.T) {
|
||||
config := defaultConfig
|
||||
|
||||
// 启动本地测试服务
|
||||
t.Log("启动本地测试服务...")
|
||||
stopChan := make(chan struct{})
|
||||
go startEchoServer(config.LocalServicePort, stopChan, t)
|
||||
defer close(stopChan)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
port := 30006
|
||||
|
||||
// 创建映射
|
||||
t.Logf("创建端口映射: %d", port)
|
||||
err := createMapping(config.ServerAPIAddr, port, "")
|
||||
if err != nil {
|
||||
t.Fatalf("创建端口映射失败: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := removeMapping(config.ServerAPIAddr, port)
|
||||
if err != nil {
|
||||
t.Logf("清理端口映射失败: %v", err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 并发发送请求
|
||||
concurrency := 10
|
||||
results := make(chan error, concurrency)
|
||||
|
||||
t.Logf("并发发送 %d 个请求...", concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(index int) {
|
||||
message := fmt.Sprintf("concurrent_%d", index)
|
||||
response, err := sendTCPRequest(fmt.Sprintf("localhost:%d", port), message)
|
||||
if err != nil {
|
||||
results <- fmt.Errorf("请求 %d 失败: %w", index, err)
|
||||
return
|
||||
}
|
||||
|
||||
expectedResponse := "ECHO: " + message
|
||||
if response != expectedResponse {
|
||||
results <- fmt.Errorf("请求 %d 响应不匹配: 期望=%s, 实际=%s", index, expectedResponse, response)
|
||||
return
|
||||
}
|
||||
|
||||
results <- nil
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 收集结果
|
||||
successCount := 0
|
||||
for i := 0; i < concurrency; i++ {
|
||||
err := <-results
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("✓ 并发测试完成: %d/%d 成功", successCount, concurrency)
|
||||
|
||||
if successCount < concurrency {
|
||||
t.Fatalf("部分并发请求失败")
|
||||
}
|
||||
}
|
||||
|
||||
// TestListMappings 测试列出所有映射
|
||||
func TestListMappings(t *testing.T) {
|
||||
config := defaultConfig
|
||||
|
||||
// 创建几个映射
|
||||
testPorts := []int{30010, 30011, 30012}
|
||||
|
||||
for _, port := range testPorts {
|
||||
t.Logf("创建端口映射: %d", port)
|
||||
err := createMapping(config.ServerAPIAddr, port, "")
|
||||
if err != nil {
|
||||
t.Fatalf("创建端口 %d 映射失败: %v", port, err)
|
||||
}
|
||||
defer removeMapping(config.ServerAPIAddr, port)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// 列出所有映射
|
||||
t.Log("列出所有映射...")
|
||||
mappings, err := listMappings(config.ServerAPIAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("列出映射失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("当前映射数量: %d", len(mappings))
|
||||
|
||||
// 验证创建的映射都在列表中
|
||||
for _, port := range testPorts {
|
||||
found := false
|
||||
for _, mapping := range mappings {
|
||||
if sourcePort, ok := mapping["source_port"].(float64); ok && int(sourcePort) == port {
|
||||
found = true
|
||||
t.Logf("✓ 找到映射: 端口 %d", port)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("映射端口 %d 未在列表中找到", port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestHealthCheck 测试健康检查
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
config := defaultConfig
|
||||
|
||||
url := config.ServerAPIAddr + "/health"
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("健康检查请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("健康检查返回非 200 状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Logf("✓ 健康检查成功: %s", string(body))
|
||||
}
|
||||
|
||||
// ============ 辅助函数 ============
|
||||
|
||||
// createMapping 创建端口映射
|
||||
func createMapping(apiAddr string, port int, targetIP string) error {
|
||||
url := apiAddr + "/api/mapping/create"
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
Port: port,
|
||||
TargetIP: targetIP,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var result Response
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
return fmt.Errorf("API 返回失败: %s", result.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeMapping 删除端口映射
|
||||
func removeMapping(apiAddr string, port int) error {
|
||||
url := apiAddr + "/api/mapping/remove"
|
||||
|
||||
reqBody := RemoveMappingRequest{
|
||||
Port: port,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var result Response
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
return fmt.Errorf("API 返回失败: %s", result.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// listMappings 列出所有映射
|
||||
func listMappings(apiAddr string) ([]map[string]interface{}, error) {
|
||||
url := apiAddr + "/api/mapping/list"
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var result Response
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
return nil, fmt.Errorf("API 返回失败: %s", result.Message)
|
||||
}
|
||||
|
||||
// 提取映射列表
|
||||
if mappingsData, ok := result.Data["mappings"].([]interface{}); ok {
|
||||
mappings := make([]map[string]interface{}, len(mappingsData))
|
||||
for i, item := range mappingsData {
|
||||
if mapping, ok := item.(map[string]interface{}); ok {
|
||||
mappings[i] = mapping
|
||||
}
|
||||
}
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
return []map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
// sendTCPRequest 发送 TCP 请求并接收响应
|
||||
func sendTCPRequest(addr string, message string) (string, error) {
|
||||
conn, err := net.DialTimeout("tcp", addr, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("连接失败: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 设置超时 - 增加到10秒
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
|
||||
// 发送消息
|
||||
_, err = conn.Write([]byte(message + "\n"))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("发送失败: %w", err)
|
||||
}
|
||||
|
||||
// 接收响应
|
||||
buffer := make([]byte, 1024)
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("接收失败: %w", err)
|
||||
}
|
||||
|
||||
return string(buffer[:n]), nil
|
||||
}
|
||||
|
||||
// startEchoServer 启动简单的 echo 服务器(用于测试)
|
||||
func startEchoServer(port int, stopChan chan struct{}, t *testing.T) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
t.Logf("启动 echo 服务器失败: %v", err)
|
||||
return
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
t.Logf("Echo 服务器启动在端口 %d", port)
|
||||
|
||||
go func() {
|
||||
<-stopChan
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
go handleEchoConnection(conn, t)
|
||||
}
|
||||
}
|
||||
|
||||
// handleEchoConnection 处理 echo 连接
|
||||
func handleEchoConnection(conn net.Conn, t *testing.T) {
|
||||
defer conn.Close()
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Logf("读取数据失败: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
message := string(buffer[:n])
|
||||
message = string(bytes.TrimSpace([]byte(message)))
|
||||
|
||||
response := fmt.Sprintf("ECHO: %s", message)
|
||||
_, err = conn.Write([]byte(response))
|
||||
if err != nil {
|
||||
t.Logf("发送响应失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
#!/bin/bash
|
||||
|
||||
# 集成测试运行脚本
|
||||
# 此脚本用于运行 go-tunnel 的集成测试
|
||||
|
||||
set -e
|
||||
|
||||
# 颜色输出
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# 配置
|
||||
SERVER_API=${SERVER_API:-"http://localhost:8080"}
|
||||
SERVER_TUNNEL_PORT=${SERVER_TUNNEL_PORT:-9000}
|
||||
TEST_TIMEOUT=${TEST_TIMEOUT:-30s}
|
||||
|
||||
echo -e "${BLUE}================================${NC}"
|
||||
echo -e "${BLUE}Go-Tunnel 集成测试${NC}"
|
||||
echo -e "${BLUE}================================${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查服务器是否运行
|
||||
echo -e "${YELLOW}检查服务器状态...${NC}"
|
||||
if curl -s "${SERVER_API}/health" > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}✓ 服务器运行正常${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ 服务器未响应${NC}"
|
||||
echo -e "${YELLOW}请确保服务器已启动:${NC}"
|
||||
echo " cd ../src && make run-server"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查客户端是否连接(通过创建一个测试映射来验证)
|
||||
echo -e "${YELLOW}检查隧道连接...${NC}"
|
||||
if curl -s -X POST "${SERVER_API}/api/mapping/create" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"port": 39999}' | grep -q "success"; then
|
||||
echo -e "${GREEN}✓ 隧道连接正常${NC}"
|
||||
# 清理测试映射
|
||||
curl -s -X POST "${SERVER_API}/api/mapping/remove" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"port": 39999}' > /dev/null 2>&1 || true
|
||||
else
|
||||
echo -e "${YELLOW}⚠ 无法验证隧道连接,测试可能失败${NC}"
|
||||
echo -e "${YELLOW}请确保客户端已启动:${NC}"
|
||||
echo " cd ../src && make run-client"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo -e "${BLUE}开始运行测试...${NC}"
|
||||
echo ""
|
||||
|
||||
# 切换到测试目录
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# 运行测试
|
||||
if [ "$1" = "verbose" ] || [ "$1" = "-v" ]; then
|
||||
go test -v -timeout ${TEST_TIMEOUT}
|
||||
elif [ -n "$1" ]; then
|
||||
# 运行特定测试
|
||||
echo -e "${BLUE}运行测试: $1${NC}"
|
||||
go test -v -timeout ${TEST_TIMEOUT} -run "$1"
|
||||
else
|
||||
go test -timeout ${TEST_TIMEOUT}
|
||||
fi
|
||||
|
||||
TEST_EXIT_CODE=$?
|
||||
|
||||
echo ""
|
||||
if [ $TEST_EXIT_CODE -eq 0 ]; then
|
||||
echo -e "${GREEN}================================${NC}"
|
||||
echo -e "${GREEN}✓ 所有测试通过${NC}"
|
||||
echo -e "${GREEN}================================${NC}"
|
||||
else
|
||||
echo -e "${RED}================================${NC}"
|
||||
echo -e "${RED}✗ 测试失败${NC}"
|
||||
echo -e "${RED}================================${NC}"
|
||||
fi
|
||||
|
||||
exit $TEST_EXIT_CODE
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
#!/bin/bash
|
||||
|
||||
# 简化的集成测试运行脚本
|
||||
# 说明:此脚本用于逐个运行集成测试,避免端口冲突
|
||||
|
||||
set -e
|
||||
|
||||
echo "========================================="
|
||||
echo "集成测试开始"
|
||||
echo "========================================="
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# 清理函数
|
||||
cleanup_ports() {
|
||||
echo "清理测试端口..."
|
||||
for port in 30001 30002 30003 30004 30005 30006 30010 30011 30012; do
|
||||
# 尝试清理可能残留的端口映射
|
||||
curl -s -X POST http://localhost:8080/api/mapping/remove \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"port\": $port}" > /dev/null 2>&1 || true
|
||||
done
|
||||
|
||||
# 等待端口完全释放
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# 测试前清理
|
||||
echo "测试前清理..."
|
||||
cleanup_ports
|
||||
|
||||
# 检查服务器是否运行
|
||||
echo "检查服务器状态..."
|
||||
if ! curl -s http://localhost:8080/health > /dev/null 2>&1; then
|
||||
echo "错误:服务器未运行,请先启动服务器"
|
||||
echo "运行: cd .. && make run-server"
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ 服务器正在运行"
|
||||
|
||||
# 检查客户端是否连接
|
||||
echo "检查客户端连接..."
|
||||
MAPPING_RESPONSE=$(curl -s http://localhost:8080/api/mapping/list)
|
||||
if [ -z "$MAPPING_RESPONSE" ]; then
|
||||
echo "警告:无法获取映射列表"
|
||||
else
|
||||
echo "✓ 可以访问服务器 API"
|
||||
fi
|
||||
|
||||
# 运行测试
|
||||
echo "========================================="
|
||||
echo "运行集成测试(逐个测试)..."
|
||||
echo "========================================="
|
||||
|
||||
TEST_FAILED=0
|
||||
|
||||
# 逐个运行测试,避免端口冲突
|
||||
echo ""
|
||||
echo "测试 1: TestForwardingBasic"
|
||||
echo "-----------------------------------------"
|
||||
if go test -v -timeout 30s -run TestForwardingBasic; then
|
||||
echo "✓ TestForwardingBasic 通过"
|
||||
else
|
||||
echo "✗ TestForwardingBasic 失败"
|
||||
TEST_FAILED=1
|
||||
fi
|
||||
cleanup_ports
|
||||
|
||||
echo ""
|
||||
echo "测试 2: TestMultipleForwards"
|
||||
echo "-----------------------------------------"
|
||||
if go test -v -timeout 30s -run TestMultipleForwards; then
|
||||
echo "✓ TestMultipleForwards 通过"
|
||||
else
|
||||
echo "✗ TestMultipleForwards 失败"
|
||||
TEST_FAILED=1
|
||||
fi
|
||||
cleanup_ports
|
||||
|
||||
echo ""
|
||||
echo "测试 3: TestAddAndRemoveMapping"
|
||||
echo "-----------------------------------------"
|
||||
if go test -v -timeout 30s -run TestAddAndRemoveMapping; then
|
||||
echo "✓ TestAddAndRemoveMapping 通过"
|
||||
else
|
||||
echo "✗ TestAddAndRemoveMapping 失败"
|
||||
TEST_FAILED=1
|
||||
fi
|
||||
cleanup_ports
|
||||
|
||||
echo ""
|
||||
echo "测试 4: TestConcurrentRequests"
|
||||
echo "-----------------------------------------"
|
||||
if go test -v -timeout 30s -run TestConcurrentRequests; then
|
||||
echo "✓ TestConcurrentRequests 通过"
|
||||
else
|
||||
echo "✗ TestConcurrentRequests 失败"
|
||||
TEST_FAILED=1
|
||||
fi
|
||||
cleanup_ports
|
||||
|
||||
# 测试后清理
|
||||
echo ""
|
||||
echo "测试后清理..."
|
||||
cleanup_ports
|
||||
|
||||
echo "========================================="
|
||||
if [ $TEST_FAILED -eq 0 ]; then
|
||||
echo "✓ 所有测试通过"
|
||||
echo "========================================="
|
||||
exit 0
|
||||
else
|
||||
echo "✗ 部分测试失败"
|
||||
echo "========================================="
|
||||
exit 1
|
||||
fi
|
||||
Loading…
Reference in New Issue