feat: 重构了tunnel,现在统一了转发逻辑

This commit is contained in:
qcqcqc@wsl 2026-01-09 13:08:13 +08:00
parent 2655b5592f
commit fc1614f7a4
5 changed files with 413 additions and 304 deletions

View File

@ -16,7 +16,7 @@ import (
// TunnelServer 隧道服务器接口
type TunnelServer interface {
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) (net.Conn, error)
IsConnected() bool
GetTrafficStats() stats.TrafficStats
}
@ -46,7 +46,13 @@ func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int6
ctx, cancel := context.WithCancel(context.Background())
var limiterOut, limiterIn *rate.Limiter
if limit != nil {
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
// burst设置为1秒的流量这样可以平滑处理突发
// 同时不会一次性消耗太多令牌
burst := int(*limit) / 100
if burst < 10240 {
burst = 10240 // 最小burst为10KB
}
log.Printf("设置限速: %d bytes/sec, burst: %d bytes", *limit, burst)
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
@ -68,7 +74,13 @@ func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunne
ctx, cancel := context.WithCancel(context.Background())
var limiterOut, limiterIn *rate.Limiter
if limit != nil {
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
// burst设置为1秒的流量这样可以平滑处理突发
// 同时不会一次性消耗太多令牌
burst := int(*limit) / 100
if burst < 10240 {
burst = 10240 // 最小burst为10KB
}
log.Printf("设置限速: %d bytes/sec, burst: %d bytes", *limit, burst)
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
@ -142,54 +154,76 @@ type rateLimitedReader struct {
}
func (rlr *rateLimitedReader) Read(p []byte) (int, error) {
if rlr.limiter != nil {
maxReq := rlr.limiter.Burst()
reqSize := len(p)
if reqSize > maxReq {
reqSize = maxReq // 避免一次申请超过桶容量导致错误
}
err := rlr.limiter.WaitN(rlr.ctx, reqSize)
if err != nil {
return 0, err
}
if rlr.limiter == nil {
return rlr.r.Read(p)
}
return rlr.r.Read(p)
// 使用更小的块大小以实现更平滑的限流
// 2KB是一个合理的值既不会太频繁调用也能保持流量平滑
chunkSize := 2048
// 不超过缓冲区大小
if len(p) < chunkSize {
chunkSize = len(p)
}
// 预先申请令牌
if err := rlr.limiter.WaitN(rlr.ctx, chunkSize); err != nil {
return 0, err
}
// 限制实际读取大小
if len(p) > chunkSize {
p = p[:chunkSize]
}
// 执行读取
n, err := rlr.r.Read(p)
// 如果实际读取少于申请的令牌,归还多余的令牌
// 注意rate.Limiter 不支持归还令牌,所以这里只能接受这个损耗
// 或者改用先读后限的方式
return n, err
}
// handleConnection 处理单个连接
func (f *Forwarder) handleConnection(clientConn net.Conn) {
defer f.wg.Done()
defer clientConn.Close()
var targetConn net.Conn
var err error
if f.useTunnel {
// 使用隧道转发
// 使用隧道转发,获取 TunnelConn
if f.tunnelServer == nil || !f.tunnelServer.IsConnected() {
log.Printf("隧道服务器不可用 (端口 %d)", f.sourcePort)
clientConn.Close()
return
}
// 将连接转发到隧道ForwardConnection 会处理连接关闭
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetHost, f.targetPort); err != nil {
// 获取隧道连接
targetConn, err = f.tunnelServer.ForwardConnection(clientConn, f.targetHost, f.targetPort)
if err != nil {
log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetHost, f.targetPort, err)
return
}
} else {
// 直接连接目标
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
// 动态解析域名并连接
targetAddr := fmt.Sprintf("%s:%d", f.targetHost, f.targetPort)
targetConn, err = dialer.DialContext(f.ctx, "tcp", targetAddr)
if err != nil {
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, targetAddr, err)
return
}
return
}
// 直接连接目标
defer clientConn.Close()
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
// 动态解析域名并连接
targetAddr := fmt.Sprintf("%s:%d", f.targetHost, f.targetPort)
targetConn, err := dialer.DialContext(f.ctx, "tcp", targetAddr)
if err != nil {
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, targetAddr, err)
return
}
defer targetConn.Close()
// 双向转发

View File

@ -39,10 +39,10 @@ type mockTunnelServer struct {
connected bool
}
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) error {
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) (net.Conn, error) {
// 简单的模拟实现
defer clientConn.Close()
return nil
return nil, nil
}
func (m *mockTunnelServer) IsConnected() bool {

View File

@ -492,6 +492,22 @@
let currentMappings = [];
let apiKey = '';
// 格式化带宽限制为可读格式
function formatBandwidth(bytes) {
if (!bytes || bytes === 0) {
return '无限制';
}
if (bytes < 1024) {
return bytes + ' B/s';
} else if (bytes < 1024 * 1024) {
return (bytes / 1024).toFixed(2) + ' KB/s';
} else if (bytes < 1024 * 1024 * 1024) {
return (bytes / (1024 * 1024)).toFixed(2) + ' MB/s';
} else {
return (bytes / (1024 * 1024 * 1024)).toFixed(2) + ' GB/s';
}
}
// 检查是否已有 API Key
function checkAuth() {
const savedKey = sessionStorage.getItem('apiKey');
@ -733,6 +749,7 @@
'<span class="status-badge ' + (mapping.use_tunnel ? 'status-tunnel' : 'status-direct') + '">' +
(mapping.use_tunnel ? '隧道模式' : '直连模式') +
'</span><br>' +
'<strong>带宽限制:</strong> ' + formatBandwidth(mapping.bandwidth_limit) + '<br>' +
'<strong>创建时间:</strong> ' + new Date(mapping.created_at).toLocaleString('zh-CN') +
'</div>' +
'</div>' +
@ -800,7 +817,8 @@
'<strong>目标:</strong> ' + mapping.target_host + ':' + mapping.target_port + ' | ' +
'<span class="status-badge ' + (mapping.use_tunnel ? 'status-tunnel' : 'status-direct') + '">' +
(mapping.use_tunnel ? '隧道' : '直连') +
'</span>' +
'</span> | ' +
'<strong>带宽:</strong> ' + formatBandwidth(mapping.bandwidth_limit) +
'</div>' +
'</div>' +
'</div>'

View File

@ -20,28 +20,28 @@ import (
const (
// 协议版本
ProtocolVersion = 0x01
// 消息头大小
HeaderSize = 6 // 版本(1) + 类型(1) + 长度(4)
// 最大包大小 (1MB)
MaxPacketSize = 1024 * 1024
// 消息类型
MsgTypeConnectRequest = 0x01 // 连接请求
MsgTypeConnectResponse = 0x02 // 连接响应
MsgTypeData = 0x03 // 数据传输
MsgTypeClose = 0x04 // 关闭连接
MsgTypeKeepAlive = 0x05 // 心跳
MsgTypeData = 0x03 // 数据传输
MsgTypeClose = 0x04 // 关闭连接
MsgTypeKeepAlive = 0x05 // 心跳
// 连接响应状态
ConnStatusSuccess = 0x00 // 连接成功
ConnStatusFailed = 0x01 // 连接失败
// 超时设置
ConnectTimeout = 10 * time.Second // 连接超时
ReadTimeout = 300 * time.Second // 读取超时统一为60秒
KeepAliveInterval = 15 * time.Second // 心跳间隔
ConnectTimeout = 10 * time.Second // 连接超时
ReadTimeout = 300 * time.Second // 读取超时统一为60秒
KeepAliveInterval = 15 * time.Second // 心跳间隔
)
// TunnelMessage 隧道消息
@ -94,7 +94,115 @@ type ActiveConnection struct {
TargetHost string // 支持IP或域名
TargetIP string
Created time.Time
Closing int32 // 原子操作标志,表示连接正在关闭
Closing int32 // 原子操作标志,表示连接正在关闭
RecvChan chan []byte // 接收数据的通道
}
// TunnelConn 实现 net.Conn 接口的隧道连接
type TunnelConn struct {
server *Server
connID uint32
targetHost string
targetPort int
recvChan chan []byte // 接收数据的通道
readBuffer []byte // 读取缓冲区
closed int32 // 原子操作标志,表示连接已关闭
closeOnce sync.Once // 确保只关闭一次
}
// Read 实现 net.Conn 接口
func (tc *TunnelConn) Read(b []byte) (n int, err error) {
if atomic.LoadInt32(&tc.closed) == 1 {
return 0, io.EOF
}
// 如果有缓冲数据,先返回缓冲数据
if len(tc.readBuffer) > 0 {
n = copy(b, tc.readBuffer)
tc.readBuffer = tc.readBuffer[n:]
return n, nil
}
// 从通道读取数据
select {
case data, ok := <-tc.recvChan:
if !ok {
return 0, io.EOF
}
n = copy(b, data)
// 如果数据太多,保存剩余部分到缓冲区
if n < len(data) {
tc.readBuffer = data[n:]
}
return n, nil
case <-tc.server.ctx.Done():
return 0, io.EOF
}
}
// Write 实现 net.Conn 接口
func (tc *TunnelConn) Write(b []byte) (n int, err error) {
if atomic.LoadInt32(&tc.closed) == 1 {
return 0, fmt.Errorf("connection closed")
}
// 发送数据到隧道
dataMsg := make([]byte, 4+len(b))
binary.BigEndian.PutUint32(dataMsg[0:4], tc.connID)
copy(dataMsg[4:], b)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(dataMsg)),
Data: dataMsg,
}
select {
case tc.server.sendChan <- msg:
return len(b), nil
case <-time.After(2 * time.Second):
return 0, fmt.Errorf("write timeout")
case <-tc.server.ctx.Done():
return 0, fmt.Errorf("server closed")
}
}
// Close 实现 net.Conn 接口
func (tc *TunnelConn) Close() error {
tc.closeOnce.Do(func() {
atomic.StoreInt32(&tc.closed, 1)
tc.server.closeConnection(tc.connID)
})
return nil
}
// LocalAddr 实现 net.Conn 接口
func (tc *TunnelConn) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4zero, Port: 0}
}
// RemoteAddr 实现 net.Conn 接口
func (tc *TunnelConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.ParseIP(tc.targetHost), Port: tc.targetPort}
}
// SetDeadline 实现 net.Conn 接口
func (tc *TunnelConn) SetDeadline(t time.Time) error {
// 隧道连接不支持 deadline可以根据需要实现
return nil
}
// SetReadDeadline 实现 net.Conn 接口
func (tc *TunnelConn) SetReadDeadline(t time.Time) error {
// 隧道连接不支持 deadline可以根据需要实现
return nil
}
// SetWriteDeadline 实现 net.Conn 接口
func (tc *TunnelConn) SetWriteDeadline(t time.Time) error {
// 隧道连接不支持 deadline可以根据需要实现
return nil
}
// Server 内网穿透服务器
@ -106,17 +214,17 @@ type Server struct {
ctx context.Context
wg sync.WaitGroup
mu sync.RWMutex
// 连接管理
pendingConns map[uint32]*PendingConnection // 待确认连接
activeConns map[uint32]*ActiveConnection // 活跃连接
closingConns map[uint32]time.Time // 正在关闭的连接,避免重复处理
connMu sync.RWMutex
nextConnID uint32
// 消息队列
sendChan chan *TunnelMessage
// 流量统计(使用原子操作)
bytesSent uint64 // 通过隧道发送的总字节数
bytesReceived uint64 // 通过隧道接收的总字节数
@ -134,10 +242,10 @@ func NewServer(listenPort int) *Server {
closingConns: make(map[uint32]time.Time),
sendChan: make(chan *TunnelMessage, 10000), // 增加到10000
}
// 启动清理器,定期清理过期的关闭连接记录
go server.cleanupClosingConns()
return server
}
@ -211,7 +319,7 @@ func (s *Server) handleTunnelRead(conn net.Conn) {
s.tunnelConn = nil
s.mu.Unlock()
log.Printf("隧道客户端已断开")
// 关闭所有活动连接
s.connMu.Lock()
for _, c := range s.pendingConns {
@ -293,7 +401,7 @@ func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
if _, err := io.ReadFull(conn, header); err != nil {
return nil, err
}
// 统计接收字节数
s.addBytesReceived(uint64(HeaderSize))
@ -333,7 +441,7 @@ func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
// 设置写入超时,防止阻塞
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
defer conn.SetWriteDeadline(time.Time{}) // 重置超时
// 构建消息头
header := make([]byte, HeaderSize)
header[0] = msg.Version
@ -344,7 +452,7 @@ func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
if _, err := conn.Write(header); err != nil {
return fmt.Errorf("写入消息头失败: %w", err)
}
// 统计发送字节数
s.addBytesSent(uint64(HeaderSize))
@ -398,13 +506,16 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
s.connMu.Unlock()
if status == ConnStatusSuccess {
// 连接成功,移到活跃连接
// 连接成功,创建接收通道
recvChan := make(chan []byte, 100)
active := &ActiveConnection{
ID: connID,
ClientConn: pending.ClientConn,
TargetPort: pending.TargetPort,
TargetHost: pending.TargetHost,
Created: time.Now(),
RecvChan: recvChan,
}
s.connMu.Lock()
@ -413,10 +524,6 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
log.Printf("连接已建立: ID=%d, 地址=%s:%d", connID, pending.TargetHost, pending.TargetPort)
// 启动数据转发
s.wg.Add(1)
go s.forwardData(active)
// 通知等待的goroutine
select {
case pending.ResponseChan <- true:
@ -463,7 +570,7 @@ func (s *Server) handleDataMessage(msg *TunnelMessage) {
s.connMu.Lock()
s.closingConns[connID] = time.Now()
s.connMu.Unlock()
// 连接不存在,发送关闭消息通知对端
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
@ -486,10 +593,18 @@ func (s *Server) handleDataMessage(msg *TunnelMessage) {
return
}
// 写入到客户端连接
if _, err := active.ClientConn.Write(data); err != nil {
log.Printf("写入客户端连接失败 (ID=%d): %v", connID, err)
// 发送数据到接收通道
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
select {
case active.RecvChan <- dataCopy:
// 数据已发送到接收通道
case <-time.After(2 * time.Second):
log.Printf("发送数据到接收通道超时 (ID=%d)", connID)
s.closeConnection(connID)
case <-s.ctx.Done():
return
}
}
@ -527,162 +642,26 @@ func (s *Server) handleKeepAlive(msg *TunnelMessage) {
}
}
// forwardData 转发数据
func (s *Server) forwardData(active *ActiveConnection) {
defer s.wg.Done()
defer s.closeConnection(active.ID)
buffer := make([]byte, 32*1024)
for {
select {
case <-s.ctx.Done():
return
default:
}
// 设置读取超时
active.ClientConn.SetReadDeadline(time.Now().Add(ReadTimeout))
n, err := active.ClientConn.Read(buffer)
if err != nil {
// 检查是否是正在关闭的连接,避免记录无关错误
if atomic.LoadInt32(&active.Closing) == 1 {
return // 连接正在关闭,正常退出
}
// 任何错误都应该终止转发,包括超时
if err == io.EOF {
log.Printf("客户端连接正常关闭 (ID=%d)", active.ID)
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
log.Printf("客户端连接超时 (ID=%d)", active.ID)
} else {
// 只记录非关闭相关的错误
if !isConnectionClosed(err) {
log.Printf("读取客户端连接失败 (ID=%d): %v", active.ID, err)
}
}
return
}
// 读取到0字节连接已关闭
if n == 0 {
log.Printf("客户端连接已关闭 (ID=%d, 读取0字节)", active.ID)
return
}
// 重置读取超时
active.ClientConn.SetReadDeadline(time.Time{})
// 检查连接是否正在关闭
if atomic.LoadInt32(&active.Closing) == 1 {
return
}
// 发送数据到隧道
dataMsg := make([]byte, 4+n)
binary.BigEndian.PutUint32(dataMsg[0:4], active.ID)
copy(dataMsg[4:], buffer[:n])
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(dataMsg)),
Data: dataMsg,
}
select {
case s.sendChan <- msg:
// 数据已发送
case <-time.After(2 * time.Second): // 减少超时时间
queueLen := len(s.sendChan)
log.Printf("发送数据超时 (ID=%d), 队列长度: %d/10000", active.ID, queueLen)
return
case <-s.ctx.Done():
return
}
}
}
// closeConnection 关闭连接
func (s *Server) closeConnection(connID uint32) {
s.connMu.Lock()
active, exists := s.activeConns[connID]
if exists {
// 使用原子操作标记连接正在关闭
if !atomic.CompareAndSwapInt32(&active.Closing, 0, 1) {
// 连接已经在关闭中,避免重复处理
s.connMu.Unlock()
return
}
delete(s.activeConns, connID)
// 记录关闭时间,避免重复发送关闭消息
s.closingConns[connID] = time.Now()
// 确保连接被关闭
if active.ClientConn != nil {
active.ClientConn.Close()
}
}
s.connMu.Unlock()
if !exists {
// 连接不存在,检查是否已经在关闭列表中
s.connMu.RLock()
_, isClosing := s.closingConns[connID]
s.connMu.RUnlock()
if isClosing {
return // 已经处理过了
}
// 标记为正在关闭
s.connMu.Lock()
s.closingConns[connID] = time.Now()
s.connMu.Unlock()
}
// 发送关闭消息
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeClose,
Length: 4,
Data: closeData,
}
select {
case s.sendChan <- msg:
log.Printf("连接已关闭: ID=%d", connID)
case <-time.After(1 * time.Second):
log.Printf("发送关闭消息超时: ID=%d", connID)
case <-s.ctx.Done():
log.Printf("服务器关闭,跳过发送关闭消息: ID=%d", connID)
}
}
// ForwardConnection 转发连接到隧道(新的透明代理实现)
func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targetPort int) error {
// ForwardConnection 创建隧道连接(返回 net.Conn 接口)
func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targetPort int) (net.Conn, error) {
s.mu.RLock()
tunnelConnected := s.tunnelConn != nil
s.mu.RUnlock()
if !tunnelConnected {
return fmt.Errorf("隧道连接不可用")
return nil, fmt.Errorf("隧道连接不可用")
}
// 创建待处理连接
s.connMu.Lock()
connID := s.nextConnID
s.nextConnID++
pending := &PendingConnection{
ID: connID,
ClientConn: clientConn,
TargetPort: targetPort,
TargetHost: targetHost, // 现在支持域名
TargetHost: targetHost,
Created: time.Now(),
ResponseChan: make(chan bool, 1),
}
@ -711,7 +690,7 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targe
s.connMu.Lock()
delete(s.pendingConns, connID)
s.connMu.Unlock()
return fmt.Errorf("发送连接请求超时")
return nil, fmt.Errorf("发送连接请求超时")
}
log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetHost, targetPort)
@ -720,18 +699,95 @@ func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targe
select {
case success := <-pending.ResponseChan:
if success {
return nil // 连接建立成功forwardData会处理后续的数据转发
// 连接建立成功,获取活动连接并创建 TunnelConn
s.connMu.RLock()
active, exists := s.activeConns[connID]
s.connMu.RUnlock()
if !exists {
return nil, fmt.Errorf("活动连接不存在")
}
// 创建 TunnelConn
tunnelConn := &TunnelConn{
server: s,
connID: connID,
targetHost: targetHost,
targetPort: targetPort,
recvChan: active.RecvChan,
}
return tunnelConn, nil
} else {
return fmt.Errorf("远程连接失败")
return nil, fmt.Errorf("远程连接失败")
}
case <-time.After(ConnectTimeout):
s.connMu.Lock()
delete(s.pendingConns, connID)
s.connMu.Unlock()
clientConn.Close()
return fmt.Errorf("连接超时")
return nil, fmt.Errorf("连接超时")
case <-s.ctx.Done():
return fmt.Errorf("服务器关闭")
return nil, fmt.Errorf("服务器关闭")
}
}
// closeConnection 关闭连接
func (s *Server) closeConnection(connID uint32) {
s.connMu.Lock()
active, exists := s.activeConns[connID]
if exists {
// 使用原子操作标记连接正在关闭
if !atomic.CompareAndSwapInt32(&active.Closing, 0, 1) {
// 连接已经在关闭中,避免重复处理
s.connMu.Unlock()
return
}
delete(s.activeConns, connID)
// 记录关闭时间,避免重复发送关闭消息
s.closingConns[connID] = time.Now()
// 关闭接收通道
if active.RecvChan != nil {
close(active.RecvChan)
}
}
s.connMu.Unlock()
if !exists {
// 连接不存在,检查是否已经在关闭列表中
s.connMu.RLock()
_, isClosing := s.closingConns[connID]
s.connMu.RUnlock()
if isClosing {
return // 已经处理过了
}
// 标记为正在关闭
s.connMu.Lock()
s.closingConns[connID] = time.Now()
s.connMu.Unlock()
}
// 发送关闭消息
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeClose,
Length: 4,
Data: closeData,
}
select {
case s.sendChan <- msg:
log.Printf("连接已关闭: ID=%d", connID)
case <-time.After(1 * time.Second):
log.Printf("发送关闭消息超时: ID=%d", connID)
case <-s.ctx.Done():
log.Printf("服务器关闭,跳过发送关闭消息: ID=%d", connID)
}
}
@ -784,7 +840,7 @@ func isTimeout(err error) bool {
// keepAliveLoop 心跳循环
func (s *Server) keepAliveLoop(conn net.Conn) {
defer s.wg.Done()
ticker := time.NewTicker(KeepAliveInterval)
defer ticker.Stop()
@ -848,31 +904,31 @@ func (s *Server) cleanupClosingConns() {
case <-ticker.C:
now := time.Now()
s.connMu.Lock()
// 按时间清理过期记录
for connID, closeTime := range s.closingConns {
if now.Sub(closeTime) > maxAge {
delete(s.closingConns, connID)
}
}
// 如果记录数量仍然过多,删除最旧的记录
if len(s.closingConns) > maxClosingRecords {
// 删除一半的最旧记录,避免频繁清理
deleteCount := len(s.closingConns) - maxClosingRecords/2
deletedCount := 0
for connID, closeTime := range s.closingConns {
if deletedCount >= deleteCount {
break
}
if closeTime.Before(now.Add(-maxAge/2)) {
if closeTime.Before(now.Add(-maxAge / 2)) {
delete(s.closingConns, connID)
deletedCount++
}
}
}
s.connMu.Unlock()
}
}
@ -885,6 +941,6 @@ func isConnectionClosed(err error) bool {
}
errStr := err.Error()
return strings.Contains(errStr, "use of closed network connection") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe")
}
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe")
}

View File

@ -12,27 +12,27 @@ import (
// TestNewServer 测试创建隧道服务器
func TestNewServer(t *testing.T) {
server := NewServer(9000)
if server == nil {
t.Fatal("创建隧道服务器失败")
}
if server.listenPort != 9000 {
t.Errorf("监听端口不正确,期望 9000得到 %d", server.listenPort)
}
if server.pendingConns == nil {
t.Error("待处理连接映射未初始化")
}
if server.activeConns == nil {
t.Error("活跃连接映射未初始化")
}
if server.sendChan == nil {
t.Error("发送通道未初始化")
}
if server.ctx == nil {
t.Error("上下文未初始化")
}
@ -47,17 +47,17 @@ func TestServerStartStop(t *testing.T) {
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
server := NewServer(port)
err = server.Start()
if err != nil {
t.Fatalf("启动服务器失败: %v", err)
}
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
// 验证服务器是否监听端口
conn, err := net.Dial("tcp", server.listener.Addr().String())
if err != nil {
@ -65,7 +65,7 @@ func TestServerStartStop(t *testing.T) {
} else {
conn.Close()
}
// 停止服务器
err = server.Stop()
if err != nil {
@ -83,14 +83,14 @@ func TestTunnelMessage(t *testing.T) {
Length: uint32(len(data)),
Data: data,
}
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
server := NewServer(9000)
// 测试写入消息
go func() {
err := server.writeTunnelMessage(serverConn, msg)
@ -98,26 +98,26 @@ func TestTunnelMessage(t *testing.T) {
t.Errorf("写入隧道消息失败: %v", err)
}
}()
// 测试读取消息
receivedMsg, err := server.readTunnelMessage(clientConn)
if err != nil {
t.Fatalf("读取隧道消息失败: %v", err)
}
// 验证消息内容
if receivedMsg.Version != msg.Version {
t.Errorf("版本不匹配,期望 %d得到 %d", msg.Version, receivedMsg.Version)
}
if receivedMsg.Type != msg.Type {
t.Errorf("类型不匹配,期望 %d得到 %d", msg.Type, receivedMsg.Type)
}
if receivedMsg.Length != msg.Length {
t.Errorf("长度不匹配,期望 %d得到 %d", msg.Length, receivedMsg.Length)
}
if string(receivedMsg.Data) != string(msg.Data) {
t.Errorf("数据不匹配,期望 %s得到 %s", string(msg.Data), string(receivedMsg.Data))
}
@ -131,9 +131,9 @@ func TestConnectRequest(t *testing.T) {
t.Fatalf("启动测试服务器失败: %v", err)
}
defer testListener.Close()
testPort := testListener.Addr().(*net.TCPAddr).Port
// 启动一个简单的echo服务
go func() {
for {
@ -147,41 +147,41 @@ func TestConnectRequest(t *testing.T) {
}(conn)
}
}()
// 模拟测试(实际测试需要更复杂的设置)
t.Logf("测试服务器运行在端口: %d", testPort)
// 创建隧道服务器,使用随机端口
tunnelListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建隧道监听器失败: %v", err)
}
defer tunnelListener.Close()
tunnelPort := tunnelListener.Addr().(*net.TCPAddr).Port
tunnelListener.Close() // 关闭以便服务器重新绑定
server := NewServer(tunnelPort)
err = server.Start()
if err != nil {
t.Fatalf("启动隧道服务器失败: %v", err)
}
defer server.Stop()
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
// 模拟客户端连接
tunnelConn, err := net.Dial("tcp", server.listener.Addr().String())
if err != nil {
t.Fatalf("连接隧道服务器失败: %v", err)
}
defer tunnelConn.Close()
// 等待隧道连接被处理
time.Sleep(100 * time.Millisecond)
// 验证隧道是否已连接
if !server.IsConnected() {
t.Error("隧道未连接")
@ -191,28 +191,28 @@ func TestConnectRequest(t *testing.T) {
// TestProtocolVersionCheck 测试协议版本检查
func TestProtocolVersionCheck(t *testing.T) {
server := NewServer(9000)
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
// 发送错误版本的消息
wrongVersionHeader := make([]byte, HeaderSize)
wrongVersionHeader[0] = 0xFF // 错误版本
wrongVersionHeader[1] = MsgTypeData
binary.BigEndian.PutUint32(wrongVersionHeader[2:6], 0)
go func() {
clientConn.Write(wrongVersionHeader)
}()
// 尝试读取消息,应该返回错误
_, err := server.readTunnelMessage(serverConn)
if err == nil {
t.Error("期望版本检查失败,但成功了")
}
if err.Error() != "不支持的协议版本: 255" {
t.Errorf("错误消息不正确: %v", err)
}
@ -221,22 +221,22 @@ func TestProtocolVersionCheck(t *testing.T) {
// TestMaxPacketSizeCheck 测试最大包大小检查
func TestMaxPacketSizeCheck(t *testing.T) {
server := NewServer(9000)
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
// 发送超大包
oversizedHeader := make([]byte, HeaderSize)
oversizedHeader[0] = ProtocolVersion
oversizedHeader[1] = MsgTypeData
binary.BigEndian.PutUint32(oversizedHeader[2:6], MaxPacketSize+1)
go func() {
clientConn.Write(oversizedHeader)
}()
// 尝试读取消息,应该返回错误
_, err := server.readTunnelMessage(serverConn)
if err == nil {
@ -253,41 +253,41 @@ func TestConcurrentConnections(t *testing.T) {
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
server := NewServer(port)
err = server.Start()
if err != nil {
t.Fatalf("启动服务器失败: %v", err)
}
defer server.Stop()
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
var wg sync.WaitGroup
connCount := 5
// 并发创建多个连接
for i := 0; i < connCount; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
conn, err := net.Dial("tcp", server.listener.Addr().String())
if err != nil {
t.Errorf("连接 %d 失败: %v", id, err)
return
}
defer conn.Close()
// 保持连接一段时间
time.Sleep(200 * time.Millisecond)
}(i)
}
wg.Wait()
// 验证只有一个隧道连接被接受
if server.IsConnected() {
// 应该只有一个连接被接受
@ -298,12 +298,12 @@ func TestConcurrentConnections(t *testing.T) {
// TestKeepAlive 测试心跳消息
func TestKeepAlive(t *testing.T) {
server := NewServer(9000)
// 创建模拟连接
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
// 创建心跳消息
keepAliveMsg := &TunnelMessage{
Version: ProtocolVersion,
@ -311,7 +311,7 @@ func TestKeepAlive(t *testing.T) {
Length: 0,
Data: nil,
}
// 启动消息处理
go func() {
for {
@ -322,7 +322,7 @@ func TestKeepAlive(t *testing.T) {
server.handleTunnelMessage(msg)
}
}()
// 启动发送处理
go func() {
for {
@ -334,19 +334,19 @@ func TestKeepAlive(t *testing.T) {
}
}
}()
// 发送心跳
err := server.writeTunnelMessage(serverConn, keepAliveMsg)
if err != nil {
t.Fatalf("发送心跳失败: %v", err)
}
// 读取响应
response, err := server.readTunnelMessage(clientConn)
if err != nil {
t.Fatalf("读取心跳响应失败: %v", err)
}
if response.Type != MsgTypeKeepAlive {
t.Errorf("心跳响应类型不正确,期望 %d得到 %d", MsgTypeKeepAlive, response.Type)
}
@ -361,31 +361,31 @@ func (mc *MockClient) sendConnectResponse(connID uint32, status byte) error {
responseData := make([]byte, 5)
binary.BigEndian.PutUint32(responseData[0:4], connID)
responseData[4] = status
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeConnectResponse,
Length: 5,
Data: responseData,
}
// 构建消息头
header := make([]byte, HeaderSize)
header[0] = msg.Version
header[1] = msg.Type
binary.BigEndian.PutUint32(header[2:6], msg.Length)
// 写入消息
if _, err := mc.conn.Write(header); err != nil {
return err
}
if msg.Length > 0 && msg.Data != nil {
if _, err := mc.conn.Write(msg.Data); err != nil {
return err
}
}
return nil
}
@ -397,9 +397,9 @@ func TestForwardConnection(t *testing.T) {
t.Fatalf("启动目标服务失败: %v", err)
}
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 启动简单的echo服务
go func() {
for {
@ -413,7 +413,7 @@ func TestForwardConnection(t *testing.T) {
}(conn)
}
}()
// 创建隧道服务器,使用随机端口
tunnelListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -421,80 +421,81 @@ func TestForwardConnection(t *testing.T) {
}
tunnelPort := tunnelListener.Addr().(*net.TCPAddr).Port
tunnelListener.Close() // 关闭以便服务器重新绑定
server := NewServer(tunnelPort)
err = server.Start()
if err != nil {
t.Fatalf("启动隧道服务器失败: %v", err)
}
defer server.Stop()
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
// 连接到隧道服务器(模拟客户端)
tunnelConn, err := net.Dial("tcp", server.listener.Addr().String())
if err != nil {
t.Fatalf("连接隧道服务器失败: %v", err)
}
defer tunnelConn.Close()
// 等待隧道连接建立
time.Sleep(100 * time.Millisecond)
if !server.IsConnected() {
t.Fatal("隧道未连接")
}
// 创建模拟客户端连接
clientConn, serverSideConn := net.Pipe()
defer clientConn.Close()
defer serverSideConn.Close()
// 创建模拟客户端
mockClient := &MockClient{conn: tunnelConn}
// 启动连接转发
go func() {
err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
conn, err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
if err != nil {
t.Errorf("转发连接失败: %v", err)
}
defer conn.Close()
}()
// 读取连接请求
header := make([]byte, HeaderSize)
_, err = io.ReadFull(tunnelConn, header)
if err != nil {
t.Fatalf("读取连接请求头失败: %v", err)
}
if header[1] != MsgTypeConnectRequest {
t.Fatalf("期望连接请求,得到消息类型: %d", header[1])
}
dataLen := binary.BigEndian.Uint32(header[2:6])
data := make([]byte, dataLen)
_, err = io.ReadFull(tunnelConn, data)
if err != nil {
t.Fatalf("读取连接请求数据失败: %v", err)
}
connID := binary.BigEndian.Uint32(data[0:4])
requestedPort := binary.BigEndian.Uint16(data[4:6])
if int(requestedPort) != targetPort {
t.Errorf("请求端口不匹配,期望 %d得到 %d", targetPort, requestedPort)
}
// 发送连接成功响应
err = mockClient.sendConnectResponse(connID, ConnStatusSuccess)
if err != nil {
t.Fatalf("发送连接响应失败: %v", err)
}
// 等待连接建立
time.Sleep(100 * time.Millisecond)
t.Log("连接转发测试完成")
}
}