feat: 重构了tunnel,现在统一了转发逻辑

This commit is contained in:
qcqcqc@wsl 2026-01-09 13:08:13 +08:00
parent 2655b5592f
commit fc1614f7a4
5 changed files with 413 additions and 304 deletions

View File

@ -16,7 +16,7 @@ import (
// TunnelServer 隧道服务器接口
type TunnelServer interface {
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) (net.Conn, error)
IsConnected() bool
GetTrafficStats() stats.TrafficStats
}
@ -46,7 +46,13 @@ func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int6
ctx, cancel := context.WithCancel(context.Background())
var limiterOut, limiterIn *rate.Limiter
if limit != nil {
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
// burst设置为1秒的流量这样可以平滑处理突发
// 同时不会一次性消耗太多令牌
burst := int(*limit) / 100
if burst < 10240 {
burst = 10240 // 最小burst为10KB
}
log.Printf("设置限速: %d bytes/sec, burst: %d bytes", *limit, burst)
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
@ -68,7 +74,13 @@ func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunne
ctx, cancel := context.WithCancel(context.Background())
var limiterOut, limiterIn *rate.Limiter
if limit != nil {
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
// burst设置为1秒的流量这样可以平滑处理突发
// 同时不会一次性消耗太多令牌
burst := int(*limit) / 100
if burst < 10240 {
burst = 10240 // 最小burst为10KB
}
log.Printf("设置限速: %d bytes/sec, burst: %d bytes", *limit, burst)
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
@ -142,42 +154,62 @@ type rateLimitedReader struct {
}
func (rlr *rateLimitedReader) Read(p []byte) (int, error) {
if rlr.limiter != nil {
maxReq := rlr.limiter.Burst()
reqSize := len(p)
if reqSize > maxReq {
reqSize = maxReq // 避免一次申请超过桶容量导致错误
if rlr.limiter == nil {
return rlr.r.Read(p)
}
err := rlr.limiter.WaitN(rlr.ctx, reqSize)
if err != nil {
// 使用更小的块大小以实现更平滑的限流
// 2KB是一个合理的值既不会太频繁调用也能保持流量平滑
chunkSize := 2048
// 不超过缓冲区大小
if len(p) < chunkSize {
chunkSize = len(p)
}
// 预先申请令牌
if err := rlr.limiter.WaitN(rlr.ctx, chunkSize); err != nil {
return 0, err
}
// 限制实际读取大小
if len(p) > chunkSize {
p = p[:chunkSize]
}
return rlr.r.Read(p)
// 执行读取
n, err := rlr.r.Read(p)
// 如果实际读取少于申请的令牌,归还多余的令牌
// 注意rate.Limiter 不支持归还令牌,所以这里只能接受这个损耗
// 或者改用先读后限的方式
return n, err
}
// handleConnection 处理单个连接
func (f *Forwarder) handleConnection(clientConn net.Conn) {
defer f.wg.Done()
if f.useTunnel {
// 使用隧道转发
if f.tunnelServer == nil || !f.tunnelServer.IsConnected() {
log.Printf("隧道服务器不可用 (端口 %d)", f.sourcePort)
clientConn.Close()
return
}
// 将连接转发到隧道ForwardConnection 会处理连接关闭
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetHost, f.targetPort); err != nil {
log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetHost, f.targetPort, err)
}
return
}
// 直接连接目标
defer clientConn.Close()
var targetConn net.Conn
var err error
if f.useTunnel {
// 使用隧道转发,获取 TunnelConn
if f.tunnelServer == nil || !f.tunnelServer.IsConnected() {
log.Printf("隧道服务器不可用 (端口 %d)", f.sourcePort)
return
}
// 获取隧道连接
targetConn, err = f.tunnelServer.ForwardConnection(clientConn, f.targetHost, f.targetPort)
if err != nil {
log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetHost, f.targetPort, err)
return
}
} else {
// 直接连接目标
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
@ -185,11 +217,13 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
// 动态解析域名并连接
targetAddr := fmt.Sprintf("%s:%d", f.targetHost, f.targetPort)
targetConn, err := dialer.DialContext(f.ctx, "tcp", targetAddr)
targetConn, err = dialer.DialContext(f.ctx, "tcp", targetAddr)
if err != nil {
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, targetAddr, err)
return
}
}
defer targetConn.Close()
// 双向转发

View File

@ -39,10 +39,10 @@ type mockTunnelServer struct {
connected bool
}
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) error {
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) (net.Conn, error) {
// 简单的模拟实现
defer clientConn.Close()
return nil
return nil, nil
}
func (m *mockTunnelServer) IsConnected() bool {

View File

@ -492,6 +492,22 @@
let currentMappings = [];
let apiKey = '';
// 格式化带宽限制为可读格式
function formatBandwidth(bytes) {
if (!bytes || bytes === 0) {
return '无限制';
}
if (bytes < 1024) {
return bytes + ' B/s';
} else if (bytes < 1024 * 1024) {
return (bytes / 1024).toFixed(2) + ' KB/s';
} else if (bytes < 1024 * 1024 * 1024) {
return (bytes / (1024 * 1024)).toFixed(2) + ' MB/s';
} else {
return (bytes / (1024 * 1024 * 1024)).toFixed(2) + ' GB/s';
}
}
// 检查是否已有 API Key
function checkAuth() {
const savedKey = sessionStorage.getItem('apiKey');
@ -733,6 +749,7 @@
'<span class="status-badge ' + (mapping.use_tunnel ? 'status-tunnel' : 'status-direct') + '">' +
(mapping.use_tunnel ? '隧道模式' : '直连模式') +
'</span><br>' +
'<strong>带宽限制:</strong> ' + formatBandwidth(mapping.bandwidth_limit) + '<br>' +
'<strong>创建时间:</strong> ' + new Date(mapping.created_at).toLocaleString('zh-CN') +
'</div>' +
'</div>' +
@ -800,7 +817,8 @@
'<strong>目标:</strong> ' + mapping.target_host + ':' + mapping.target_port + ' | ' +
'<span class="status-badge ' + (mapping.use_tunnel ? 'status-tunnel' : 'status-direct') + '">' +
(mapping.use_tunnel ? '隧道' : '直连') +
'</span>' +
'</span> | ' +
'<strong>带宽:</strong> ' + formatBandwidth(mapping.bandwidth_limit) +
'</div>' +
'</div>' +
'</div>'

View File

@ -95,6 +95,114 @@ type ActiveConnection struct {
TargetIP string
Created time.Time
Closing int32 // 原子操作标志,表示连接正在关闭
RecvChan chan []byte // 接收数据的通道
}
// TunnelConn 实现 net.Conn 接口的隧道连接
type TunnelConn struct {
server *Server
connID uint32
targetHost string
targetPort int
recvChan chan []byte // 接收数据的通道
readBuffer []byte // 读取缓冲区
closed int32 // 原子操作标志,表示连接已关闭
closeOnce sync.Once // 确保只关闭一次
}
// Read 实现 net.Conn 接口
func (tc *TunnelConn) Read(b []byte) (n int, err error) {
if atomic.LoadInt32(&tc.closed) == 1 {
return 0, io.EOF
}
// 如果有缓冲数据,先返回缓冲数据
if len(tc.readBuffer) > 0 {
n = copy(b, tc.readBuffer)
tc.readBuffer = tc.readBuffer[n:]
return n, nil
}
// 从通道读取数据
select {
case data, ok := <-tc.recvChan:
if !ok {
return 0, io.EOF
}
n = copy(b, data)
// 如果数据太多,保存剩余部分到缓冲区
if n < len(data) {
tc.readBuffer = data[n:]
}
return n, nil
case <-tc.server.ctx.Done():
return 0, io.EOF
}
}
// Write 实现 net.Conn 接口
func (tc *TunnelConn) Write(b []byte) (n int, err error) {
if atomic.LoadInt32(&tc.closed) == 1 {
return 0, fmt.Errorf("connection closed")
}
// 发送数据到隧道
dataMsg := make([]byte, 4+len(b))
binary.BigEndian.PutUint32(dataMsg[0:4], tc.connID)
copy(dataMsg[4:], b)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(dataMsg)),
Data: dataMsg,
}
select {
case tc.server.sendChan <- msg:
return len(b), nil
case <-time.After(2 * time.Second):
return 0, fmt.Errorf("write timeout")
case <-tc.server.ctx.Done():
return 0, fmt.Errorf("server closed")
}
}
// Close 实现 net.Conn 接口
func (tc *TunnelConn) Close() error {
tc.closeOnce.Do(func() {
atomic.StoreInt32(&tc.closed, 1)
tc.server.closeConnection(tc.connID)
})
return nil
}
// LocalAddr 实现 net.Conn 接口
func (tc *TunnelConn) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4zero, Port: 0}
}
// RemoteAddr 实现 net.Conn 接口
func (tc *TunnelConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.ParseIP(tc.targetHost), Port: tc.targetPort}
}
// SetDeadline 实现 net.Conn 接口
func (tc *TunnelConn) SetDeadline(t time.Time) error {
// 隧道连接不支持 deadline可以根据需要实现
return nil
}
// SetReadDeadline 实现 net.Conn 接口
func (tc *TunnelConn) SetReadDeadline(t time.Time) error {
// 隧道连接不支持 deadline可以根据需要实现
return nil
}
// SetWriteDeadline 实现 net.Conn 接口
func (tc *TunnelConn) SetWriteDeadline(t time.Time) error {
// 隧道连接不支持 deadline可以根据需要实现
return nil
}
// Server 内网穿透服务器
@ -398,13 +506,16 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
s.connMu.Unlock()
if status == ConnStatusSuccess {
// 连接成功,移到活跃连接
// 连接成功,创建接收通道
recvChan := make(chan []byte, 100)
active := &ActiveConnection{
ID: connID,
ClientConn: pending.ClientConn,
TargetPort: pending.TargetPort,
TargetHost: pending.TargetHost,
Created: time.Now(),
RecvChan: recvChan,
}
s.connMu.Lock()
@ -413,10 +524,6 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
log.Printf("连接已建立: ID=%d, 地址=%s:%d", connID, pending.TargetHost, pending.TargetPort)
// 启动数据转发
s.wg.Add(1)
go s.forwardData(active)
// 通知等待的goroutine
select {
case pending.ResponseChan <- true:
@ -486,10 +593,18 @@ func (s *Server) handleDataMessage(msg *TunnelMessage) {
return
}
// 写入到客户端连接
if _, err := active.ClientConn.Write(data); err != nil {
log.Printf("写入客户端连接失败 (ID=%d): %v", connID, err)
// 发送数据到接收通道
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
select {
case active.RecvChan <- dataCopy:
// 数据已发送到接收通道
case <-time.After(2 * time.Second):
log.Printf("发送数据到接收通道超时 (ID=%d)", connID)
s.closeConnection(connID)
case <-s.ctx.Done():
return
}
}
@ -527,79 +642,92 @@ func (s *Server) handleKeepAlive(msg *TunnelMessage) {
}
}
// forwardData 转发数据
func (s *Server) forwardData(active *ActiveConnection) {
defer s.wg.Done()
defer s.closeConnection(active.ID)
// ForwardConnection 创建隧道连接(返回 net.Conn 接口)
func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targetPort int) (net.Conn, error) {
s.mu.RLock()
tunnelConnected := s.tunnelConn != nil
s.mu.RUnlock()
buffer := make([]byte, 32*1024)
for {
select {
case <-s.ctx.Done():
return
default:
if !tunnelConnected {
return nil, fmt.Errorf("隧道连接不可用")
}
// 设置读取超时
active.ClientConn.SetReadDeadline(time.Now().Add(ReadTimeout))
n, err := active.ClientConn.Read(buffer)
// 创建待处理连接
s.connMu.Lock()
connID := s.nextConnID
s.nextConnID++
if err != nil {
// 检查是否是正在关闭的连接,避免记录无关错误
if atomic.LoadInt32(&active.Closing) == 1 {
return // 连接正在关闭,正常退出
pending := &PendingConnection{
ID: connID,
ClientConn: clientConn,
TargetPort: targetPort,
TargetHost: targetHost,
Created: time.Now(),
ResponseChan: make(chan bool, 1),
}
s.pendingConns[connID] = pending
s.connMu.Unlock()
// 任何错误都应该终止转发,包括超时
if err == io.EOF {
log.Printf("客户端连接正常关闭 (ID=%d)", active.ID)
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
log.Printf("客户端连接超时 (ID=%d)", active.ID)
} else {
// 只记录非关闭相关的错误
if !isConnectionClosed(err) {
log.Printf("读取客户端连接失败 (ID=%d): %v", active.ID, err)
}
}
return
}
// 读取到0字节连接已关闭
if n == 0 {
log.Printf("客户端连接已关闭 (ID=%d, 读取0字节)", active.ID)
return
}
// 重置读取超时
active.ClientConn.SetReadDeadline(time.Time{})
// 检查连接是否正在关闭
if atomic.LoadInt32(&active.Closing) == 1 {
return
}
// 发送数据到隧道
dataMsg := make([]byte, 4+n)
binary.BigEndian.PutUint32(dataMsg[0:4], active.ID)
copy(dataMsg[4:], buffer[:n])
// 发送连接请求
// 格式: connID(4) + targetPort(2) + targetHostLen(1) + targetHost(变长)
targetHostBytes := []byte(targetHost)
reqData := make([]byte, 7+len(targetHostBytes))
binary.BigEndian.PutUint32(reqData[0:4], connID)
binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort))
reqData[6] = byte(len(targetHostBytes))
copy(reqData[7:], targetHostBytes)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeData,
Length: uint32(len(dataMsg)),
Data: dataMsg,
Type: MsgTypeConnectRequest,
Length: uint32(len(reqData)),
Data: reqData,
}
select {
case s.sendChan <- msg:
// 数据已发送
case <-time.After(2 * time.Second): // 减少超时时间
queueLen := len(s.sendChan)
log.Printf("发送数据超时 (ID=%d), 队列长度: %d/10000", active.ID, queueLen)
return
case <-s.ctx.Done():
return
case <-time.After(5 * time.Second):
s.connMu.Lock()
delete(s.pendingConns, connID)
s.connMu.Unlock()
return nil, fmt.Errorf("发送连接请求超时")
}
log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetHost, targetPort)
// 等待连接响应
select {
case success := <-pending.ResponseChan:
if success {
// 连接建立成功,获取活动连接并创建 TunnelConn
s.connMu.RLock()
active, exists := s.activeConns[connID]
s.connMu.RUnlock()
if !exists {
return nil, fmt.Errorf("活动连接不存在")
}
// 创建 TunnelConn
tunnelConn := &TunnelConn{
server: s,
connID: connID,
targetHost: targetHost,
targetPort: targetPort,
recvChan: active.RecvChan,
}
return tunnelConn, nil
} else {
return nil, fmt.Errorf("远程连接失败")
}
case <-time.After(ConnectTimeout):
s.connMu.Lock()
delete(s.pendingConns, connID)
s.connMu.Unlock()
return nil, fmt.Errorf("连接超时")
case <-s.ctx.Done():
return nil, fmt.Errorf("服务器关闭")
}
}
@ -619,9 +747,9 @@ func (s *Server) closeConnection(connID uint32) {
// 记录关闭时间,避免重复发送关闭消息
s.closingConns[connID] = time.Now()
// 确保连接被关闭
if active.ClientConn != nil {
active.ClientConn.Close()
// 关闭接收通道
if active.RecvChan != nil {
close(active.RecvChan)
}
}
s.connMu.Unlock()
@ -663,78 +791,6 @@ func (s *Server) closeConnection(connID uint32) {
}
}
// ForwardConnection 转发连接到隧道(新的透明代理实现)
func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targetPort int) error {
s.mu.RLock()
tunnelConnected := s.tunnelConn != nil
s.mu.RUnlock()
if !tunnelConnected {
return fmt.Errorf("隧道连接不可用")
}
// 创建待处理连接
s.connMu.Lock()
connID := s.nextConnID
s.nextConnID++
pending := &PendingConnection{
ID: connID,
ClientConn: clientConn,
TargetPort: targetPort,
TargetHost: targetHost, // 现在支持域名
Created: time.Now(),
ResponseChan: make(chan bool, 1),
}
s.pendingConns[connID] = pending
s.connMu.Unlock()
// 发送连接请求
// 格式: connID(4) + targetPort(2) + targetHostLen(1) + targetHost(变长)
targetHostBytes := []byte(targetHost)
reqData := make([]byte, 7+len(targetHostBytes))
binary.BigEndian.PutUint32(reqData[0:4], connID)
binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort))
reqData[6] = byte(len(targetHostBytes))
copy(reqData[7:], targetHostBytes)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeConnectRequest,
Length: uint32(len(reqData)),
Data: reqData,
}
select {
case s.sendChan <- msg:
case <-time.After(5 * time.Second):
s.connMu.Lock()
delete(s.pendingConns, connID)
s.connMu.Unlock()
return fmt.Errorf("发送连接请求超时")
}
log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetHost, targetPort)
// 等待连接响应
select {
case success := <-pending.ResponseChan:
if success {
return nil // 连接建立成功forwardData会处理后续的数据转发
} else {
return fmt.Errorf("远程连接失败")
}
case <-time.After(ConnectTimeout):
s.connMu.Lock()
delete(s.pendingConns, connID)
s.connMu.Unlock()
clientConn.Close()
return fmt.Errorf("连接超时")
case <-s.ctx.Done():
return fmt.Errorf("服务器关闭")
}
}
// IsConnected 检查隧道是否已连接
func (s *Server) IsConnected() bool {
s.mu.RLock()
@ -866,7 +922,7 @@ func (s *Server) cleanupClosingConns() {
if deletedCount >= deleteCount {
break
}
if closeTime.Before(now.Add(-maxAge/2)) {
if closeTime.Before(now.Add(-maxAge / 2)) {
delete(s.closingConns, connID)
deletedCount++
}

View File

@ -456,10 +456,11 @@ func TestForwardConnection(t *testing.T) {
// 启动连接转发
go func() {
err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
conn, err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
if err != nil {
t.Errorf("转发连接失败: %v", err)
}
defer conn.Close()
}()
// 读取连接请求