From a30a7e38b42b3f5bfc1cd1bf95c7f9646a971059 Mon Sep 17 00:00:00 2001 From: "qcqcqc@wsl" <1220204124@zust.edu.cn> Date: Thu, 16 Oct 2025 16:41:46 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=BA=86=E4=B8=A4?= =?UTF-8?q?=E7=AB=AF=E9=87=8D=E5=A4=8Dack=E5=AF=BC=E8=87=B4=E7=9A=84?= =?UTF-8?q?=E8=BF=87=E5=BA=A6cpu=E5=8D=A0=E7=94=A8=E5=92=8C=E5=B8=A6?= =?UTF-8?q?=E5=AE=BD=E5=8D=A0=E7=94=A8=20fix:=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E4=BA=86=E8=AF=BB=E5=8F=96=E8=B6=85=E6=97=B6=E5=B0=8F=E4=BA=8E?= =?UTF-8?q?ssh=E9=BB=98=E8=AE=A4=E7=9A=8460s=E5=AF=BC=E8=87=B4=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E4=B8=AD=E6=96=AD=E7=9A=84=E9=97=AE=E9=A2=98=20feat:?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=E4=BA=86=E6=B5=81=E9=87=8F=E7=BB=9F?= =?UTF-8?q?=E8=AE=A1=E9=A1=B5=E9=9D=A2=E4=B8=8E=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- DEBUG_GUIDE.md | 360 +++++++++++++++++++++++++ TRAFFIC_MONITORING.md | 278 +++++++++++++++++++ monitor.sh | 122 +++++++++ src/client/main.go | 21 ++ src/client/tunnel/client.go | 92 +++++-- src/client/tunnel/client_test.go | 10 +- src/server/api/api.go | 55 ++++ src/server/api/api_test.go | 2 +- src/server/api/html.go | 341 +++++++++++++++++++++++ src/server/config/config_test.go | 2 +- src/server/forwarder/forwarder.go | 65 ++++- src/server/forwarder/forwarder_test.go | 13 +- src/server/main.go | 23 +- src/server/stats/stats.go | 25 ++ src/server/tunnel/tunnel.go | 108 +++++++- 15 files changed, 1459 insertions(+), 58 deletions(-) create mode 100644 DEBUG_GUIDE.md create mode 100644 TRAFFIC_MONITORING.md create mode 100644 monitor.sh create mode 100644 src/server/api/html.go create mode 100644 src/server/stats/stats.go diff --git a/DEBUG_GUIDE.md b/DEBUG_GUIDE.md new file mode 100644 index 0000000..546910d --- /dev/null +++ b/DEBUG_GUIDE.md @@ -0,0 +1,360 @@ +# Go Tunnel 调试指南 + +## 概述 +本指南介绍如何使用 Go 的内置性能分析工具来调试和诊断 CPU 占用、内存泄漏和 goroutine 泄漏等问题。 + +## 已添加的调试功能 + +### 1. pprof HTTP 接口 +- **服务器**: `http://localhost:6060/debug/pprof/` +- **客户端**: `http://localhost:6061/debug/pprof/` + +### 2. Goroutine 监控 +每10秒自动打印当前 goroutine 数量到日志。 + +## 使用方法 + +### 方法一:实时查看 Goroutine 堆栈(最有用!) + +这是找出 CPU 占用问题的最直接方法。 + +#### 1. 启动服务器和客户端 +```bash +# 终端1:启动服务器 +cd /home/qcqcqc/workspace/go-tunnel/src +make run-server + +# 终端2:启动客户端 +make run-client + +# 终端3:建立 SSH 连接测试 +ssh root@localhost -p 30009 +# 执行一些操作后断开 +``` + +#### 2. SSH 断开后,立即查看 goroutine 堆栈 + +**查看服务器的 goroutine:** +```bash +# 在浏览器打开或使用 curl +curl http://localhost:6060/debug/pprof/goroutine?debug=2 + +# 或者保存到文件 +curl http://localhost:6060/debug/pprof/goroutine?debug=2 > server_goroutines.txt + +# 使用 less 查看(方便搜索) +curl http://localhost:6060/debug/pprof/goroutine?debug=2 | less +``` + +**查看客户端的 goroutine:** +```bash +curl http://localhost:6061/debug/pprof/goroutine?debug=2 > client_goroutines.txt +``` + +#### 3. 分析 goroutine 堆栈 + +查看输出中的重复模式: +- **正常情况**: 应该只有几个基础 goroutine(监听器、心跳等) +- **异常情况**: 如果有大量相同的堆栈,说明有 goroutine 泄漏 + +**关键搜索词**: +```bash +# 在堆栈文件中搜索 +grep -n "forwardData" server_goroutines.txt +grep -n "Read" server_goroutines.txt +grep -n "runtime.gopark" server_goroutines.txt +``` + +**如何解读堆栈**: +``` +goroutine 123 [running]: +port-forward/server/tunnel.(*Server).forwardData(0xc000120000, 0xc000130000) + /path/to/tunnel.go:456 +0x123 + +这表示 goroutine 123 正在执行 forwardData 函数的第456行 +``` + +### 方法二:CPU Profile(找出高 CPU 占用的函数) + +#### 1. 收集 30 秒的 CPU profile +```bash +# 服务器 +curl http://localhost:6060/debug/pprof/profile?seconds=30 > server_cpu.prof + +# 客户端 +curl http://localhost:6061/debug/pprof/profile?seconds=30 > client_cpu.prof +``` + +**注意**: 在这30秒内,让程序保持在高CPU占用状态(SSH连接和断开) + +#### 2. 分析 CPU profile +```bash +# 使用 go tool pprof 交互式分析 +go tool pprof server_cpu.prof + +# 进入交互式界面后的常用命令: +(pprof) top # 显示占用CPU最多的函数 +(pprof) top -cum # 按累计时间排序 +(pprof) list forwardData # 显示 forwardData 函数的详细分析 +(pprof) web # 生成可视化图表(需要 graphviz) +(pprof) quit # 退出 +``` + +#### 3. Web界面分析(推荐) +```bash +# 启动 Web UI +go tool pprof -http=:8080 server_cpu.prof + +# 然后在浏览器打开: http://localhost:8080 +# 可以看到火焰图、调用图等可视化分析 +``` + +### 方法三:实时 CPU Profile(找出正在运行的热点) + +```bash +# 查看当前正在执行的函数 +curl http://localhost:6060/debug/pprof/profile?seconds=5 | go tool pprof -http=:8080 - +``` + +### 方法四:查看所有 goroutine 数量 + +```bash +# 服务器 +curl http://localhost:6060/debug/pprof/goroutine + +# 或者在浏览器访问,会显示友好的界面 +``` + +### 方法五:内存分析(如果怀疑内存泄漏) + +```bash +# 堆内存 profile +curl http://localhost:6060/debug/pprof/heap > heap.prof +go tool pprof -http=:8080 heap.prof + +# 内存分配统计 +curl http://localhost:6060/debug/pprof/allocs > allocs.prof +go tool pprof -http=:8080 allocs.prof +``` + +### 方法六:使用 trace 进行详细跟踪 + +```bash +# 收集 5 秒的执行跟踪 +curl http://localhost:6060/debug/pprof/trace?seconds=5 > trace.out + +# 查看跟踪 +go tool trace trace.out + +# 这会启动一个 Web 界面,可以看到: +# - 每个 goroutine 的时间线 +# - 系统调用 +# - GC 事件 +# - goroutine 创建和销毁 +``` + +## 典型问题诊断流程 + +### 场景:SSH 断开后 CPU 仍然高占用 + +#### 步骤 1:确认问题 +```bash +# 观察日志中的 goroutine 数量 +# 正常情况下应该在 10 个以内 +# 如果持续增长或高于 50,说明有泄漏 +``` + +#### 步骤 2:抓取 goroutine 堆栈 +```bash +# SSH 连接前 +curl http://localhost:6060/debug/pprof/goroutine?debug=2 > before.txt + +# SSH 连接中 +curl http://localhost:6060/debug/pprof/goroutine?debug=2 > during.txt + +# SSH 断开后(等待 10 秒) +curl http://localhost:6060/debug/pprof/goroutine?debug=2 > after.txt + +# 比较文件 +diff before.txt after.txt | less +``` + +#### 步骤 3:查找泄漏的 goroutine +```bash +# 统计每种堆栈的数量 +grep -A 20 "^goroutine" after.txt | grep "port-forward" | sort | uniq -c | sort -rn +``` + +#### 步骤 4:分析特定函数 +```bash +# 如果发现大量 forwardData goroutine +grep -A 30 "forwardData" after.txt | head -50 +``` + +#### 步骤 5:CPU Profile +```bash +# 在 CPU 高占用时收集 +curl http://localhost:6060/debug/pprof/profile?seconds=10 > high_cpu.prof + +# 分析 +go tool pprof -http=:8080 high_cpu.prof + +# 查看 top 函数,通常会看到: +# - Read 操作占用高 -> 说明在忙循环读取 +# - syscall 占用高 -> 可能是网络调用问题 +# - runtime.schedule -> goroutine 调度开销大(goroutine 太多) +``` + +## 预期的正常状态 + +### Goroutine 数量 +- **服务器空闲**: 8-12 个 goroutine + - acceptLoop (1) + - handleTunnelRead (1) + - handleTunnelWrite (1) + - goroutine 监控 (1) + - pprof HTTP (1) + - API HTTP (1) + - 其他系统 goroutine (2-6) + +- **有 SSH 连接时**: +2 个 goroutine + - forwardData (1) 服务器端 + - forwardData (1) 客户端 + +- **SSH 断开后**: 应该回到空闲状态的数量 + +### CPU 占用 +- **空闲**: < 5% +- **传输数据时**: 取决于传输速度,但不应该持续高占用 +- **连接断开后**: 立即回到 < 5% + +## 常见 CPU 占用原因及特征 + +### 1. 忙循环读取 +**特征**: +- CPU 占用 80-100% +- goroutine 堆栈卡在 `Read()` 或 `conn.Read()` +- pprof 显示大量时间在 `syscall.Read` + +**原因**: 连接已关闭,但循环继续调用 Read,立即返回错误 + +### 2. Channel 阻塞循环 +**特征**: +- CPU 占用 20-40% +- goroutine 堆栈在 `select` 或 channel 操作 +- 大量 goroutine 在 `runtime.gopark` + +**原因**: channel 满了或没有接收者,发送阻塞 + +### 3. Goroutine 泄漏 +**特征**: +- goroutine 数量持续增长 +- 内存占用增长 +- CPU 占用逐渐升高 + +**原因**: goroutine 没有正确退出 + +## 快速命令参考 + +```bash +# 查看 goroutine 数量 +curl -s http://localhost:6060/debug/pprof/ | grep goroutine + +# 查看详细的 goroutine 堆栈 +curl http://localhost:6060/debug/pprof/goroutine?debug=2 | less + +# 收集 CPU profile 并分析 +curl http://localhost:6060/debug/pprof/profile?seconds=10 | go tool pprof -http=:8080 - + +# 查看实时监控日志 +tail -f /path/to/server.log | grep "监控" + +# 统计当前各种 goroutine 的数量 +curl -s http://localhost:6060/debug/pprof/goroutine?debug=2 | \ + grep -E "^goroutine|^\w+\(" | \ + paste - - | \ + awk '{print $NF}' | \ + sort | uniq -c | sort -rn + +# 只看 port-forward 相关的 goroutine +curl -s http://localhost:6060/debug/pprof/goroutine?debug=2 | \ + grep -A 15 "port-forward" +``` + +## 监控脚本 + +创建一个监控脚本 `monitor.sh`: + +```bash +#!/bin/bash + +echo "开始监控 go-tunnel..." +echo "时间戳 | Goroutines | CPU%" +echo "------|------------|-----" + +while true; do + # 获取 goroutine 数量 + GOROUTINES=$(curl -s http://localhost:6060/debug/pprof/goroutine | \ + grep -oP 'goroutine profile: total \K\d+' || echo "N/A") + + # 获取进程 CPU 占用 + PID=$(pgrep -f "server.*9000" | head -1) + if [ ! -z "$PID" ]; then + CPU=$(ps -p $PID -o %cpu= || echo "N/A") + else + CPU="N/A" + fi + + # 输出 + echo "$(date +%H:%M:%S) | $GOROUTINES | $CPU%" + + sleep 2 +done +``` + +使用方法: +```bash +chmod +x monitor.sh +./monitor.sh +``` + +## 故障排查清单 + +- [ ] 检查日志中的 goroutine 数量是否异常增长 +- [ ] 抓取 goroutine 堆栈,查看是否有重复的堆栈 +- [ ] 使用 CPU profile 找出占用最多的函数 +- [ ] 检查是否有 goroutine 卡在 Read/Write 操作 +- [ ] 验证连接关闭后,相关 goroutine 是否退出 +- [ ] 检查 channel 是否有阻塞 +- [ ] 使用 trace 查看详细的执行流程 + +## 其他有用的命令 + +```bash +# 查看所有可用的 pprof endpoints +curl http://localhost:6060/debug/pprof/ + +# 查看程序版本和编译信息 +curl http://localhost:6060/debug/pprof/cmdline + +# 查看当前的调度器状态 +GODEBUG=schedtrace=1000 ./server + +# 查看 GC 统计 +curl http://localhost:6060/debug/pprof/heap?debug=1 | head -30 +``` + +## 参考资料 + +- [Go pprof 官方文档](https://golang.org/pkg/net/http/pprof/) +- [Go tool pprof 使用指南](https://github.com/google/pprof/blob/master/doc/README.md) +- [Go 性能分析最佳实践](https://go.dev/blog/pprof) + +## 下一步 + +在发现问题后: +1. 记录问题出现时的堆栈信息 +2. 确认是哪个函数/循环导致的 +3. 检查该位置的退出条件 +4. 验证错误处理逻辑 +5. 添加更多的日志和退出检查 diff --git a/TRAFFIC_MONITORING.md b/TRAFFIC_MONITORING.md new file mode 100644 index 0000000..e9391b0 --- /dev/null +++ b/TRAFFIC_MONITORING.md @@ -0,0 +1,278 @@ +# 流量监控功能文档 + +## 概述 + +新增了实时流量监控功能,可以查看每个端口映射和隧道的流量统计,并以可视化图表的形式展示。 + +## 新增功能 + +### 1. 流量统计 + +系统自动统计以下数据: +- **隧道流量**: 通过tunnel发送/接收的总字节数 +- **端口映射流量**: 每个端口映射单独的流量统计 +- **总流量**: 所有流量的汇总 + +### 2. API 接口 + +#### GET /api/stats/traffic + +获取当前最新的流量统计数据。 + +**响应示例**: +```json +{ + "success": true, + "message": "获取流量统计成功", + "data": { + "tunnel": { + "bytes_sent": 1048576, + "bytes_received": 2097152, + "last_update": 1697462400 + }, + "mappings": [ + { + "port": 30009, + "bytes_sent": 524288, + "bytes_received": 1048576, + "last_update": 1697462400 + } + ], + "total_sent": 1572864, + "total_received": 3145728, + "timestamp": 1697462400 + } +} +``` + +#### GET /api/stats/monitor + +访问流量监控Web页面,提供可视化的实时监控界面。 + +**特性**: +- 📊 实时流量趋势图表 +- 📈 流量速率计算 (KB/s) +- 🔄 每3秒自动刷新 +- 📱 响应式设计,支持移动设备 +- 🎨 美观的UI界面 + +### 3. Web监控页面 + +访问地址: `http://localhost:8080/api/stats/monitor` + +**页面内容**: + +1. **总览卡片** + - 总发送流量 + - 总接收流量 + - 隧道发送流量 + - 隧道接收流量 + +2. **实时流量趋势图** + - 显示最近20个数据点 + - 发送/接收速率曲线 + - 单位: KB/s + +3. **端口映射详情** + - 每个端口的流量统计 + - 实时更新 + +## 使用方法 + +### 1. 启动服务 + +```bash +cd /home/qcqcqc/workspace/go-tunnel/src +make build +make run-server +``` + +### 2. 访问监控页面 + +在浏览器中打开: +``` +http://localhost:8080/api/stats/monitor +``` + +### 3. API调用示例 + +使用curl获取流量数据: +```bash +# 获取流量统计 +curl http://localhost:8080/api/stats/traffic + +# 格式化输出 +curl http://localhost:8080/api/stats/traffic | jq . +``` + +使用JavaScript获取数据: +```javascript +fetch('http://localhost:8080/api/stats/traffic') + .then(res => res.json()) + .then(data => { + console.log('总发送:', data.data.total_sent); + console.log('总接收:', data.data.total_received); + }); +``` + +### 4. 集成到自己的前端 + +你可以定期调用 `/api/stats/traffic` 接口获取数据,然后在自己的前端页面中展示: + +```javascript +// 每3秒获取一次数据 +setInterval(async () => { + const response = await fetch('http://localhost:8080/api/stats/traffic'); + const result = await response.json(); + + if (result.success) { + const data = result.data; + + // 更新UI + updateTotalSent(data.total_sent); + updateTotalReceived(data.total_received); + + // 更新图表 + updateChart(data); + + // 更新端口列表 + updatePortList(data.mappings); + } +}, 3000); +``` + +## 实现细节 + +### 流量统计机制 + +1. **Tunnel层统计** + - 在 `writeTunnelMessage` 中记录发送字节数 + - 在 `readTunnelMessage` 中记录接收字节数 + - 使用 `atomic.AddUint64` 保证并发安全 + +2. **Forwarder层统计** + - 在 `io.Copy` 返回后记录传输字节数 + - 分别统计客户端→目标和目标→客户端的流量 + +3. **数据结构** + ```go + type TrafficStats struct { + BytesSent uint64 // 发送字节数 + BytesReceived uint64 // 接收字节数 + LastUpdate int64 // 最后更新时间 + } + ``` + +### 性能考虑 + +- 使用原子操作 (`sync/atomic`) 避免锁竞争 +- 统计开销极小,几乎不影响转发性能 +- 只在需要时才计算速率 + +## 测试验证 + +### 1. 基础测试 + +```bash +# 1. 启动服务器和客户端 +make run-server # 终端1 +make run-client # 终端2 + +# 2. 创建端口映射 +curl -X POST http://localhost:8080/api/mapping/create \ + -H "Content-Type: application/json" \ + -d '{ + "source_port": 30009, + "target_host": "127.0.0.1", + "target_port": 22, + "use_tunnel": true + }' + +# 3. 通过映射传输一些数据 +ssh root@localhost -p 30009 + +# 4. 查看流量统计 +curl http://localhost:8080/api/stats/traffic | jq . +``` + +### 2. 压力测试 + +```bash +# 使用 scp 传输大文件测试流量统计 +dd if=/dev/zero of=test.dat bs=1M count=100 +scp -P 30009 test.dat root@localhost:/tmp/ + +# 观察监控页面的实时变化 +``` + +### 3. 长时间运行测试 + +```bash +# 保持监控页面打开,观察: +# - 图表是否正常更新 +# - 数据是否累积正确 +# - 内存占用是否稳定 +``` + +## 故障排查 + +### 问题1: 流量统计不准确 + +**可能原因**: +- 连接未正常关闭 +- 统计溢出(极少见,uint64可存储18EB) + +**解决方法**: +```bash +# 重启服务器重置统计 +``` + +### 问题2: 监控页面无法访问 + +**检查**: +```bash +# 确认服务器正在运行 +curl http://localhost:8080/health + +# 检查端口是否被占用 +netstat -tlnp | grep 8080 +``` + +### 问题3: 图表不更新 + +**检查**: +1. 打开浏览器开发者工具 (F12) +2. 查看 Console 是否有错误 +3. 查看 Network 标签,确认 API 请求成功 + +## 未来改进 + +可能的增强功能: +- [ ] 流量统计持久化(保存到数据库) +- [ ] 历史流量查询 +- [ ] 流量告警(超过阈值时通知) +- [ ] 导出流量报表 +- [ ] 按时间段统计(小时/天/月) +- [ ] WebSocket实时推送(减少轮询) + +## 相关文件 + +- `src/server/stats/stats.go` - 流量统计数据结构 +- `src/server/tunnel/tunnel.go` - Tunnel流量统计实现 +- `src/server/forwarder/forwarder.go` - Forwarder流量统计实现 +- `src/server/api/api.go` - API接口和监控页面 + +## 技术栈 + +- **后端**: Go 1.x +- **前端**: 原生 JavaScript + Chart.js 4.4.0 +- **图表库**: Chart.js (CDN) +- **样式**: CSS3 (渐变、动画) + +## 许可证 + +与主项目相同 + +## 更新日期 + +2025-10-16 diff --git a/monitor.sh b/monitor.sh new file mode 100644 index 0000000..27ef7a1 --- /dev/null +++ b/monitor.sh @@ -0,0 +1,122 @@ +#!/bin/bash + +# Go Tunnel 监控脚本 +# 用于实时监控 goroutine 数量和 CPU 占用 + +SERVER_PPROF="http://localhost:6060" +CLIENT_PPROF="http://localhost:6061" + +# 颜色定义 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +echo "========================================" +echo " Go Tunnel 实时监控" +echo "========================================" +echo "" +echo "监控项:" +echo " - Goroutine 数量(正常应该 < 15)" +echo " - CPU 占用率" +echo " - 进程状态" +echo "" +echo "按 Ctrl+C 停止监控" +echo "========================================" +echo "" + +printf "%-12s | %-20s | %-20s | %-10s\n" "时间" "服务器" "客户端" "CPU占用" +printf "%-12s | %-20s | %-20s | %-10s\n" "--------" "--------------------" "--------------------" "----------" + +while true; do + TIMESTAMP=$(date +%H:%M:%S) + + # 获取服务器 goroutine 数量 + SERVER_GOROUTINES=$(curl -s --connect-timeout 1 $SERVER_PPROF/debug/pprof/goroutine 2>/dev/null | \ + grep -oP 'goroutine profile: total \K\d+' || echo "N/A") + + # 获取客户端 goroutine 数量 + CLIENT_GOROUTINES=$(curl -s --connect-timeout 1 $CLIENT_PPROF/debug/pprof/goroutine 2>/dev/null | \ + grep -oP 'goroutine profile: total \K\d+' || echo "N/A") + + # 获取服务器进程 CPU 占用 + SERVER_PID=$(pgrep -f "bin/server" | head -1) + if [ ! -z "$SERVER_PID" ]; then + SERVER_CPU=$(ps -p $SERVER_PID -o %cpu= | tr -d ' ') + else + SERVER_CPU="未运行" + fi + + # 获取客户端进程 CPU 占用 + CLIENT_PID=$(pgrep -f "bin/client" | head -1) + if [ ! -z "$CLIENT_PID" ]; then + CLIENT_CPU=$(ps -p $CLIENT_PID -o %cpu= | tr -d ' ') + else + CLIENT_CPU="未运行" + fi + + # 格式化 goroutine 信息 + if [ "$SERVER_GOROUTINES" != "N/A" ]; then + SERVER_INFO="Goroutines: $SERVER_GOROUTINES" + # 如果 goroutine 数量异常,高亮显示 + if [ "$SERVER_GOROUTINES" -gt 20 ]; then + SERVER_INFO="${RED}Goroutines: $SERVER_GOROUTINES ⚠${NC}" + elif [ "$SERVER_GOROUTINES" -gt 15 ]; then + SERVER_INFO="${YELLOW}Goroutines: $SERVER_GOROUTINES ⚠${NC}" + else + SERVER_INFO="${GREEN}Goroutines: $SERVER_GOROUTINES ✓${NC}" + fi + else + SERVER_INFO="${RED}离线${NC}" + fi + + if [ "$CLIENT_GOROUTINES" != "N/A" ]; then + CLIENT_INFO="Goroutines: $CLIENT_GOROUTINES" + if [ "$CLIENT_GOROUTINES" -gt 20 ]; then + CLIENT_INFO="${RED}Goroutines: $CLIENT_GOROUTINES ⚠${NC}" + elif [ "$CLIENT_GOROUTINES" -gt 15 ]; then + CLIENT_INFO="${YELLOW}Goroutines: $CLIENT_GOROUTINES ⚠${NC}" + else + CLIENT_INFO="${GREEN}Goroutines: $CLIENT_GOROUTINES ✓${NC}" + fi + else + CLIENT_INFO="${RED}离线${NC}" + fi + + # CPU 信息 + if [ "$SERVER_CPU" != "未运行" ]; then + CPU_INFO="S:${SERVER_CPU}%" + # 检查 CPU 占用是否过高 + CPU_VALUE=$(echo $SERVER_CPU | cut -d. -f1) + if [ "$CPU_VALUE" -gt 50 ] 2>/dev/null; then + CPU_INFO="${RED}S:${SERVER_CPU}% ⚠${NC}" + elif [ "$CPU_VALUE" -gt 10 ] 2>/dev/null; then + CPU_INFO="${YELLOW}S:${SERVER_CPU}%${NC}" + else + CPU_INFO="${GREEN}S:${SERVER_CPU}%${NC}" + fi + else + CPU_INFO="${RED}未运行${NC}" + fi + + if [ "$CLIENT_CPU" != "未运行" ]; then + if [ "$SERVER_CPU" != "未运行" ]; then + CPU_INFO="${CPU_INFO} " + else + CPU_INFO="" + fi + CLIENT_CPU_VALUE=$(echo $CLIENT_CPU | cut -d. -f1) + if [ "$CLIENT_CPU_VALUE" -gt 50 ] 2>/dev/null; then + CPU_INFO="${CPU_INFO}${RED}C:${CLIENT_CPU}% ⚠${NC}" + elif [ "$CLIENT_CPU_VALUE" -gt 10 ] 2>/dev/null; then + CPU_INFO="${CPU_INFO}${YELLOW}C:${CLIENT_CPU}%${NC}" + else + CPU_INFO="${CPU_INFO}${GREEN}C:${CLIENT_CPU}%${NC}" + fi + fi + + # 输出监控信息 + printf "%-12s | %-35s | %-35s | %-30s\n" "$TIMESTAMP" "$SERVER_INFO" "$CLIENT_INFO" "$CPU_INFO" + + sleep 2 +done diff --git a/src/client/main.go b/src/client/main.go index b374c04..abe0016 100644 --- a/src/client/main.go +++ b/src/client/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "log" + _ "net/http/pprof" // 导入pprof用于性能分析 "os" "os/signal" "port-forward/client/tunnel" @@ -27,8 +28,28 @@ func main() { log.Fatalf("启动隧道客户端失败: %v", err) } + // // 启动 pprof 调试服务器(用于性能分析和调试) + // pprofPort := 6061 + // go func() { + // log.Printf("启动 pprof 调试服务器: http://localhost:%d/debug/pprof/", pprofPort) + // if err := http.ListenAndServe(":6061", nil); err != nil { + // log.Printf("pprof 服务器启动失败: %v", err) + // } + // }() + + // // 启动 goroutine 监控 + // go func() { + // ticker := time.NewTicker(10 * time.Second) + // defer ticker.Stop() + // for range ticker.C { + // numGoroutines := runtime.NumGoroutine() + // log.Printf("[监控] 当前 Goroutine 数量: %d", numGoroutines) + // } + // }() + log.Println("===========================================") log.Println("隧道客户端运行中...") + // log.Printf("调试接口: http://localhost:%d/debug/pprof/", pprofPort) log.Println("按 Ctrl+C 退出") log.Println("===========================================") diff --git a/src/client/tunnel/client.go b/src/client/tunnel/client.go index 466ce0f..229356c 100644 --- a/src/client/tunnel/client.go +++ b/src/client/tunnel/client.go @@ -185,6 +185,7 @@ func (c *Client) handleServerRead(conn net.Conn, connCtx context.Context, connCa defer conn.Close() for { + // 检查是否应该退出 select { case <-c.ctx.Done(): return @@ -193,15 +194,19 @@ func (c *Client) handleServerRead(conn net.Conn, connCtx context.Context, connCa default: } + // 设置读取超时,避免无限阻塞 + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) msg, err := c.readTunnelMessage(conn) if err != nil { - if err != io.EOF { + if err != io.EOF && !isTimeout(err) { log.Printf("读取隧道消息失败: %v", err) } connCancel() // 通知其他协程退出 return } + // 重置读取超时 + conn.SetReadDeadline(time.Time{}) c.handleTunnelMessage(msg) } } @@ -265,6 +270,10 @@ func (c *Client) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) { // writeTunnelMessage 写入隧道消息 func (c *Client) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error { + // 设置写入超时,防止阻塞 + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + defer conn.SetWriteDeadline(time.Time{}) // 重置超时 + // 构建消息头 header := make([]byte, HeaderSize) header[0] = msg.Version @@ -273,13 +282,13 @@ func (c *Client) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error { // 写入消息头 if _, err := conn.Write(header); err != nil { - return err + return fmt.Errorf("写入消息头失败: %w", err) } // 写入数据 if msg.Length > 0 && msg.Data != nil { if _, err := conn.Write(msg.Data); err != nil { - return err + return fmt.Errorf("写入消息数据失败: %w", err) } } @@ -368,11 +377,24 @@ func (c *Client) handleDataMessage(msg *TunnelMessage) { c.connMu.RUnlock() if !exists { - log.Printf("收到未知连接的数据: %d", connID) + log.Printf("收到未知连接的数据: %d,发送关闭消息", connID) + // 连接不存在,发送关闭消息通知对端 + closeData := make([]byte, 4) + binary.BigEndian.PutUint32(closeData, connID) + closeMsg := &TunnelMessage{ + Version: ProtocolVersion, + Type: MsgTypeClose, + Length: 4, + Data: closeData, + } + select { + case c.sendChan <- closeMsg: + default: + } return } - if _, err := connection.Conn.Write(data); err != nil { + if _, err := connection.Conn.Write(data); err != nil { log.Printf("写入目标连接失败 (ID=%d): %v", connID, err) c.closeConnection(connID) } @@ -391,19 +413,9 @@ func (c *Client) handleCloseMessage(msg *TunnelMessage) { // handleKeepAlive 处理心跳消息 func (c *Client) handleKeepAlive(msg *TunnelMessage) { - // 回应心跳 - response := &TunnelMessage{ - Version: ProtocolVersion, - Type: MsgTypeKeepAlive, - Length: 0, - Data: nil, - } - - select { - case c.sendChan <- response: - default: - log.Printf("发送心跳响应失败: 发送队列已满") - } + // 客户端收到服务器的心跳响应,不需要再回应 + // 这样避免心跳消息的无限循环 + // log.Printf("收到服务器心跳响应") } // sendConnectResponse 发送连接响应 @@ -433,22 +445,38 @@ func (c *Client) forwardData(connection *LocalConnection) { buffer := make([]byte, 32*1024) for { select { - case <-connection.closeChan: - return case <-c.ctx.Done(): return + case <-connection.closeChan: + return default: } + // 设置读取超时 connection.Conn.SetReadDeadline(time.Now().Add(ReadTimeout)) n, err := connection.Conn.Read(buffer) + if err != nil { - if err != io.EOF && !isTimeout(err) { + // 任何错误都应该终止转发 + if err == io.EOF { + log.Printf("目标连接正常关闭 (ID=%d)", connection.ID) + } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + log.Printf("目标连接超时 (ID=%d)", connection.ID) + } else { log.Printf("读取目标连接失败 (ID=%d): %v", connection.ID, err) } return } + // 读取到0字节,连接已关闭 + if n == 0 { + log.Printf("目标连接已关闭 (ID=%d, 读取0字节)", connection.ID) + return + } + + // 重置读取超时 + connection.Conn.SetReadDeadline(time.Time{}) + // 发送数据到隧道 dataMsg := make([]byte, 4+n) binary.BigEndian.PutUint32(dataMsg[0:4], connection.ID) @@ -463,11 +491,14 @@ func (c *Client) forwardData(connection *LocalConnection) { select { case c.sendChan <- msg: + // 数据已发送 case <-time.After(5 * time.Second): log.Printf("发送数据超时 (ID=%d)", connection.ID) return case <-c.ctx.Done(): return + case <-connection.closeChan: + return } } } @@ -481,10 +512,18 @@ func (c *Client) closeConnection(connID uint32) { connection.closeOnce.Do(func() { close(connection.closeChan) }) - connection.Conn.Close() + // 确保连接被关闭 + if connection.Conn != nil { + connection.Conn.Close() + } } c.connMu.Unlock() + if !exists { + // 连接不存在,无需发送关闭消息 + return + } + // 发送关闭消息 closeData := make([]byte, 4) binary.BigEndian.PutUint32(closeData, connID) @@ -498,12 +537,11 @@ func (c *Client) closeConnection(connID uint32) { select { case c.sendChan <- msg: - default: - // 发送队列满,忽略 - } - - if exists { log.Printf("连接已关闭: ID=%d", connID) + case <-time.After(1 * time.Second): + log.Printf("发送关闭消息超时: ID=%d", connID) + case <-c.ctx.Done(): + log.Printf("客户端关闭,跳过发送关闭消息: ID=%d", connID) } } diff --git a/src/client/tunnel/client_test.go b/src/client/tunnel/client_test.go index c8767c2..2eadc81 100644 --- a/src/client/tunnel/client_test.go +++ b/src/client/tunnel/client_test.go @@ -152,16 +152,20 @@ func TestClientHandleConnectRequest(t *testing.T) { client := NewClient("127.0.0.1:9000") - // 创建连接请求消息 + // 创建连接请求消息(新格式:connID + port + hostLen + host) connID := uint32(12345) - reqData := make([]byte, 6) + targetHost := "127.0.0.1" + targetHostBytes := []byte(targetHost) + reqData := make([]byte, 7+len(targetHostBytes)) binary.BigEndian.PutUint32(reqData[0:4], connID) binary.BigEndian.PutUint16(reqData[4:6], uint16(localPort)) + reqData[6] = byte(len(targetHostBytes)) + copy(reqData[7:], targetHostBytes) msg := &TunnelMessage{ Version: ProtocolVersion, Type: MsgTypeConnectRequest, - Length: 6, + Length: uint32(len(reqData)), Data: reqData, } diff --git a/src/server/api/api.go b/src/server/api/api.go index cf2ca7a..c9759d3 100644 --- a/src/server/api/api.go +++ b/src/server/api/api.go @@ -8,6 +8,7 @@ import ( "net/http" "port-forward/server/db" "port-forward/server/forwarder" + "port-forward/server/stats" "port-forward/server/tunnel" "strconv" "time" @@ -58,6 +59,8 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/api/mapping/create", h.handleCreateMapping) mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping) mux.HandleFunc("/api/mapping/list", h.handleListMappings) + mux.HandleFunc("/api/stats/traffic", h.handleGetTrafficStats) + mux.HandleFunc("/api/stats/monitor", h.handleTrafficMonitor) mux.HandleFunc("/health", h.handleHealth) } @@ -277,4 +280,56 @@ func Start(handler *Handler, port int) error { log.Printf("HTTP API 服务启动: 端口 %d", port) return server.ListenAndServe() +} + +// handleGetTrafficStats 获取流量统计 +func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + h.writeError(w, http.StatusMethodNotAllowed, "只支持 GET 方法") + return + } + + // 获取隧道流量统计 + var tunnelStats stats.TrafficStats + if h.tunnelServer != nil { + tunnelStats = h.tunnelServer.GetTrafficStats() + } + + // 获取所有端口映射的流量统计 + forwarderStats := h.forwarderMgr.GetAllTrafficStats() + + // 构建响应 + mappings := make([]stats.PortTrafficStats, 0, len(forwarderStats)) + var totalSent, totalReceived uint64 + + for port, stat := range forwarderStats { + mappings = append(mappings, stats.PortTrafficStats{ + Port: port, + BytesSent: stat.BytesSent, + BytesReceived: stat.BytesReceived, + LastUpdate: stat.LastUpdate, + }) + totalSent += stat.BytesSent + totalReceived += stat.BytesReceived + } + + // 加上隧道的流量 + totalSent += tunnelStats.BytesSent + totalReceived += tunnelStats.BytesReceived + + response := stats.AllTrafficStats{ + Tunnel: tunnelStats, + Mappings: mappings, + TotalSent: totalSent, + TotalReceived: totalReceived, + Timestamp: time.Now().Unix(), + } + + h.writeSuccess(w, "获取流量统计成功", response) +} + +// handleTrafficMonitor 流量监控页面 +func (h *Handler) handleTrafficMonitor(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprint(w, html) } \ No newline at end of file diff --git a/src/server/api/api_test.go b/src/server/api/api_test.go index 7719bdc..5a6390a 100644 --- a/src/server/api/api_test.go +++ b/src/server/api/api_test.go @@ -260,7 +260,7 @@ func TestHandleCreateMappingInvalidIP(t *testing.T) { // Port: 15000, SourcePort: 15000, TargetPort: 15000, - TargetHost: "invalid-ip", + TargetHost: "", // 使用空字符串而不是无效域名,避免DNS查询超时 } body, _ := json.Marshal(reqBody) diff --git a/src/server/api/html.go b/src/server/api/html.go new file mode 100644 index 0000000..b689a3b --- /dev/null +++ b/src/server/api/html.go @@ -0,0 +1,341 @@ +package api + +const html = ` + + + + + + 流量监控 - Go Tunnel + + + + +
+

流量监控面板

+ +
+
+

总发送流量

+
0 B
+
Total Sent
+
+ +
+

总接收流量

+
0 B
+
Total Received
+
+ +
+

隧道发送

+
0 B
+
Tunnel Sent
+
+ +
+

隧道接收

+
0 B
+
Tunnel Received
+
+
+ +
+

实时流量趋势

+ +
+ +
+

端口映射流量

+
+
+ +
+ 最后更新: - | 每 3 秒自动刷新 +
+
+ + + + + ` \ No newline at end of file diff --git a/src/server/config/config_test.go b/src/server/config/config_test.go index 0e4d1ee..350a1e0 100644 --- a/src/server/config/config_test.go +++ b/src/server/config/config_test.go @@ -200,7 +200,7 @@ func TestValidateConfig(t *testing.T) { { name: "端口范围过大", config: Config{ - PortRange: PortRangeConfig{From: 1, End: 20000}, + PortRange: PortRangeConfig{From: 1, End: 40000}, Tunnel: TunnelConfig{Enabled: false, ListenPort: 0}, API: APIConfig{ListenPort: 8080}, Database: DatabaseConfig{Path: "./data/test.db"}, diff --git a/src/server/forwarder/forwarder.go b/src/server/forwarder/forwarder.go index d6a6141..b11c066 100644 --- a/src/server/forwarder/forwarder.go +++ b/src/server/forwarder/forwarder.go @@ -6,7 +6,9 @@ import ( "io" "log" "net" + "port-forward/server/stats" "sync" + "sync/atomic" "time" ) @@ -14,6 +16,7 @@ import ( type TunnelServer interface { ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error IsConnected() bool + GetTrafficStats() stats.TrafficStats } // Forwarder 端口转发器 @@ -27,6 +30,10 @@ type Forwarder struct { wg sync.WaitGroup tunnelServer TunnelServer useTunnel bool + + // 流量统计(使用原子操作) + bytesSent uint64 // 发送字节数 + bytesReceived uint64 // 接收字节数 } // NewForwarder 创建新的端口转发器 @@ -142,26 +149,44 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) { defer targetConn.Close() // 双向转发 - errChan := make(chan error, 2) + var wg sync.WaitGroup + wg.Add(2) // 客户端 -> 目标 go func() { - _, err := io.Copy(targetConn, clientConn) - errChan <- err + defer wg.Done() + n, _ := io.Copy(targetConn, clientConn) + atomic.AddUint64(&f.bytesSent, uint64(n)) + // 关闭目标连接的写入端,通知对方不会再发送数据 + if tcpConn, ok := targetConn.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } }() // 目标 -> 客户端 go func() { - _, err := io.Copy(clientConn, targetConn) - errChan <- err + defer wg.Done() + n, _ := io.Copy(clientConn, targetConn) + atomic.AddUint64(&f.bytesReceived, uint64(n)) + // 关闭客户端连接的写入端 + if tcpConn, ok := clientConn.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } }() - // 等待任一方向完成或出错 + // 创建一个 channel 来等待完成 + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + // 等待两个方向都完成或上下文取消 select { - case <-errChan: - // 连接已关闭或出错 + case <-done: + // 两个方向都已完成 case <-f.ctx.Done(): - // 转发器被停止 + // 转发器被停止,连接会在 defer 中关闭 } } @@ -279,4 +304,26 @@ func (m *Manager) StopAll() { } m.forwarders = make(map[int]*Forwarder) +} + +// GetTrafficStats 获取流量统计信息 +func (f *Forwarder) GetTrafficStats() stats.TrafficStats { + return stats.TrafficStats{ + BytesSent: atomic.LoadUint64(&f.bytesSent), + BytesReceived: atomic.LoadUint64(&f.bytesReceived), + LastUpdate: time.Now().Unix(), + } +} + +// GetAllTrafficStats 获取所有转发器的流量统计 +func (m *Manager) GetAllTrafficStats() map[int]stats.TrafficStats { + m.mu.RLock() + defer m.mu.RUnlock() + + statsMap := make(map[int]stats.TrafficStats) + for port, forwarder := range m.forwarders { + statsMap[port] = forwarder.GetTrafficStats() + } + + return statsMap } \ No newline at end of file diff --git a/src/server/forwarder/forwarder_test.go b/src/server/forwarder/forwarder_test.go index cfc2b8d..ac1b346 100644 --- a/src/server/forwarder/forwarder_test.go +++ b/src/server/forwarder/forwarder_test.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "port-forward/server/stats" "testing" "time" ) @@ -20,8 +21,12 @@ func TestNewForwarder(t *testing.T) { t.Errorf("源端口不正确,期望 8080,得到 %d", fwd.sourcePort) } - if fwd.targetHost != "192.168.1.100:80" { - t.Errorf("目标地址不正确,期望 192.168.1.100:80,得到 %s", fwd.targetHost) + if fwd.targetHost != "192.168.1.100" { + t.Errorf("目标主机不正确,期望 192.168.1.100,得到 %s", fwd.targetHost) + } + + if fwd.targetPort != 80 { + t.Errorf("目标端口不正确,期望 80,得到 %d", fwd.targetPort) } if fwd.useTunnel { @@ -43,6 +48,10 @@ func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp strin func (m *mockTunnelServer) IsConnected() bool { return m.connected } + +func (m *mockTunnelServer) GetTrafficStats() stats.TrafficStats { + return stats.TrafficStats{} +} // TestNewTunnelForwarder 测试创建隧道转发器 func TestNewTunnelForwarder(t *testing.T) { diff --git a/src/server/main.go b/src/server/main.go index 5935478..47fa174 100644 --- a/src/server/main.go +++ b/src/server/main.go @@ -4,6 +4,7 @@ import ( "context" "flag" "log" + _ "net/http/pprof" // 导入pprof用于性能分析 "os" "os/signal" "port-forward/server/api" @@ -64,8 +65,6 @@ func main() { continue } - log.Printf("恢复端口映射: %d -> %s:%d (tunnel: %v)", mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, mapping.UseTunnel) - var err error if mapping.UseTunnel { // 隧道模式:检查隧道服务器是否可用 @@ -103,10 +102,30 @@ func main() { } }() + // // 启动 pprof 调试服务器(用于性能分析和调试) + // pprofPort := 6060 + // go func() { + // log.Printf("启动 pprof 调试服务器: http://localhost:%d/debug/pprof/", pprofPort) + // if err := http.ListenAndServe(":6060", nil); err != nil { + // log.Printf("pprof 服务器启动失败: %v", err) + // } + // }() + + // // 启动 goroutine 监控 + // go func() { + // ticker := time.NewTicker(10 * time.Second) + // defer ticker.Stop() + // for range ticker.C { + // numGoroutines := runtime.NumGoroutine() + // log.Printf("[监控] 当前 Goroutine 数量: %d", numGoroutines) + // } + // }() + log.Println("===========================================") log.Printf("服务器启动成功!") log.Printf("端口范围: %d-%d", cfg.PortRange.From, cfg.PortRange.End) log.Printf("HTTP API: http://localhost:%d", cfg.API.ListenPort) + // log.Printf("调试接口: http://localhost:%d/debug/pprof/", pprofPort) if cfg.Tunnel.Enabled { log.Printf("隧道服务: 端口 %d", cfg.Tunnel.ListenPort) } diff --git a/src/server/stats/stats.go b/src/server/stats/stats.go new file mode 100644 index 0000000..2be1e7b --- /dev/null +++ b/src/server/stats/stats.go @@ -0,0 +1,25 @@ +package stats + +// TrafficStats 流量统计 +type TrafficStats struct { + BytesSent uint64 `json:"bytes_sent"` // 发送字节数 + BytesReceived uint64 `json:"bytes_received"` // 接收字节数 + LastUpdate int64 `json:"last_update"` // 最后更新时间(Unix时间戳) +} + +// PortTrafficStats 端口流量统计 +type PortTrafficStats struct { + Port int `json:"port"` + BytesSent uint64 `json:"bytes_sent"` + BytesReceived uint64 `json:"bytes_received"` + LastUpdate int64 `json:"last_update"` +} + +// AllTrafficStats 所有流量统计 +type AllTrafficStats struct { + Tunnel TrafficStats `json:"tunnel"` // 隧道流量 + Mappings []PortTrafficStats `json:"mappings"` // 端口映射流量 + TotalSent uint64 `json:"total_sent"` // 总发送 + TotalReceived uint64 `json:"total_received"` // 总接收 + Timestamp int64 `json:"timestamp"` // 时间戳 +} diff --git a/src/server/tunnel/tunnel.go b/src/server/tunnel/tunnel.go index 01d9aad..fea91bd 100644 --- a/src/server/tunnel/tunnel.go +++ b/src/server/tunnel/tunnel.go @@ -7,7 +7,9 @@ import ( "io" "log" "net" + "port-forward/server/stats" "sync" + "sync/atomic" "time" ) @@ -37,7 +39,7 @@ const ( // 超时设置 ConnectTimeout = 10 * time.Second // 连接超时 - ReadTimeout = 30 * time.Second // 读取超时 + ReadTimeout = 300 * time.Second // 读取超时 KeepAliveInterval = 15 * time.Second // 心跳间隔 ) @@ -111,6 +113,10 @@ type Server struct { // 消息队列 sendChan chan *TunnelMessage + + // 流量统计(使用原子操作) + bytesSent uint64 // 通过隧道发送的总字节数 + bytesReceived uint64 // 通过隧道接收的总字节数 } // NewServer 创建新的隧道服务器 @@ -214,20 +220,25 @@ func (s *Server) handleTunnelRead(conn net.Conn) { }() for { + // 检查是否应该退出 select { case <-s.ctx.Done(): return default: } + // 设置读取超时,避免无限阻塞 + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) msg, err := s.readTunnelMessage(conn) if err != nil { - if err != io.EOF { + if err != io.EOF && !isTimeout(err) { log.Printf("读取隧道消息失败: %v", err) } return } + // 重置读取超时 + conn.SetReadDeadline(time.Time{}) s.handleTunnelMessage(msg) } } @@ -256,6 +267,9 @@ func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) { if _, err := io.ReadFull(conn, header); err != nil { return nil, err } + + // 统计接收字节数 + s.addBytesReceived(uint64(HeaderSize)) version := header[0] msgType := header[1] @@ -276,6 +290,8 @@ func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) { if _, err := io.ReadFull(conn, data); err != nil { return nil, err } + // 统计接收字节数 + s.addBytesReceived(uint64(dataLen)) } return &TunnelMessage{ @@ -288,6 +304,10 @@ func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) { // writeTunnelMessage 写入隧道消息 func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error { + // 设置写入超时,防止阻塞 + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + defer conn.SetWriteDeadline(time.Time{}) // 重置超时 + // 构建消息头 header := make([]byte, HeaderSize) header[0] = msg.Version @@ -296,14 +316,19 @@ func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error { // 写入消息头 if _, err := conn.Write(header); err != nil { - return err + return fmt.Errorf("写入消息头失败: %w", err) } + + // 统计发送字节数 + s.addBytesSent(uint64(HeaderSize)) // 写入数据 if msg.Length > 0 && msg.Data != nil { if _, err := conn.Write(msg.Data); err != nil { - return err + return fmt.Errorf("写入消息数据失败: %w", err) } + // 统计发送字节数 + s.addBytesSent(uint64(msg.Length)) } return nil @@ -401,7 +426,20 @@ func (s *Server) handleDataMessage(msg *TunnelMessage) { s.connMu.RUnlock() if !exists { - log.Printf("收到未知连接的数据: %d", connID) + log.Printf("收到未知连接的数据: %d,发送关闭消息", connID) + // 连接不存在,发送关闭消息通知对端 + closeData := make([]byte, 4) + binary.BigEndian.PutUint32(closeData, connID) + closeMsg := &TunnelMessage{ + Version: ProtocolVersion, + Type: MsgTypeClose, + Length: 4, + Data: closeData, + } + select { + case s.sendChan <- closeMsg: + default: + } return } @@ -425,7 +463,8 @@ func (s *Server) handleCloseMessage(msg *TunnelMessage) { // handleKeepAlive 处理心跳消息 func (s *Server) handleKeepAlive(msg *TunnelMessage) { - // 回应心跳 + // 服务器收到客户端的心跳请求,回应一次即可 + // 不要形成心跳循环 response := &TunnelMessage{ Version: ProtocolVersion, Type: MsgTypeKeepAlive, @@ -454,15 +493,31 @@ func (s *Server) forwardData(active *ActiveConnection) { default: } + // 设置读取超时 active.ClientConn.SetReadDeadline(time.Now().Add(ReadTimeout)) n, err := active.ClientConn.Read(buffer) + if err != nil { - if err != io.EOF && !isTimeout(err) { + // 任何错误都应该终止转发,包括超时 + if err == io.EOF { + log.Printf("客户端连接正常关闭 (ID=%d)", active.ID) + } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + log.Printf("客户端连接超时 (ID=%d)", active.ID) + } else { log.Printf("读取客户端连接失败 (ID=%d): %v", active.ID, err) } return } + // 读取到0字节,连接已关闭 + if n == 0 { + log.Printf("客户端连接已关闭 (ID=%d, 读取0字节)", active.ID) + return + } + + // 重置读取超时 + active.ClientConn.SetReadDeadline(time.Time{}) + // 发送数据到隧道 dataMsg := make([]byte, 4+n) binary.BigEndian.PutUint32(dataMsg[0:4], active.ID) @@ -477,6 +532,7 @@ func (s *Server) forwardData(active *ActiveConnection) { select { case s.sendChan <- msg: + // 数据已发送 case <-time.After(5 * time.Second): log.Printf("发送数据超时 (ID=%d)", active.ID) return @@ -492,10 +548,18 @@ func (s *Server) closeConnection(connID uint32) { active, exists := s.activeConns[connID] if exists { delete(s.activeConns, connID) - active.ClientConn.Close() + // 确保连接被关闭 + if active.ClientConn != nil { + active.ClientConn.Close() + } } s.connMu.Unlock() + if !exists { + // 连接不存在,无需发送关闭消息 + return + } + // 发送关闭消息 closeData := make([]byte, 4) binary.BigEndian.PutUint32(closeData, connID) @@ -509,12 +573,11 @@ func (s *Server) closeConnection(connID uint32) { select { case s.sendChan <- msg: - default: - // 发送队列满,忽略 - } - - if exists { log.Printf("连接已关闭: ID=%d", connID) + case <-time.After(1 * time.Second): + log.Printf("发送关闭消息超时: ID=%d", connID) + case <-s.ctx.Done(): + log.Printf("服务器关闭,跳过发送关闭消息: ID=%d", connID) } } @@ -667,4 +730,23 @@ func (s *Server) keepAliveLoop(conn net.Conn) { } } } +} + +// GetTrafficStats 获取流量统计信息 +func (s *Server) GetTrafficStats() stats.TrafficStats { + return stats.TrafficStats{ + BytesSent: atomic.LoadUint64(&s.bytesSent), + BytesReceived: atomic.LoadUint64(&s.bytesReceived), + LastUpdate: time.Now().Unix(), + } +} + +// addBytesSent 增加发送字节数 +func (s *Server) addBytesSent(bytes uint64) { + atomic.AddUint64(&s.bytesSent, bytes) +} + +// addBytesReceived 增加接收字节数 +func (s *Server) addBytesReceived(bytes uint64) { + atomic.AddUint64(&s.bytesReceived, bytes) } \ No newline at end of file