feat: 重构了tunnel,现在统一了转发逻辑
This commit is contained in:
parent
2655b5592f
commit
fc1614f7a4
|
|
@ -16,7 +16,7 @@ import (
|
|||
|
||||
// TunnelServer 隧道服务器接口
|
||||
type TunnelServer interface {
|
||||
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error
|
||||
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) (net.Conn, error)
|
||||
IsConnected() bool
|
||||
GetTrafficStats() stats.TrafficStats
|
||||
}
|
||||
|
|
@ -46,7 +46,13 @@ func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int6
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var limiterOut, limiterIn *rate.Limiter
|
||||
if limit != nil {
|
||||
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
|
||||
// burst设置为1秒的流量,这样可以平滑处理突发
|
||||
// 同时不会一次性消耗太多令牌
|
||||
burst := int(*limit) / 100
|
||||
if burst < 10240 {
|
||||
burst = 10240 // 最小burst为10KB
|
||||
}
|
||||
log.Printf("设置限速: %d bytes/sec, burst: %d bytes", *limit, burst)
|
||||
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
|
||||
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
|
||||
}
|
||||
|
|
@ -68,7 +74,13 @@ func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunne
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var limiterOut, limiterIn *rate.Limiter
|
||||
if limit != nil {
|
||||
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
|
||||
// burst设置为1秒的流量,这样可以平滑处理突发
|
||||
// 同时不会一次性消耗太多令牌
|
||||
burst := int(*limit) / 100
|
||||
if burst < 10240 {
|
||||
burst = 10240 // 最小burst为10KB
|
||||
}
|
||||
log.Printf("设置限速: %d bytes/sec, burst: %d bytes", *limit, burst)
|
||||
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
|
||||
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
|
||||
}
|
||||
|
|
@ -142,42 +154,62 @@ type rateLimitedReader struct {
|
|||
}
|
||||
|
||||
func (rlr *rateLimitedReader) Read(p []byte) (int, error) {
|
||||
if rlr.limiter != nil {
|
||||
maxReq := rlr.limiter.Burst()
|
||||
reqSize := len(p)
|
||||
if reqSize > maxReq {
|
||||
reqSize = maxReq // 避免一次申请超过桶容量导致错误
|
||||
if rlr.limiter == nil {
|
||||
return rlr.r.Read(p)
|
||||
}
|
||||
err := rlr.limiter.WaitN(rlr.ctx, reqSize)
|
||||
if err != nil {
|
||||
|
||||
// 使用更小的块大小以实现更平滑的限流
|
||||
// 2KB是一个合理的值,既不会太频繁调用,也能保持流量平滑
|
||||
chunkSize := 2048
|
||||
|
||||
// 不超过缓冲区大小
|
||||
if len(p) < chunkSize {
|
||||
chunkSize = len(p)
|
||||
}
|
||||
|
||||
// 预先申请令牌
|
||||
if err := rlr.limiter.WaitN(rlr.ctx, chunkSize); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 限制实际读取大小
|
||||
if len(p) > chunkSize {
|
||||
p = p[:chunkSize]
|
||||
}
|
||||
return rlr.r.Read(p)
|
||||
|
||||
// 执行读取
|
||||
n, err := rlr.r.Read(p)
|
||||
|
||||
// 如果实际读取少于申请的令牌,归还多余的令牌
|
||||
// 注意:rate.Limiter 不支持归还令牌,所以这里只能接受这个损耗
|
||||
// 或者改用先读后限的方式
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// handleConnection 处理单个连接
|
||||
func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.useTunnel {
|
||||
// 使用隧道转发
|
||||
if f.tunnelServer == nil || !f.tunnelServer.IsConnected() {
|
||||
log.Printf("隧道服务器不可用 (端口 %d)", f.sourcePort)
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 将连接转发到隧道,ForwardConnection 会处理连接关闭
|
||||
if err := f.tunnelServer.ForwardConnection(clientConn, f.targetHost, f.targetPort); err != nil {
|
||||
log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetHost, f.targetPort, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 直接连接目标
|
||||
defer clientConn.Close()
|
||||
|
||||
var targetConn net.Conn
|
||||
var err error
|
||||
|
||||
if f.useTunnel {
|
||||
// 使用隧道转发,获取 TunnelConn
|
||||
if f.tunnelServer == nil || !f.tunnelServer.IsConnected() {
|
||||
log.Printf("隧道服务器不可用 (端口 %d)", f.sourcePort)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取隧道连接
|
||||
targetConn, err = f.tunnelServer.ForwardConnection(clientConn, f.targetHost, f.targetPort)
|
||||
if err != nil {
|
||||
log.Printf("隧道转发失败 (端口 %d -> %s:%d): %v", f.sourcePort, f.targetHost, f.targetPort, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 直接连接目标
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
|
|
@ -185,11 +217,13 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
|||
|
||||
// 动态解析域名并连接
|
||||
targetAddr := fmt.Sprintf("%s:%d", f.targetHost, f.targetPort)
|
||||
targetConn, err := dialer.DialContext(f.ctx, "tcp", targetAddr)
|
||||
targetConn, err = dialer.DialContext(f.ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, targetAddr, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
defer targetConn.Close()
|
||||
|
||||
// 双向转发
|
||||
|
|
|
|||
|
|
@ -39,10 +39,10 @@ type mockTunnelServer struct {
|
|||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) error {
|
||||
func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp string, targetPort int) (net.Conn, error) {
|
||||
// 简单的模拟实现
|
||||
defer clientConn.Close()
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockTunnelServer) IsConnected() bool {
|
||||
|
|
|
|||
|
|
@ -492,6 +492,22 @@
|
|||
let currentMappings = [];
|
||||
let apiKey = '';
|
||||
|
||||
// 格式化带宽限制为可读格式
|
||||
function formatBandwidth(bytes) {
|
||||
if (!bytes || bytes === 0) {
|
||||
return '无限制';
|
||||
}
|
||||
if (bytes < 1024) {
|
||||
return bytes + ' B/s';
|
||||
} else if (bytes < 1024 * 1024) {
|
||||
return (bytes / 1024).toFixed(2) + ' KB/s';
|
||||
} else if (bytes < 1024 * 1024 * 1024) {
|
||||
return (bytes / (1024 * 1024)).toFixed(2) + ' MB/s';
|
||||
} else {
|
||||
return (bytes / (1024 * 1024 * 1024)).toFixed(2) + ' GB/s';
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否已有 API Key
|
||||
function checkAuth() {
|
||||
const savedKey = sessionStorage.getItem('apiKey');
|
||||
|
|
@ -733,6 +749,7 @@
|
|||
'<span class="status-badge ' + (mapping.use_tunnel ? 'status-tunnel' : 'status-direct') + '">' +
|
||||
(mapping.use_tunnel ? '隧道模式' : '直连模式') +
|
||||
'</span><br>' +
|
||||
'<strong>带宽限制:</strong> ' + formatBandwidth(mapping.bandwidth_limit) + '<br>' +
|
||||
'<strong>创建时间:</strong> ' + new Date(mapping.created_at).toLocaleString('zh-CN') +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
|
|
@ -800,7 +817,8 @@
|
|||
'<strong>目标:</strong> ' + mapping.target_host + ':' + mapping.target_port + ' | ' +
|
||||
'<span class="status-badge ' + (mapping.use_tunnel ? 'status-tunnel' : 'status-direct') + '">' +
|
||||
(mapping.use_tunnel ? '隧道' : '直连') +
|
||||
'</span>' +
|
||||
'</span> | ' +
|
||||
'<strong>带宽:</strong> ' + formatBandwidth(mapping.bandwidth_limit) +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>'
|
||||
|
|
|
|||
|
|
@ -95,6 +95,114 @@ type ActiveConnection struct {
|
|||
TargetIP string
|
||||
Created time.Time
|
||||
Closing int32 // 原子操作标志,表示连接正在关闭
|
||||
RecvChan chan []byte // 接收数据的通道
|
||||
}
|
||||
|
||||
// TunnelConn 实现 net.Conn 接口的隧道连接
|
||||
type TunnelConn struct {
|
||||
server *Server
|
||||
connID uint32
|
||||
targetHost string
|
||||
targetPort int
|
||||
recvChan chan []byte // 接收数据的通道
|
||||
readBuffer []byte // 读取缓冲区
|
||||
closed int32 // 原子操作标志,表示连接已关闭
|
||||
closeOnce sync.Once // 确保只关闭一次
|
||||
}
|
||||
|
||||
// Read 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) Read(b []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&tc.closed) == 1 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// 如果有缓冲数据,先返回缓冲数据
|
||||
if len(tc.readBuffer) > 0 {
|
||||
n = copy(b, tc.readBuffer)
|
||||
tc.readBuffer = tc.readBuffer[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// 从通道读取数据
|
||||
select {
|
||||
case data, ok := <-tc.recvChan:
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(b, data)
|
||||
// 如果数据太多,保存剩余部分到缓冲区
|
||||
if n < len(data) {
|
||||
tc.readBuffer = data[n:]
|
||||
}
|
||||
return n, nil
|
||||
case <-tc.server.ctx.Done():
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// Write 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) Write(b []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&tc.closed) == 1 {
|
||||
return 0, fmt.Errorf("connection closed")
|
||||
}
|
||||
|
||||
// 发送数据到隧道
|
||||
dataMsg := make([]byte, 4+len(b))
|
||||
binary.BigEndian.PutUint32(dataMsg[0:4], tc.connID)
|
||||
copy(dataMsg[4:], b)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(dataMsg)),
|
||||
Data: dataMsg,
|
||||
}
|
||||
|
||||
select {
|
||||
case tc.server.sendChan <- msg:
|
||||
return len(b), nil
|
||||
case <-time.After(2 * time.Second):
|
||||
return 0, fmt.Errorf("write timeout")
|
||||
case <-tc.server.ctx.Done():
|
||||
return 0, fmt.Errorf("server closed")
|
||||
}
|
||||
}
|
||||
|
||||
// Close 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) Close() error {
|
||||
tc.closeOnce.Do(func() {
|
||||
atomic.StoreInt32(&tc.closed, 1)
|
||||
tc.server.closeConnection(tc.connID)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4zero, Port: 0}
|
||||
}
|
||||
|
||||
// RemoteAddr 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.ParseIP(tc.targetHost), Port: tc.targetPort}
|
||||
}
|
||||
|
||||
// SetDeadline 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) SetDeadline(t time.Time) error {
|
||||
// 隧道连接不支持 deadline,可以根据需要实现
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) SetReadDeadline(t time.Time) error {
|
||||
// 隧道连接不支持 deadline,可以根据需要实现
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline 实现 net.Conn 接口
|
||||
func (tc *TunnelConn) SetWriteDeadline(t time.Time) error {
|
||||
// 隧道连接不支持 deadline,可以根据需要实现
|
||||
return nil
|
||||
}
|
||||
|
||||
// Server 内网穿透服务器
|
||||
|
|
@ -398,13 +506,16 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
|
|||
s.connMu.Unlock()
|
||||
|
||||
if status == ConnStatusSuccess {
|
||||
// 连接成功,移到活跃连接
|
||||
// 连接成功,创建接收通道
|
||||
recvChan := make(chan []byte, 100)
|
||||
|
||||
active := &ActiveConnection{
|
||||
ID: connID,
|
||||
ClientConn: pending.ClientConn,
|
||||
TargetPort: pending.TargetPort,
|
||||
TargetHost: pending.TargetHost,
|
||||
Created: time.Now(),
|
||||
RecvChan: recvChan,
|
||||
}
|
||||
|
||||
s.connMu.Lock()
|
||||
|
|
@ -413,10 +524,6 @@ func (s *Server) handleConnectResponse(msg *TunnelMessage) {
|
|||
|
||||
log.Printf("连接已建立: ID=%d, 地址=%s:%d", connID, pending.TargetHost, pending.TargetPort)
|
||||
|
||||
// 启动数据转发
|
||||
s.wg.Add(1)
|
||||
go s.forwardData(active)
|
||||
|
||||
// 通知等待的goroutine
|
||||
select {
|
||||
case pending.ResponseChan <- true:
|
||||
|
|
@ -486,10 +593,18 @@ func (s *Server) handleDataMessage(msg *TunnelMessage) {
|
|||
return
|
||||
}
|
||||
|
||||
// 写入到客户端连接
|
||||
if _, err := active.ClientConn.Write(data); err != nil {
|
||||
log.Printf("写入客户端连接失败 (ID=%d): %v", connID, err)
|
||||
// 发送数据到接收通道
|
||||
dataCopy := make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
select {
|
||||
case active.RecvChan <- dataCopy:
|
||||
// 数据已发送到接收通道
|
||||
case <-time.After(2 * time.Second):
|
||||
log.Printf("发送数据到接收通道超时 (ID=%d)", connID)
|
||||
s.closeConnection(connID)
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -527,79 +642,92 @@ func (s *Server) handleKeepAlive(msg *TunnelMessage) {
|
|||
}
|
||||
}
|
||||
|
||||
// forwardData 转发数据
|
||||
func (s *Server) forwardData(active *ActiveConnection) {
|
||||
defer s.wg.Done()
|
||||
defer s.closeConnection(active.ID)
|
||||
// ForwardConnection 创建隧道连接(返回 net.Conn 接口)
|
||||
func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targetPort int) (net.Conn, error) {
|
||||
s.mu.RLock()
|
||||
tunnelConnected := s.tunnelConn != nil
|
||||
s.mu.RUnlock()
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
if !tunnelConnected {
|
||||
return nil, fmt.Errorf("隧道连接不可用")
|
||||
}
|
||||
|
||||
// 设置读取超时
|
||||
active.ClientConn.SetReadDeadline(time.Now().Add(ReadTimeout))
|
||||
n, err := active.ClientConn.Read(buffer)
|
||||
// 创建待处理连接
|
||||
s.connMu.Lock()
|
||||
connID := s.nextConnID
|
||||
s.nextConnID++
|
||||
|
||||
if err != nil {
|
||||
// 检查是否是正在关闭的连接,避免记录无关错误
|
||||
if atomic.LoadInt32(&active.Closing) == 1 {
|
||||
return // 连接正在关闭,正常退出
|
||||
pending := &PendingConnection{
|
||||
ID: connID,
|
||||
ClientConn: clientConn,
|
||||
TargetPort: targetPort,
|
||||
TargetHost: targetHost,
|
||||
Created: time.Now(),
|
||||
ResponseChan: make(chan bool, 1),
|
||||
}
|
||||
s.pendingConns[connID] = pending
|
||||
s.connMu.Unlock()
|
||||
|
||||
// 任何错误都应该终止转发,包括超时
|
||||
if err == io.EOF {
|
||||
log.Printf("客户端连接正常关闭 (ID=%d)", active.ID)
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
log.Printf("客户端连接超时 (ID=%d)", active.ID)
|
||||
} else {
|
||||
// 只记录非关闭相关的错误
|
||||
if !isConnectionClosed(err) {
|
||||
log.Printf("读取客户端连接失败 (ID=%d): %v", active.ID, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 读取到0字节,连接已关闭
|
||||
if n == 0 {
|
||||
log.Printf("客户端连接已关闭 (ID=%d, 读取0字节)", active.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// 重置读取超时
|
||||
active.ClientConn.SetReadDeadline(time.Time{})
|
||||
|
||||
// 检查连接是否正在关闭
|
||||
if atomic.LoadInt32(&active.Closing) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// 发送数据到隧道
|
||||
dataMsg := make([]byte, 4+n)
|
||||
binary.BigEndian.PutUint32(dataMsg[0:4], active.ID)
|
||||
copy(dataMsg[4:], buffer[:n])
|
||||
// 发送连接请求
|
||||
// 格式: connID(4) + targetPort(2) + targetHostLen(1) + targetHost(变长)
|
||||
targetHostBytes := []byte(targetHost)
|
||||
reqData := make([]byte, 7+len(targetHostBytes))
|
||||
binary.BigEndian.PutUint32(reqData[0:4], connID)
|
||||
binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort))
|
||||
reqData[6] = byte(len(targetHostBytes))
|
||||
copy(reqData[7:], targetHostBytes)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeData,
|
||||
Length: uint32(len(dataMsg)),
|
||||
Data: dataMsg,
|
||||
Type: MsgTypeConnectRequest,
|
||||
Length: uint32(len(reqData)),
|
||||
Data: reqData,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.sendChan <- msg:
|
||||
// 数据已发送
|
||||
case <-time.After(2 * time.Second): // 减少超时时间
|
||||
queueLen := len(s.sendChan)
|
||||
log.Printf("发送数据超时 (ID=%d), 队列长度: %d/10000", active.ID, queueLen)
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
s.connMu.Lock()
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
return nil, fmt.Errorf("发送连接请求超时")
|
||||
}
|
||||
|
||||
log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetHost, targetPort)
|
||||
|
||||
// 等待连接响应
|
||||
select {
|
||||
case success := <-pending.ResponseChan:
|
||||
if success {
|
||||
// 连接建立成功,获取活动连接并创建 TunnelConn
|
||||
s.connMu.RLock()
|
||||
active, exists := s.activeConns[connID]
|
||||
s.connMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("活动连接不存在")
|
||||
}
|
||||
|
||||
// 创建 TunnelConn
|
||||
tunnelConn := &TunnelConn{
|
||||
server: s,
|
||||
connID: connID,
|
||||
targetHost: targetHost,
|
||||
targetPort: targetPort,
|
||||
recvChan: active.RecvChan,
|
||||
}
|
||||
|
||||
return tunnelConn, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("远程连接失败")
|
||||
}
|
||||
case <-time.After(ConnectTimeout):
|
||||
s.connMu.Lock()
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
return nil, fmt.Errorf("连接超时")
|
||||
case <-s.ctx.Done():
|
||||
return nil, fmt.Errorf("服务器关闭")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -619,9 +747,9 @@ func (s *Server) closeConnection(connID uint32) {
|
|||
// 记录关闭时间,避免重复发送关闭消息
|
||||
s.closingConns[connID] = time.Now()
|
||||
|
||||
// 确保连接被关闭
|
||||
if active.ClientConn != nil {
|
||||
active.ClientConn.Close()
|
||||
// 关闭接收通道
|
||||
if active.RecvChan != nil {
|
||||
close(active.RecvChan)
|
||||
}
|
||||
}
|
||||
s.connMu.Unlock()
|
||||
|
|
@ -663,78 +791,6 @@ func (s *Server) closeConnection(connID uint32) {
|
|||
}
|
||||
}
|
||||
|
||||
// ForwardConnection 转发连接到隧道(新的透明代理实现)
|
||||
func (s *Server) ForwardConnection(clientConn net.Conn, targetHost string, targetPort int) error {
|
||||
s.mu.RLock()
|
||||
tunnelConnected := s.tunnelConn != nil
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !tunnelConnected {
|
||||
return fmt.Errorf("隧道连接不可用")
|
||||
}
|
||||
|
||||
// 创建待处理连接
|
||||
s.connMu.Lock()
|
||||
connID := s.nextConnID
|
||||
s.nextConnID++
|
||||
|
||||
pending := &PendingConnection{
|
||||
ID: connID,
|
||||
ClientConn: clientConn,
|
||||
TargetPort: targetPort,
|
||||
TargetHost: targetHost, // 现在支持域名
|
||||
Created: time.Now(),
|
||||
ResponseChan: make(chan bool, 1),
|
||||
}
|
||||
s.pendingConns[connID] = pending
|
||||
s.connMu.Unlock()
|
||||
|
||||
// 发送连接请求
|
||||
// 格式: connID(4) + targetPort(2) + targetHostLen(1) + targetHost(变长)
|
||||
targetHostBytes := []byte(targetHost)
|
||||
reqData := make([]byte, 7+len(targetHostBytes))
|
||||
binary.BigEndian.PutUint32(reqData[0:4], connID)
|
||||
binary.BigEndian.PutUint16(reqData[4:6], uint16(targetPort))
|
||||
reqData[6] = byte(len(targetHostBytes))
|
||||
copy(reqData[7:], targetHostBytes)
|
||||
|
||||
msg := &TunnelMessage{
|
||||
Version: ProtocolVersion,
|
||||
Type: MsgTypeConnectRequest,
|
||||
Length: uint32(len(reqData)),
|
||||
Data: reqData,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.sendChan <- msg:
|
||||
case <-time.After(5 * time.Second):
|
||||
s.connMu.Lock()
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
return fmt.Errorf("发送连接请求超时")
|
||||
}
|
||||
|
||||
log.Printf("发送连接请求: ID=%d, 地址=%s:%d", connID, targetHost, targetPort)
|
||||
|
||||
// 等待连接响应
|
||||
select {
|
||||
case success := <-pending.ResponseChan:
|
||||
if success {
|
||||
return nil // 连接建立成功,forwardData会处理后续的数据转发
|
||||
} else {
|
||||
return fmt.Errorf("远程连接失败")
|
||||
}
|
||||
case <-time.After(ConnectTimeout):
|
||||
s.connMu.Lock()
|
||||
delete(s.pendingConns, connID)
|
||||
s.connMu.Unlock()
|
||||
clientConn.Close()
|
||||
return fmt.Errorf("连接超时")
|
||||
case <-s.ctx.Done():
|
||||
return fmt.Errorf("服务器关闭")
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected 检查隧道是否已连接
|
||||
func (s *Server) IsConnected() bool {
|
||||
s.mu.RLock()
|
||||
|
|
@ -866,7 +922,7 @@ func (s *Server) cleanupClosingConns() {
|
|||
if deletedCount >= deleteCount {
|
||||
break
|
||||
}
|
||||
if closeTime.Before(now.Add(-maxAge/2)) {
|
||||
if closeTime.Before(now.Add(-maxAge / 2)) {
|
||||
delete(s.closingConns, connID)
|
||||
deletedCount++
|
||||
}
|
||||
|
|
|
|||
|
|
@ -456,10 +456,11 @@ func TestForwardConnection(t *testing.T) {
|
|||
|
||||
// 启动连接转发
|
||||
go func() {
|
||||
err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
|
||||
conn, err := server.ForwardConnection(serverSideConn, "127.0.0.1", targetPort)
|
||||
if err != nil {
|
||||
t.Errorf("转发连接失败: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
}()
|
||||
|
||||
// 读取连接请求
|
||||
|
|
|
|||
Loading…
Reference in New Issue