feat: 重构了tunnel,现在统一了转发逻辑
This commit is contained in:
parent
2655b5592f
commit
fc1614f7a4
|
|
@ -16,7 +16,7 @@ import (
|
|||
|
||||
// TunnelServer 隧道服务器接口
|
||||
type TunnelServer interface {
|
||||
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error
|
||||
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) (net.Conn, error)
|
||||
IsConnected() bool
|
||||
GetTrafficStats() stats.TrafficStats
|
||||
}
|
||||
|
|
@ -46,7 +46,13 @@ func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int6
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var limiterOut, limiterIn *rate.Limiter
|
||||
if limit != nil {
|
||||
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
|
||||
// burst设置为1秒的流量,这样可以平滑处理突发
|
||||
// 同时不会一次性消耗太多令牌
|
||||
burst := int(*limit) / 100
|
||||
if burst < 10240 {
|
||||
burst = 10240 // 最小burst为10KB
|
||||
}
|
||||
log.Printf("设置限速: %d bytes/sec, burst: %d bytes", *limit, burst)
|
||||
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
|
||||
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
|
||||
}
|
||||
|
|
@ -68,7 +74,13 @@ func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunne
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var limiterOut, limiterIn *rate.Limiter
|
||||
if limit != nil {
|
||||
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
|
||||
// burst设置为1秒的流量,这样可以平滑处理突发
|
||||
// 同时不会一次性消耗太多令牌
|
||||
burst := int(*limit) / 100
|
||||
if burst < 10240 {
|
||||
burst = 10240 // 最小burst为10KB
|
||||
}
|
||||
log.Printf("设置限速: %d bytes/sec, burst: %d bytes", *limit, burst)
|
||||
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
|
||||
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
|
||||
}
|
||||
|
|
@ -142,54 +154,76 @@ type rateLimitedReader struct {
|
|||
}
|
||||
|
||||
func (rlr *rateLimitedReader) Read(p []byte) (int, error) {
|
||||
if rlr.limiter != nil {
|
||||
maxReq := rlr.limiter.Burst()
|
||||
reqSize := len(p)
|
||||
if reqSize > maxReq {
|
||||
reqSize = maxReq // 避免一次申请超过桶容量导致错误
|
||||
}
|
||||
err := rlr.limiter.WaitN(rlr.ctx, reqSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if rlr.limiter == nil {
|
||||
return rlr.r.Read(p)
|
||||
}
|
||||
return rlr.r.Read(p)
|
||||
|
||||
// 使用更小的块大小以实现更平滑的限流
|
||||
// 2KB是一个合理的值,既不会太频繁调用,也能保持流量平滑
|
||||
chunkSize := 2048
|
||||
|
||||
// 不超过缓冲区大小
|
||||
if len(p) < chunkSize {
|
||||
chunkSize = len(p)
|
||||
}
|
||||
|
||||
// 预先申请令牌
|
||||
if err := rlr.limiter.WaitN(rlr.ctx, chunkSize); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 限制实际读取大小
|
||||
if len(p) > chunkSize {
|
||||
p = p[:chunkSize]
|
||||
}
|
||||
|
||||
// 执行读取
|
||||
n, err := rlr.r.Read(p)
|
||||
|
||||
// 如果实际读取少于申请的令牌,归还多余的令牌
|
||||
// 注意:rate.Limiter 不支持归还令牌,所以这里只能接受这个损耗
|
||||
// 或者改用先读后限的方式
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// handleConnection 处理单个连接
|
||||
func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
||||
defer f.wg.Done()
|
||||
defer clientConn.Close()
|
||||
|
||||
var targetConn net.Conn
|
||||
var err error
|
||||
|
||||
if f.useTunnel {
|
||||
// 使用隧道转发
|
||||
// 使用隧道转发,获取 TunnelConn
|
||||
if f.tunnelServer == nil || !f.tunnelServer.IsConnected() {
|
||||
log.Printf("隧道服务器不可用 (端口 %d)", f.sourcePort)
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 将连接转发到隧道,ForwardConnection 会处理连接关闭
|
||||
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetHost, f.targetPort); err != nil {
|
||||
// 获取隧道连接
|
||||
targetConn, err = f.tunnelServer.ForwardConnection(clientConn, f.targetHost, f.targetPort)
|
||||
if err != nil {
|
||||
log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetHost, f.targetPort, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 直接连接目标
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
// 动态解析域名并连接
|
||||
targetAddr := fmt.Sprintf("%s:%d", f.targetHost, f.targetPort)
|
||||
targetConn, err = dialer.DialContext(f.ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, targetAddr, err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 直接连接目标
|
||||
defer clientConn.Close()
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
// 动态解析域名并连接
|
||||
targetAddr := fmt.Sprintf("%s:%d", f.targetHost, f.targetPort)
|
||||
targetConn, err := dialer.DialContext(f.ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, targetAddr, err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
// 双向转发
|
||||
|
|
|
|||
|
|
@ -39,10 +39,10 @@ type mockTunnelServer struct {
|
|||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) error {
|
||||
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) (net.Conn, error) {
|
||||
// 简单的模拟实现
|
||||
defer clientConn.Close()
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockTunnelServer) IsConnected() bool {
|
||||
|
|
|
|||
|
|
@ -492,6 +492,22 @@
|
|||
let currentMappings = [];
|
||||
let apiKey = '';
|
||||
|
||||
// 格式化带宽限制为可读格式
|
||||
function formatBandwidth(bytes) {
|
||||
if (!bytes || bytes === 0) {
|
||||
return '无限制';
|
||||
}
|
||||
if (bytes < 1024) {
|
||||
return bytes + ' B/s';
|
||||
} else if (bytes < 1024 * 1024) {
|
||||
return (bytes / 1024).toFixed(2) + ' KB/s';
|
||||
} else if (bytes < 1024 * 1024 * 1024) {
|
||||
return (bytes / (1024 * 1024)).toFixed(2) + ' MB/s';
|
||||
} else {
|
||||
return (bytes / (1024 * 1024 * 1024)).toFixed(2) + ' GB/s';
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否已有 API Key
|
||||
function checkAuth() {
|
||||
const savedKey = sessionStorage.getItem('apiKey');
|
||||
|
|
@ -733,6 +749,7 @@
|
|||
'<span class="status-badge ' + (mapping.use_tunnel ? 'status-tunnel' : 'status-direct') + '">' +
|
||||
(mapping.use_tunnel ? '隧道模式' : '直连模式') +
|
||||
'</span><br>' +
|
||||
'<strong>带宽限制:</strong> ' + formatBandwidth(mapping.bandwidth_limit) + '<br>' +
|
||||
'<strong>创建时间:</strong> ' + new Date(mapping.created_at).toLocaleString('zh-CN') +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
|
|
@ -800,7 +817,8 @@
|
|||
'<strong>目标:</strong> ' + mapping.target_host + ':' + mapping.target_port + ' | ' +
|
||||
'<span class="status-badge ' + (mapping.use_tunnel ? 'status-tunnel' : 'status-direct') + '">' +
|
||||
(mapping.use_tunnel ? '隧道' : '直连') +
|
||||
'</span>' +
|
||||
'</span> | ' +
|
||||
'<strong>带宽:</strong> ' + formatBandwidth(mapping.bandwidth_limit) +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>'
|
||||
|
|
|
|||
|
|
@ -20,28 +20,28 @@ import (
|
|||
const (
|
||||
// 协议版本
|
||||
ProtocolVersion = 0x01
|
||||
|
||||
|
||||
// 消息头大小
|
||||
HeaderSize = 6 // 版本(1) + 类型(1) + 长度(4)
|
||||
|
||||
|
||||
// 最大包大小 (1MB)
|
||||
MaxPacketSize = 1024 * 1024
|
||||
|
||||
|
||||
// 消息类型
|
||||
MsgTypeConnectRequest = 0x01 // 连接请求
|
||||
MsgTypeConnectResponse = 0x02 // 连接响应
|
||||
MsgTypeData = 0x03 // 数据传输
|
||||
MsgTypeClose = 0x04 // 关闭连接
|
||||
MsgTypeKeepAlive = 0x05 // 心跳
|
||||
|
||||
MsgTypeData = 0x03 // 数据传输
|
||||
MsgTypeClose = 0x04 // 关闭连接
|
||||
MsgTypeKeepAlive = 0x05 // 心跳
|
||||
|
||||
// 连接响应状态
|
||||
ConnStatusSuccess = 0x00 // 连接成功
|
||||
ConnStatusFailed = 0x01 // 连接失败
|
||||
|
||||
|
||||
// 超时设置
|
||||
ConnectTimeout = 10 * time.Second // 连接超时
|
||||
ReadTimeout = 300 * time.Second // 读取超时,统一为60秒
|
||||
KeepAliveInterval = 15 * time.Second // 心跳间隔
|
||||
ConnectTimeout = 10 * time.Second // 连接超时
|
||||
ReadTimeout = 300 * time.Second // 读取超时,统一为60秒
|
||||
KeepAliveInterval = 15 * time.Second // 心跳间隔
|
||||
)
|
||||
|
||||
// TunnelMessage 隧道消息
|
||||
|
|
@ -94,7 +94,115 @@ type ActiveConnection struct {
|
|||
TargetHost string // 支持IP或域名
|
||||
TargetIP string
|
||||
Created time.Time
|
||||
Closing int32 // 原子操作标志,表示连接正在关闭
|
||||
Closing int32 // 原子操作标志,表示连接正在关闭
|
||||
RecvChan chan []byte // 接收数据的通道
|
||||
}
|
||||
|
||||
// TunnelConn 实现 net.Conn 接口的隧道连接
|
||||
type TunnelConn struct {
|
||||
server *Server
|
||||
connID uint32
|
||||
targetHost string
|
||||
targetPort int
|
||||
recvChan chan []byte // 接收数据的通道
|
||||
readBuffer []byte // 读取缓冲区
|
||||
closed int32 // 原子操作标志,表示连接已关闭
|
||||
closeOnce sync.Once // 确保只关闭一次
|
||||
}
|
||||
|
||||
// Read 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) Read(b []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&tc.closed) == 1 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// 如果有缓冲数据,先返回缓冲数据
|
||||
if len(tc.readBuffer) > 0 {
|
||||
n = copy(b, tc.readBuffer)
|
||||
tc.readBuffer = tc.readBuffer[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// 从通道读取数据
|
||||
select {
|
||||
case data, ok := <-tc.recvChan:
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(b, data)
|
||||
// 如果数据太多,保存剩余部分到缓冲区
|
||||
if n < len(data) {
|
||||
tc.readBuffer = data[n:]
|
||||
}
|
||||
return n, nil
|
||||
case <-tc.server.ctx.Done():
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// Write 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) Write(b []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&tc.closed) == 1 {
|
||||
return 0, fmt.Errorf("connection closed")
|
||||
}
|
||||
|
||||
// 发送数据到隧道
|
||||
dataMsg := make([]byte, 4+len(b))
|
||||
binary.BigEndian.PutUint32(dataMsg[0:4], tc.connID)
|
||||
copy(dataMsg[4:], b)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(dataMsg)),
|
||||
Data: dataMsg,
|
||||
}
|
||||
|
||||
select {
|
||||
case tc.server.sendChan <- msg:
|
||||
return len(b), nil
|
||||
case <-time.After(2 * time.Second):
|
||||
return 0, fmt.Errorf("write timeout")
|
||||
case <-tc.server.ctx.Done():
|
||||
return 0, fmt.Errorf("server closed")
|
||||
}
|
||||
}
|
||||
|
||||
// Close 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) Close() error {
|
||||
tc.closeOnce.Do(func() {
|
||||
atomic.StoreInt32(&tc.closed, 1)
|
||||
tc.server.closeConnection(tc.connID)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4zero, Port: 0}
|
||||
}
|
||||
|
||||
// RemoteAddr 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.ParseIP(tc.targetHost), Port: tc.targetPort}
|
||||
}
|
||||
|
||||
// SetDeadline 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) SetDeadline(t time.Time) error {
|
||||
// 隧道连接不支持 deadline,可以根据需要实现
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) SetReadDeadline(t time.Time) error {
|
||||
// 隧道连接不支持 deadline,可以根据需要实现
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) SetWriteDeadline(t time.Time) error {
|
||||
// 隧道连接不支持 deadline,可以根据需要实现
|
||||
return nil
|
||||
}
|
||||
|
||||
// Server 内网穿透服务器
|
||||
|
|
@ -106,17 +214,17 @@ type Server struct {
|
|||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
|
||||
|
||||
// 连接管理
|
||||
pendingConns map[uint32]*PendingConnection // 待确认连接
|
||||
activeConns map[uint32]*ActiveConnection // 活跃连接
|
||||
closingConns map[uint32]time.Time // 正在关闭的连接,避免重复处理
|
||||
connMu sync.RWMutex
|
||||
nextConnID uint32
|
||||
|
||||
|
||||
// 消息队列
|
||||
sendChan chan *TunnelMessage
|
||||
|
||||
|
||||
// 流量统计(使用原子操作)
|
||||
bytesSent uint64 // 通过隧道发送的总字节数
|
||||
bytesReceived uint64 // 通过隧道接收的总字节数
|
||||
|
|
@ -134,10 +242,10 @@ func NewServer(listenPort int) *Server {
|
|||
closingConns: make(map[uint32]time.Time),
|
||||
sendChan: make(chan *TunnelMessage, 10000), // 增加到10000
|
||||
}
|
||||
|
||||
|
||||
// 启动清理器,定期清理过期的关闭连接记录
|
||||
go server.cleanupClosingConns()
|
||||
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
|
|
@ -211,7 +319,7 @@ func (s *Server) handleTunnelRead(conn net.Conn) {
|
|||
s.tunnelConn = nil
|
||||
s.mu.Unlock()
|
||||
log.Printf("隧道客户端已断开")
|
||||
|
||||
|
||||
// 关闭所有活动连接
|
||||
s.connMu.Lock()
|
||||
for _, c := range s.pendingConns {
|
||||
|
|
@ -293,7 +401,7 @@ func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
|
|||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 统计接收字节数
|
||||
s.addBytesReceived(uint64(HeaderSize))
|
||||
|
||||
|
|
@ -333,7 +441,7 @@ func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
|
|||
// 设置写入超时,防止阻塞
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
defer conn.SetWriteDeadline(time.Time{}) // 重置超时
|
||||
|
||||
|
||||
// 构建消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
header[0] = msg.Version
|
||||
|
|
@ -344,7 +452,7 @@ func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
|
|||
if _, err := conn.Write(header); err != nil {
|
||||
return fmt.Errorf("写入消息头失败: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// 统计发送字节数
|
||||
s.addBytesSent(uint64(HeaderSize))
|
||||
|
||||
|
|
@ -398,13 +506,16 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
|
|||
s.connMu.Unlock()
|
||||
|
||||
if status == ConnStatusSuccess {
|
||||
// 连接成功,移到活跃连接
|
||||
// 连接成功,创建接收通道
|
||||
recvChan := make(chan []byte, 100)
|
||||
|
||||
active := &ActiveConnection{
|
||||
ID: connID,
|
||||
ClientConn: pending.ClientConn,
|
||||
TargetPort: pending.TargetPort,
|
||||
TargetHost: pending.TargetHost,
|
||||
Created: time.Now(),
|
||||
RecvChan: recvChan,
|
||||
}
|
||||
|
||||
s.connMu.Lock()
|
||||
|
|
@ -413,10 +524,6 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
|
|||
|
||||
log.Printf("连接已建立: ID=%d, 地址=%s:%d", connID, pending.TargetHost, pending.TargetPort)
|
||||
|
||||
// 启动数据转发
|
||||
s.wg.Add(1)
|
||||
go s.forwardData(active)
|
||||
|
||||
// 通知等待的goroutine
|
||||
select {
|
||||
case pending.ResponseChan <- true:
|
||||
|
|
@ -463,7 +570,7 @@ func (s *Server) handleDataMessage(msg *TunnelMessage) {
|
|||
s.connMu.Lock()
|
||||
s.closingConns[connID] = time.Now()
|
||||
s.connMu.Unlock()
|
||||
|
||||
|
||||
// 连接不存在,发送关闭消息通知对端
|
||||
closeData := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(closeData, connID)
|
||||
|
|
@ -486,10 +593,18 @@ func (s *Server) handleDataMessage(msg *TunnelMessage) {
|
|||
return
|
||||
}
|
||||
|
||||
// 写入到客户端连接
|
||||
if _, err := active.ClientConn.Write(data); err != nil {
|
||||
log.Printf("写入客户端连接失败 (ID=%d): %v", connID, err)
|
||||
// 发送数据到接收通道
|
||||
dataCopy := make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
select {
|
||||
case active.RecvChan <- dataCopy:
|
||||
// 数据已发送到接收通道
|
||||
case <-time.After(2 * time.Second):
|
||||
log.Printf("发送数据到接收通道超时 (ID=%d)", connID)
|
||||
s.closeConnection(connID)
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -527,162 +642,26 @@ func (s *Server) handleKeepAlive(msg *TunnelMessage) {
|
|||
}
|
||||
}
|
||||
|
||||
// forwardData 转发数据
|
||||
func (s *Server) forwardData(active *ActiveConnection) {
|
||||
defer s.wg.Done()
|
||||
defer s.closeConnection(active.ID)
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// 设置读取超时
|
||||
active.ClientConn.SetReadDeadline(time.Now().Add(ReadTimeout))
|
||||
n, err := active.ClientConn.Read(buffer)
|
||||
|
||||
if err != nil {
|
||||
// 检查是否是正在关闭的连接,避免记录无关错误
|
||||
if atomic.LoadInt32(&active.Closing) == 1 {
|
||||
return // 连接正在关闭,正常退出
|
||||
}
|
||||
|
||||
// 任何错误都应该终止转发,包括超时
|
||||
if err == io.EOF {
|
||||
log.Printf("客户端连接正常关闭 (ID=%d)", active.ID)
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
log.Printf("客户端连接超时 (ID=%d)", active.ID)
|
||||
} else {
|
||||
// 只记录非关闭相关的错误
|
||||
if !isConnectionClosed(err) {
|
||||
log.Printf("读取客户端连接失败 (ID=%d): %v", active.ID, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 读取到0字节,连接已关闭
|
||||
if n == 0 {
|
||||
log.Printf("客户端连接已关闭 (ID=%d, 读取0字节)", active.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// 重置读取超时
|
||||
active.ClientConn.SetReadDeadline(time.Time{})
|
||||
|
||||
// 检查连接是否正在关闭
|
||||
if atomic.LoadInt32(&active.Closing) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// 发送数据到隧道
|
||||
dataMsg := make([]byte, 4+n)
|
||||
binary.BigEndian.PutUint32(dataMsg[0:4], active.ID)
|
||||
copy(dataMsg[4:], buffer[:n])
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(dataMsg)),
|
||||
Data: dataMsg,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.sendChan <- msg:
|
||||
// 数据已发送
|
||||
case <-time.After(2 * time.Second): // 减少超时时间
|
||||
queueLen := len(s.sendChan)
|
||||
log.Printf("发送数据超时 (ID=%d), 队列长度: %d/10000", active.ID, queueLen)
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnection 关闭连接
|
||||
func (s *Server) closeConnection(connID uint32) {
|
||||
s.connMu.Lock()
|
||||
active, exists := s.activeConns[connID]
|
||||
if exists {
|
||||
// 使用原子操作标记连接正在关闭
|
||||
if !atomic.CompareAndSwapInt32(&active.Closing, 0, 1) {
|
||||
// 连接已经在关闭中,避免重复处理
|
||||
s.connMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
delete(s.activeConns, connID)
|
||||
// 记录关闭时间,避免重复发送关闭消息
|
||||
s.closingConns[connID] = time.Now()
|
||||
|
||||
// 确保连接被关闭
|
||||
if active.ClientConn != nil {
|
||||
active.ClientConn.Close()
|
||||
}
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
if !exists {
|
||||
// 连接不存在,检查是否已经在关闭列表中
|
||||
s.connMu.RLock()
|
||||
_, isClosing := s.closingConns[connID]
|
||||
s.connMu.RUnlock()
|
||||
|
||||
if isClosing {
|
||||
return // 已经处理过了
|
||||
}
|
||||
|
||||
// 标记为正在关闭
|
||||
s.connMu.Lock()
|
||||
s.closingConns[connID] = time.Now()
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
|
||||
// 发送关闭消息
|
||||
closeData := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(closeData, connID)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeClose,
|
||||
Length: 4,
|
||||
Data: closeData,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.sendChan <- msg:
|
||||
log.Printf("连接已关闭: ID=%d", connID)
|
||||
case <-time.After(1 * time.Second):
|
||||
log.Printf("发送关闭消息超时: ID=%d", connID)
|
||||
case <-s.ctx.Done():
|
||||
log.Printf("服务器关闭,跳过发送关闭消息: ID=%d", connID)
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardConnection 转发连接到隧道(新的透明代理实现)
|
||||
func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targetPort int) error {
|
||||
// ForwardConnection 创建隧道连接(返回 net.Conn 接口)
|
||||
func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targetPort int) (net.Conn, error) {
|
||||
s.mu.RLock()
|
||||
tunnelConnected := s.tunnelConn != nil
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !tunnelConnected {
|
||||
return fmt.Errorf("隧道连接不可用")
|
||||
return nil, fmt.Errorf("隧道连接不可用")
|
||||
}
|
||||
|
||||
// 创建待处理连接
|
||||
s.connMu.Lock()
|
||||
connID := s.nextConnID
|
||||
s.nextConnID++
|
||||
|
||||
|
||||
pending := &PendingConnection{
|
||||
ID: connID,
|
||||
ClientConn: clientConn,
|
||||
TargetPort: targetPort,
|
||||
TargetHost: targetHost, // 现在支持域名
|
||||
TargetHost: targetHost,
|
||||
Created: time.Now(),
|
||||
ResponseChan: make(chan bool, 1),
|
||||
}
|
||||
|
|
@ -711,7 +690,7 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targe
|
|||
s.connMu.Lock()
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
return fmt.Errorf("发送连接请求超时")
|
||||
return nil, fmt.Errorf("发送连接请求超时")
|
||||
}
|
||||
|
||||
log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetHost, targetPort)
|
||||
|
|
@ -720,18 +699,95 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targe
|
|||
select {
|
||||
case success := <-pending.ResponseChan:
|
||||
if success {
|
||||
return nil // 连接建立成功,forwardData会处理后续的数据转发
|
||||
// 连接建立成功,获取活动连接并创建 TunnelConn
|
||||
s.connMu.RLock()
|
||||
active, exists := s.activeConns[connID]
|
||||
s.connMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("活动连接不存在")
|
||||
}
|
||||
|
||||
// 创建 TunnelConn
|
||||
tunnelConn := &TunnelConn{
|
||||
server: s,
|
||||
connID: connID,
|
||||
targetHost: targetHost,
|
||||
targetPort: targetPort,
|
||||
recvChan: active.RecvChan,
|
||||
}
|
||||
|
||||
return tunnelConn, nil
|
||||
} else {
|
||||
return fmt.Errorf("远程连接失败")
|
||||
return nil, fmt.Errorf("远程连接失败")
|
||||
}
|
||||
case <-time.After(ConnectTimeout):
|
||||
s.connMu.Lock()
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
clientConn.Close()
|
||||
return fmt.Errorf("连接超时")
|
||||
return nil, fmt.Errorf("连接超时")
|
||||
case <-s.ctx.Done():
|
||||
return fmt.Errorf("服务器关闭")
|
||||
return nil, fmt.Errorf("服务器关闭")
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnection 关闭连接
|
||||
func (s *Server) closeConnection(connID uint32) {
|
||||
s.connMu.Lock()
|
||||
active, exists := s.activeConns[connID]
|
||||
if exists {
|
||||
// 使用原子操作标记连接正在关闭
|
||||
if !atomic.CompareAndSwapInt32(&active.Closing, 0, 1) {
|
||||
// 连接已经在关闭中,避免重复处理
|
||||
s.connMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
delete(s.activeConns, connID)
|
||||
// 记录关闭时间,避免重复发送关闭消息
|
||||
s.closingConns[connID] = time.Now()
|
||||
|
||||
// 关闭接收通道
|
||||
if active.RecvChan != nil {
|
||||
close(active.RecvChan)
|
||||
}
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
||||
if !exists {
|
||||
// 连接不存在,检查是否已经在关闭列表中
|
||||
s.connMu.RLock()
|
||||
_, isClosing := s.closingConns[connID]
|
||||
s.connMu.RUnlock()
|
||||
|
||||
if isClosing {
|
||||
return // 已经处理过了
|
||||
}
|
||||
|
||||
// 标记为正在关闭
|
||||
s.connMu.Lock()
|
||||
s.closingConns[connID] = time.Now()
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
|
||||
// 发送关闭消息
|
||||
closeData := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(closeData, connID)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeClose,
|
||||
Length: 4,
|
||||
Data: closeData,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.sendChan <- msg:
|
||||
log.Printf("连接已关闭: ID=%d", connID)
|
||||
case <-time.After(1 * time.Second):
|
||||
log.Printf("发送关闭消息超时: ID=%d", connID)
|
||||
case <-s.ctx.Done():
|
||||
log.Printf("服务器关闭,跳过发送关闭消息: ID=%d", connID)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -784,7 +840,7 @@ func isTimeout(err error) bool {
|
|||
// keepAliveLoop 心跳循环
|
||||
func (s *Server) keepAliveLoop(conn net.Conn) {
|
||||
defer s.wg.Done()
|
||||
|
||||
|
||||
ticker := time.NewTicker(KeepAliveInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
|
@ -848,31 +904,31 @@ func (s *Server) cleanupClosingConns() {
|
|||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.connMu.Lock()
|
||||
|
||||
|
||||
// 按时间清理过期记录
|
||||
for connID, closeTime := range s.closingConns {
|
||||
if now.Sub(closeTime) > maxAge {
|
||||
delete(s.closingConns, connID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果记录数量仍然过多,删除最旧的记录
|
||||
if len(s.closingConns) > maxClosingRecords {
|
||||
// 删除一半的最旧记录,避免频繁清理
|
||||
deleteCount := len(s.closingConns) - maxClosingRecords/2
|
||||
deletedCount := 0
|
||||
|
||||
|
||||
for connID, closeTime := range s.closingConns {
|
||||
if deletedCount >= deleteCount {
|
||||
break
|
||||
}
|
||||
if closeTime.Before(now.Add(-maxAge/2)) {
|
||||
if closeTime.Before(now.Add(-maxAge / 2)) {
|
||||
delete(s.closingConns, connID)
|
||||
deletedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
s.connMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
|
@ -885,6 +941,6 @@ func isConnectionClosed(err error) bool {
|
|||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "use of closed network connection") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe")
|
||||
}
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,27 +12,27 @@ import (
|
|||
// TestNewServer 测试创建隧道服务器
|
||||
func TestNewServer(t *testing.T) {
|
||||
server := NewServer(9000)
|
||||
|
||||
|
||||
if server == nil {
|
||||
t.Fatal("创建隧道服务器失败")
|
||||
}
|
||||
|
||||
|
||||
if server.listenPort != 9000 {
|
||||
t.Errorf("监听端口不正确,期望 9000,得到 %d", server.listenPort)
|
||||
}
|
||||
|
||||
|
||||
if server.pendingConns == nil {
|
||||
t.Error("待处理连接映射未初始化")
|
||||
}
|
||||
|
||||
|
||||
if server.activeConns == nil {
|
||||
t.Error("活跃连接映射未初始化")
|
||||
}
|
||||
|
||||
|
||||
if server.sendChan == nil {
|
||||
t.Error("发送通道未初始化")
|
||||
}
|
||||
|
||||
|
||||
if server.ctx == nil {
|
||||
t.Error("上下文未初始化")
|
||||
}
|
||||
|
|
@ -47,17 +47,17 @@ func TestServerStartStop(t *testing.T) {
|
|||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
|
||||
server := NewServer(port)
|
||||
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动服务器失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
// 验证服务器是否监听端口
|
||||
conn, err := net.Dial("tcp", server.listener.Addr().String())
|
||||
if err != nil {
|
||||
|
|
@ -65,7 +65,7 @@ func TestServerStartStop(t *testing.T) {
|
|||
} else {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
|
||||
// 停止服务器
|
||||
err = server.Stop()
|
||||
if err != nil {
|
||||
|
|
@ -83,14 +83,14 @@ func TestTunnelMessage(t *testing.T) {
|
|||
Length: uint32(len(data)),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
|
||||
server := NewServer(9000)
|
||||
|
||||
|
||||
// 测试写入消息
|
||||
go func() {
|
||||
err := server.writeTunnelMessage(serverConn, msg)
|
||||
|
|
@ -98,26 +98,26 @@ func TestTunnelMessage(t *testing.T) {
|
|||
t.Errorf("写入隧道消息失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// 测试读取消息
|
||||
receivedMsg, err := server.readTunnelMessage(clientConn)
|
||||
if err != nil {
|
||||
t.Fatalf("读取隧道消息失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 验证消息内容
|
||||
if receivedMsg.Version != msg.Version {
|
||||
t.Errorf("版本不匹配,期望 %d,得到 %d", msg.Version, receivedMsg.Version)
|
||||
}
|
||||
|
||||
|
||||
if receivedMsg.Type != msg.Type {
|
||||
t.Errorf("类型不匹配,期望 %d,得到 %d", msg.Type, receivedMsg.Type)
|
||||
}
|
||||
|
||||
|
||||
if receivedMsg.Length != msg.Length {
|
||||
t.Errorf("长度不匹配,期望 %d,得到 %d", msg.Length, receivedMsg.Length)
|
||||
}
|
||||
|
||||
|
||||
if string(receivedMsg.Data) != string(msg.Data) {
|
||||
t.Errorf("数据不匹配,期望 %s,得到 %s", string(msg.Data), string(receivedMsg.Data))
|
||||
}
|
||||
|
|
@ -131,9 +131,9 @@ func TestConnectRequest(t *testing.T) {
|
|||
t.Fatalf("启动测试服务器失败: %v", err)
|
||||
}
|
||||
defer testListener.Close()
|
||||
|
||||
|
||||
testPort := testListener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
|
||||
// 启动一个简单的echo服务
|
||||
go func() {
|
||||
for {
|
||||
|
|
@ -147,41 +147,41 @@ func TestConnectRequest(t *testing.T) {
|
|||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// 模拟测试(实际测试需要更复杂的设置)
|
||||
t.Logf("测试服务器运行在端口: %d", testPort)
|
||||
|
||||
|
||||
// 创建隧道服务器,使用随机端口
|
||||
tunnelListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("创建隧道监听器失败: %v", err)
|
||||
}
|
||||
defer tunnelListener.Close()
|
||||
|
||||
|
||||
tunnelPort := tunnelListener.Addr().(*net.TCPAddr).Port
|
||||
tunnelListener.Close() // 关闭以便服务器重新绑定
|
||||
|
||||
|
||||
server := NewServer(tunnelPort)
|
||||
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动隧道服务器失败: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
// 模拟客户端连接
|
||||
tunnelConn, err := net.Dial("tcp", server.listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("连接隧道服务器失败: %v", err)
|
||||
}
|
||||
defer tunnelConn.Close()
|
||||
|
||||
|
||||
// 等待隧道连接被处理
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
// 验证隧道是否已连接
|
||||
if !server.IsConnected() {
|
||||
t.Error("隧道未连接")
|
||||
|
|
@ -191,28 +191,28 @@ func TestConnectRequest(t *testing.T) {
|
|||
// TestProtocolVersionCheck 测试协议版本检查
|
||||
func TestProtocolVersionCheck(t *testing.T) {
|
||||
server := NewServer(9000)
|
||||
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
|
||||
// 发送错误版本的消息
|
||||
wrongVersionHeader := make([]byte, HeaderSize)
|
||||
wrongVersionHeader[0] = 0xFF // 错误版本
|
||||
wrongVersionHeader[1] = MsgTypeData
|
||||
binary.BigEndian.PutUint32(wrongVersionHeader[2:6], 0)
|
||||
|
||||
|
||||
go func() {
|
||||
clientConn.Write(wrongVersionHeader)
|
||||
}()
|
||||
|
||||
|
||||
// 尝试读取消息,应该返回错误
|
||||
_, err := server.readTunnelMessage(serverConn)
|
||||
if err == nil {
|
||||
t.Error("期望版本检查失败,但成功了")
|
||||
}
|
||||
|
||||
|
||||
if err.Error() != "不支持的协议版本: 255" {
|
||||
t.Errorf("错误消息不正确: %v", err)
|
||||
}
|
||||
|
|
@ -221,22 +221,22 @@ func TestProtocolVersionCheck(t *testing.T) {
|
|||
// TestMaxPacketSizeCheck 测试最大包大小检查
|
||||
func TestMaxPacketSizeCheck(t *testing.T) {
|
||||
server := NewServer(9000)
|
||||
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
|
||||
// 发送超大包
|
||||
oversizedHeader := make([]byte, HeaderSize)
|
||||
oversizedHeader[0] = ProtocolVersion
|
||||
oversizedHeader[1] = MsgTypeData
|
||||
binary.BigEndian.PutUint32(oversizedHeader[2:6], MaxPacketSize+1)
|
||||
|
||||
|
||||
go func() {
|
||||
clientConn.Write(oversizedHeader)
|
||||
}()
|
||||
|
||||
|
||||
// 尝试读取消息,应该返回错误
|
||||
_, err := server.readTunnelMessage(serverConn)
|
||||
if err == nil {
|
||||
|
|
@ -253,41 +253,41 @@ func TestConcurrentConnections(t *testing.T) {
|
|||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
|
||||
server := NewServer(port)
|
||||
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动服务器失败: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
var wg sync.WaitGroup
|
||||
connCount := 5
|
||||
|
||||
|
||||
// 并发创建多个连接
|
||||
for i := 0; i < connCount; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
|
||||
conn, err := net.Dial("tcp", server.listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("连接 %d 失败: %v", id, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
|
||||
// 保持连接一段时间
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}(i)
|
||||
}
|
||||
|
||||
|
||||
wg.Wait()
|
||||
|
||||
|
||||
// 验证只有一个隧道连接被接受
|
||||
if server.IsConnected() {
|
||||
// 应该只有一个连接被接受
|
||||
|
|
@ -298,12 +298,12 @@ func TestConcurrentConnections(t *testing.T) {
|
|||
// TestKeepAlive 测试心跳消息
|
||||
func TestKeepAlive(t *testing.T) {
|
||||
server := NewServer(9000)
|
||||
|
||||
|
||||
// 创建模拟连接
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
|
||||
// 创建心跳消息
|
||||
keepAliveMsg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
|
|
@ -311,7 +311,7 @@ func TestKeepAlive(t *testing.T) {
|
|||
Length: 0,
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
|
||||
// 启动消息处理
|
||||
go func() {
|
||||
for {
|
||||
|
|
@ -322,7 +322,7 @@ func TestKeepAlive(t *testing.T) {
|
|||
server.handleTunnelMessage(msg)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// 启动发送处理
|
||||
go func() {
|
||||
for {
|
||||
|
|
@ -334,19 +334,19 @@ func TestKeepAlive(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// 发送心跳
|
||||
err := server.writeTunnelMessage(serverConn, keepAliveMsg)
|
||||
if err != nil {
|
||||
t.Fatalf("发送心跳失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 读取响应
|
||||
response, err := server.readTunnelMessage(clientConn)
|
||||
if err != nil {
|
||||
t.Fatalf("读取心跳响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if response.Type != MsgTypeKeepAlive {
|
||||
t.Errorf("心跳响应类型不正确,期望 %d,得到 %d", MsgTypeKeepAlive, response.Type)
|
||||
}
|
||||
|
|
@ -361,31 +361,31 @@ func (mc *MockClient) sendConnectResponse(connID uint32, status byte) error {
|
|||
responseData := make([]byte, 5)
|
||||
binary.BigEndian.PutUint32(responseData[0:4], connID)
|
||||
responseData[4] = status
|
||||
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeConnectResponse,
|
||||
Length: 5,
|
||||
Data: responseData,
|
||||
}
|
||||
|
||||
|
||||
// 构建消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
header[0] = msg.Version
|
||||
header[1] = msg.Type
|
||||
binary.BigEndian.PutUint32(header[2:6], msg.Length)
|
||||
|
||||
|
||||
// 写入消息
|
||||
if _, err := mc.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
if msg.Length > 0 && msg.Data != nil {
|
||||
if _, err := mc.conn.Write(msg.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -397,9 +397,9 @@ func TestForwardConnection(t *testing.T) {
|
|||
t.Fatalf("启动目标服务失败: %v", err)
|
||||
}
|
||||
defer targetListener.Close()
|
||||
|
||||
|
||||
targetPort := targetListener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
|
||||
// 启动简单的echo服务
|
||||
go func() {
|
||||
for {
|
||||
|
|
@ -413,7 +413,7 @@ func TestForwardConnection(t *testing.T) {
|
|||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// 创建隧道服务器,使用随机端口
|
||||
tunnelListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
|
|
@ -421,80 +421,81 @@ func TestForwardConnection(t *testing.T) {
|
|||
}
|
||||
tunnelPort := tunnelListener.Addr().(*net.TCPAddr).Port
|
||||
tunnelListener.Close() // 关闭以便服务器重新绑定
|
||||
|
||||
|
||||
server := NewServer(tunnelPort)
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("启动隧道服务器失败: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
// 连接到隧道服务器(模拟客户端)
|
||||
tunnelConn, err := net.Dial("tcp", server.listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("连接隧道服务器失败: %v", err)
|
||||
}
|
||||
defer tunnelConn.Close()
|
||||
|
||||
|
||||
// 等待隧道连接建立
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
if !server.IsConnected() {
|
||||
t.Fatal("隧道未连接")
|
||||
}
|
||||
|
||||
|
||||
// 创建模拟客户端连接
|
||||
clientConn, serverSideConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverSideConn.Close()
|
||||
|
||||
|
||||
// 创建模拟客户端
|
||||
mockClient := &MockClient{conn: tunnelConn}
|
||||
|
||||
|
||||
// 启动连接转发
|
||||
go func() {
|
||||
err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
|
||||
conn, err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
|
||||
if err != nil {
|
||||
t.Errorf("转发连接失败: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
}()
|
||||
|
||||
|
||||
// 读取连接请求
|
||||
header := make([]byte, HeaderSize)
|
||||
_, err = io.ReadFull(tunnelConn, header)
|
||||
if err != nil {
|
||||
t.Fatalf("读取连接请求头失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if header[1] != MsgTypeConnectRequest {
|
||||
t.Fatalf("期望连接请求,得到消息类型: %d", header[1])
|
||||
}
|
||||
|
||||
|
||||
dataLen := binary.BigEndian.Uint32(header[2:6])
|
||||
data := make([]byte, dataLen)
|
||||
_, err = io.ReadFull(tunnelConn, data)
|
||||
if err != nil {
|
||||
t.Fatalf("读取连接请求数据失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
connID := binary.BigEndian.Uint32(data[0:4])
|
||||
requestedPort := binary.BigEndian.Uint16(data[4:6])
|
||||
|
||||
|
||||
if int(requestedPort) != targetPort {
|
||||
t.Errorf("请求端口不匹配,期望 %d,得到 %d", targetPort, requestedPort)
|
||||
}
|
||||
|
||||
|
||||
// 发送连接成功响应
|
||||
err = mockClient.sendConnectResponse(connID, ConnStatusSuccess)
|
||||
if err != nil {
|
||||
t.Fatalf("发送连接响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 等待连接建立
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
t.Log("连接转发测试完成")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue