This commit is contained in:
Pan Qiancheng 2025-10-14 16:47:39 +08:00
commit 9c40cf8450
2 changed files with 57 additions and 11 deletions

View File

@ -123,27 +123,39 @@ func (c *Client) connectLoop() {
// 处理连接
var connWg sync.WaitGroup
connCtx, connCancel := context.WithCancel(context.Background())
connWg.Add(3)
go func() {
defer connWg.Done()
c.handleServerRead(conn)
c.handleServerRead(conn, connCtx, connCancel)
}()
go func() {
defer connWg.Done()
c.handleServerWrite(conn)
c.handleServerWrite(conn, connCtx, connCancel)
}()
go func() {
defer connWg.Done()
c.keepAliveLoop(conn)
c.keepAliveLoop(conn, connCtx, connCancel)
}()
// 等待连接断开
// 等待任一协程出错
connWg.Wait()
// 确保所有协程都退出
connCancel()
// 短暂等待确保资源清理
time.Sleep(100 * time.Millisecond)
// 连接断开后立即清理资源
c.mu.Lock()
c.serverConn = nil
c.mu.Unlock()
// 清空发送队列
c.drainSendChan()
// 关闭所有本地连接
c.connMu.Lock()
for _, conn := range c.connections {
@ -158,18 +170,26 @@ func (c *Client) connectLoop() {
c.connMu.Unlock()
log.Printf("与服务器断开连接,%v 后重连", ReconnectDelay)
time.Sleep(ReconnectDelay)
// 使用带取消的sleep确保可以及时响应关闭信号
select {
case <-time.After(ReconnectDelay):
case <-c.ctx.Done():
return
}
}
}
// handleServerRead 处理服务器读取
func (c *Client) handleServerRead(conn net.Conn) {
func (c *Client) handleServerRead(conn net.Conn, connCtx context.Context, connCancel context.CancelFunc) {
defer conn.Close()
for {
select {
case <-c.ctx.Done():
return
case <-connCtx.Done():
return
default:
}
@ -178,6 +198,7 @@ func (c *Client) handleServerRead(conn net.Conn) {
if err != io.EOF {
log.Printf("读取隧道消息失败: %v", err)
}
connCancel() // 通知其他协程退出
return
}
@ -186,14 +207,19 @@ func (c *Client) handleServerRead(conn net.Conn) {
}
// handleServerWrite 处理服务器写入
func (c *Client) handleServerWrite(conn net.Conn) {
func (c *Client) handleServerWrite(conn net.Conn, connCtx context.Context, connCancel context.CancelFunc) {
for {
select {
case <-c.ctx.Done():
return
case <-connCtx.Done():
return
case msg := <-c.sendChan:
if err := c.writeTunnelMessage(conn, msg); err != nil {
log.Printf("写入隧道消息失败: %v", err)
// 清空发送队列,避免阻塞
go c.drainSendChan()
connCancel() // 通知其他协程退出
return
}
}
@ -518,7 +544,7 @@ func isTimeout(err error) bool {
}
// keepAliveLoop 心跳循环
func (c *Client) keepAliveLoop(conn net.Conn) {
func (c *Client) keepAliveLoop(conn net.Conn, connCtx context.Context, connCancel context.CancelFunc) {
ticker := time.NewTicker(KeepAliveInterval)
defer ticker.Stop()
@ -526,6 +552,8 @@ func (c *Client) keepAliveLoop(conn net.Conn) {
select {
case <-c.ctx.Done():
return
case <-connCtx.Done():
return
case <-ticker.C:
// 发送心跳消息
keepAliveMsg := &TunnelMessage{
@ -537,13 +565,29 @@ func (c *Client) keepAliveLoop(conn net.Conn) {
select {
case c.sendChan <- keepAliveMsg:
log.Printf("发送心跳消息")
// 降低心跳日志频率,避免刷屏
// log.Printf("发送心跳消息")
case <-time.After(5 * time.Second):
log.Printf("发送心跳消息超时")
connCancel() // 通知其他协程退出
return
case <-c.ctx.Done():
return
case <-connCtx.Done():
return
}
}
}
}
// drainSendChan 清空发送队列,避免阻塞
func (c *Client) drainSendChan() {
for {
select {
case <-c.sendChan:
// 丢弃消息
default:
return
}
}
}

View File

@ -180,10 +180,11 @@ func (s *Server) acceptLoop() {
log.Printf("隧道客户端已连接: %s", conn.RemoteAddr())
s.wg.Add(3)
s.wg.Add(2)
go s.handleTunnelRead(conn)
go s.handleTunnelWrite(conn)
go s.keepAliveLoop(conn)
// 注释掉服务器端主动心跳,只由客户端发送心跳
// go s.keepAliveLoop(conn)
}
}
@ -433,6 +434,7 @@ func (s *Server) handleKeepAlive(msg *TunnelMessage) {
select {
case s.sendChan <- response:
// log.Printf("回应心跳消息") // 降低日志频率
default:
log.Printf("发送心跳响应失败: 发送队列已满")
}