fix: 修复了两端重复ack导致的过度cpu占用和带宽占用

fix: 修复了读取超时小于ssh默认的60s导致连接中断的问题
feat: 新增了流量统计页面与接口
This commit is contained in:
Pan Qiancheng 2025-10-16 16:41:46 +08:00
parent 5c7fa0ff17
commit a30a7e38b4
15 changed files with 1459 additions and 58 deletions

360
DEBUG_GUIDE.md Normal file
View File

@ -0,0 +1,360 @@
# Go Tunnel 调试指南
## 概述
本指南介绍如何使用 Go 的内置性能分析工具来调试和诊断 CPU 占用、内存泄漏和 goroutine 泄漏等问题。
## 已添加的调试功能
### 1. pprof HTTP 接口
- **服务器**: `http://localhost:6060/debug/pprof/`
- **客户端**: `http://localhost:6061/debug/pprof/`
### 2. Goroutine 监控
每10秒自动打印当前 goroutine 数量到日志。
## 使用方法
### 方法一:实时查看 Goroutine 堆栈(最有用!)
这是找出 CPU 占用问题的最直接方法。
#### 1. 启动服务器和客户端
```bash
# 终端1启动服务器
cd /home/qcqcqc/workspace/go-tunnel/src
make run-server
# 终端2启动客户端
make run-client
# 终端3建立 SSH 连接测试
ssh root@localhost -p 30009
# 执行一些操作后断开
```
#### 2. SSH 断开后,立即查看 goroutine 堆栈
**查看服务器的 goroutine**
```bash
# 在浏览器打开或使用 curl
curl http://localhost:6060/debug/pprof/goroutine?debug=2
# 或者保存到文件
curl http://localhost:6060/debug/pprof/goroutine?debug=2 > server_goroutines.txt
# 使用 less 查看(方便搜索)
curl http://localhost:6060/debug/pprof/goroutine?debug=2 | less
```
**查看客户端的 goroutine**
```bash
curl http://localhost:6061/debug/pprof/goroutine?debug=2 > client_goroutines.txt
```
#### 3. 分析 goroutine 堆栈
查看输出中的重复模式:
- **正常情况**: 应该只有几个基础 goroutine监听器、心跳等
- **异常情况**: 如果有大量相同的堆栈,说明有 goroutine 泄漏
**关键搜索词**
```bash
# 在堆栈文件中搜索
grep -n "forwardData" server_goroutines.txt
grep -n "Read" server_goroutines.txt
grep -n "runtime.gopark" server_goroutines.txt
```
**如何解读堆栈**
```
goroutine 123 [running]:
port-forward/server/tunnel.(*Server).forwardData(0xc000120000, 0xc000130000)
/path/to/tunnel.go:456 +0x123
这表示 goroutine 123 正在执行 forwardData 函数的第456行
```
### 方法二CPU Profile找出高 CPU 占用的函数)
#### 1. 收集 30 秒的 CPU profile
```bash
# 服务器
curl http://localhost:6060/debug/pprof/profile?seconds=30 > server_cpu.prof
# 客户端
curl http://localhost:6061/debug/pprof/profile?seconds=30 > client_cpu.prof
```
**注意**: 在这30秒内让程序保持在高CPU占用状态SSH连接和断开
#### 2. 分析 CPU profile
```bash
# 使用 go tool pprof 交互式分析
go tool pprof server_cpu.prof
# 进入交互式界面后的常用命令:
(pprof) top # 显示占用CPU最多的函数
(pprof) top -cum # 按累计时间排序
(pprof) list forwardData # 显示 forwardData 函数的详细分析
(pprof) web # 生成可视化图表(需要 graphviz
(pprof) quit # 退出
```
#### 3. Web界面分析推荐
```bash
# 启动 Web UI
go tool pprof -http=:8080 server_cpu.prof
# 然后在浏览器打开: http://localhost:8080
# 可以看到火焰图、调用图等可视化分析
```
### 方法三:实时 CPU Profile找出正在运行的热点
```bash
# 查看当前正在执行的函数
curl http://localhost:6060/debug/pprof/profile?seconds=5 | go tool pprof -http=:8080 -
```
### 方法四:查看所有 goroutine 数量
```bash
# 服务器
curl http://localhost:6060/debug/pprof/goroutine
# 或者在浏览器访问,会显示友好的界面
```
### 方法五:内存分析(如果怀疑内存泄漏)
```bash
# 堆内存 profile
curl http://localhost:6060/debug/pprof/heap > heap.prof
go tool pprof -http=:8080 heap.prof
# 内存分配统计
curl http://localhost:6060/debug/pprof/allocs > allocs.prof
go tool pprof -http=:8080 allocs.prof
```
### 方法六:使用 trace 进行详细跟踪
```bash
# 收集 5 秒的执行跟踪
curl http://localhost:6060/debug/pprof/trace?seconds=5 > trace.out
# 查看跟踪
go tool trace trace.out
# 这会启动一个 Web 界面,可以看到:
# - 每个 goroutine 的时间线
# - 系统调用
# - GC 事件
# - goroutine 创建和销毁
```
## 典型问题诊断流程
### 场景SSH 断开后 CPU 仍然高占用
#### 步骤 1确认问题
```bash
# 观察日志中的 goroutine 数量
# 正常情况下应该在 10 个以内
# 如果持续增长或高于 50说明有泄漏
```
#### 步骤 2抓取 goroutine 堆栈
```bash
# SSH 连接前
curl http://localhost:6060/debug/pprof/goroutine?debug=2 > before.txt
# SSH 连接中
curl http://localhost:6060/debug/pprof/goroutine?debug=2 > during.txt
# SSH 断开后(等待 10 秒)
curl http://localhost:6060/debug/pprof/goroutine?debug=2 > after.txt
# 比较文件
diff before.txt after.txt | less
```
#### 步骤 3查找泄漏的 goroutine
```bash
# 统计每种堆栈的数量
grep -A 20 "^goroutine" after.txt | grep "port-forward" | sort | uniq -c | sort -rn
```
#### 步骤 4分析特定函数
```bash
# 如果发现大量 forwardData goroutine
grep -A 30 "forwardData" after.txt | head -50
```
#### 步骤 5CPU Profile
```bash
# 在 CPU 高占用时收集
curl http://localhost:6060/debug/pprof/profile?seconds=10 > high_cpu.prof
# 分析
go tool pprof -http=:8080 high_cpu.prof
# 查看 top 函数,通常会看到:
# - Read 操作占用高 -> 说明在忙循环读取
# - syscall 占用高 -> 可能是网络调用问题
# - runtime.schedule -> goroutine 调度开销大goroutine 太多)
```
## 预期的正常状态
### Goroutine 数量
- **服务器空闲**: 8-12 个 goroutine
- acceptLoop (1)
- handleTunnelRead (1)
- handleTunnelWrite (1)
- goroutine 监控 (1)
- pprof HTTP (1)
- API HTTP (1)
- 其他系统 goroutine (2-6)
- **有 SSH 连接时**: +2 个 goroutine
- forwardData (1) 服务器端
- forwardData (1) 客户端
- **SSH 断开后**: 应该回到空闲状态的数量
### CPU 占用
- **空闲**: < 5%
- **传输数据时**: 取决于传输速度,但不应该持续高占用
- **连接断开后**: 立即回到 < 5%
## 常见 CPU 占用原因及特征
### 1. 忙循环读取
**特征**:
- CPU 占用 80-100%
- goroutine 堆栈卡在 `Read()``conn.Read()`
- pprof 显示大量时间在 `syscall.Read`
**原因**: 连接已关闭,但循环继续调用 Read立即返回错误
### 2. Channel 阻塞循环
**特征**:
- CPU 占用 20-40%
- goroutine 堆栈在 `select` 或 channel 操作
- 大量 goroutine 在 `runtime.gopark`
**原因**: channel 满了或没有接收者,发送阻塞
### 3. Goroutine 泄漏
**特征**:
- goroutine 数量持续增长
- 内存占用增长
- CPU 占用逐渐升高
**原因**: goroutine 没有正确退出
## 快速命令参考
```bash
# 查看 goroutine 数量
curl -s http://localhost:6060/debug/pprof/ | grep goroutine
# 查看详细的 goroutine 堆栈
curl http://localhost:6060/debug/pprof/goroutine?debug=2 | less
# 收集 CPU profile 并分析
curl http://localhost:6060/debug/pprof/profile?seconds=10 | go tool pprof -http=:8080 -
# 查看实时监控日志
tail -f /path/to/server.log | grep "监控"
# 统计当前各种 goroutine 的数量
curl -s http://localhost:6060/debug/pprof/goroutine?debug=2 | \
grep -E "^goroutine|^\w+\(" | \
paste - - | \
awk '{print $NF}' | \
sort | uniq -c | sort -rn
# 只看 port-forward 相关的 goroutine
curl -s http://localhost:6060/debug/pprof/goroutine?debug=2 | \
grep -A 15 "port-forward"
```
## 监控脚本
创建一个监控脚本 `monitor.sh`:
```bash
#!/bin/bash
echo "开始监控 go-tunnel..."
echo "时间戳 | Goroutines | CPU%"
echo "------|------------|-----"
while true; do
# 获取 goroutine 数量
GOROUTINES=$(curl -s http://localhost:6060/debug/pprof/goroutine | \
grep -oP 'goroutine profile: total \K\d+' || echo "N/A")
# 获取进程 CPU 占用
PID=$(pgrep -f "server.*9000" | head -1)
if [ ! -z "$PID" ]; then
CPU=$(ps -p $PID -o %cpu= || echo "N/A")
else
CPU="N/A"
fi
# 输出
echo "$(date +%H:%M:%S) | $GOROUTINES | $CPU%"
sleep 2
done
```
使用方法:
```bash
chmod +x monitor.sh
./monitor.sh
```
## 故障排查清单
- [ ] 检查日志中的 goroutine 数量是否异常增长
- [ ] 抓取 goroutine 堆栈,查看是否有重复的堆栈
- [ ] 使用 CPU profile 找出占用最多的函数
- [ ] 检查是否有 goroutine 卡在 Read/Write 操作
- [ ] 验证连接关闭后,相关 goroutine 是否退出
- [ ] 检查 channel 是否有阻塞
- [ ] 使用 trace 查看详细的执行流程
## 其他有用的命令
```bash
# 查看所有可用的 pprof endpoints
curl http://localhost:6060/debug/pprof/
# 查看程序版本和编译信息
curl http://localhost:6060/debug/pprof/cmdline
# 查看当前的调度器状态
GODEBUG=schedtrace=1000 ./server
# 查看 GC 统计
curl http://localhost:6060/debug/pprof/heap?debug=1 | head -30
```
## 参考资料
- [Go pprof 官方文档](https://golang.org/pkg/net/http/pprof/)
- [Go tool pprof 使用指南](https://github.com/google/pprof/blob/master/doc/README.md)
- [Go 性能分析最佳实践](https://go.dev/blog/pprof)
## 下一步
在发现问题后:
1. 记录问题出现时的堆栈信息
2. 确认是哪个函数/循环导致的
3. 检查该位置的退出条件
4. 验证错误处理逻辑
5. 添加更多的日志和退出检查

278
TRAFFIC_MONITORING.md Normal file
View File

@ -0,0 +1,278 @@
# 流量监控功能文档
## 概述
新增了实时流量监控功能,可以查看每个端口映射和隧道的流量统计,并以可视化图表的形式展示。
## 新增功能
### 1. 流量统计
系统自动统计以下数据:
- **隧道流量**: 通过tunnel发送/接收的总字节数
- **端口映射流量**: 每个端口映射单独的流量统计
- **总流量**: 所有流量的汇总
### 2. API 接口
#### GET /api/stats/traffic
获取当前最新的流量统计数据。
**响应示例**:
```json
{
"success": true,
"message": "获取流量统计成功",
"data": {
"tunnel": {
"bytes_sent": 1048576,
"bytes_received": 2097152,
"last_update": 1697462400
},
"mappings": [
{
"port": 30009,
"bytes_sent": 524288,
"bytes_received": 1048576,
"last_update": 1697462400
}
],
"total_sent": 1572864,
"total_received": 3145728,
"timestamp": 1697462400
}
}
```
#### GET /api/stats/monitor
访问流量监控Web页面提供可视化的实时监控界面。
**特性**:
- 📊 实时流量趋势图表
- 📈 流量速率计算 (KB/s)
- 🔄 每3秒自动刷新
- 📱 响应式设计,支持移动设备
- 🎨 美观的UI界面
### 3. Web监控页面
访问地址: `http://localhost:8080/api/stats/monitor`
**页面内容**:
1. **总览卡片**
- 总发送流量
- 总接收流量
- 隧道发送流量
- 隧道接收流量
2. **实时流量趋势图**
- 显示最近20个数据点
- 发送/接收速率曲线
- 单位: KB/s
3. **端口映射详情**
- 每个端口的流量统计
- 实时更新
## 使用方法
### 1. 启动服务
```bash
cd /home/qcqcqc/workspace/go-tunnel/src
make build
make run-server
```
### 2. 访问监控页面
在浏览器中打开:
```
http://localhost:8080/api/stats/monitor
```
### 3. API调用示例
使用curl获取流量数据:
```bash
# 获取流量统计
curl http://localhost:8080/api/stats/traffic
# 格式化输出
curl http://localhost:8080/api/stats/traffic | jq .
```
使用JavaScript获取数据:
```javascript
fetch('http://localhost:8080/api/stats/traffic')
.then(res => res.json())
.then(data => {
console.log('总发送:', data.data.total_sent);
console.log('总接收:', data.data.total_received);
});
```
### 4. 集成到自己的前端
你可以定期调用 `/api/stats/traffic` 接口获取数据,然后在自己的前端页面中展示:
```javascript
// 每3秒获取一次数据
setInterval(async () => {
const response = await fetch('http://localhost:8080/api/stats/traffic');
const result = await response.json();
if (result.success) {
const data = result.data;
// 更新UI
updateTotalSent(data.total_sent);
updateTotalReceived(data.total_received);
// 更新图表
updateChart(data);
// 更新端口列表
updatePortList(data.mappings);
}
}, 3000);
```
## 实现细节
### 流量统计机制
1. **Tunnel层统计**
- 在 `writeTunnelMessage` 中记录发送字节数
- 在 `readTunnelMessage` 中记录接收字节数
- 使用 `atomic.AddUint64` 保证并发安全
2. **Forwarder层统计**
- 在 `io.Copy` 返回后记录传输字节数
- 分别统计客户端→目标和目标→客户端的流量
3. **数据结构**
```go
type TrafficStats struct {
BytesSent uint64 // 发送字节数
BytesReceived uint64 // 接收字节数
LastUpdate int64 // 最后更新时间
}
```
### 性能考虑
- 使用原子操作 (`sync/atomic`) 避免锁竞争
- 统计开销极小,几乎不影响转发性能
- 只在需要时才计算速率
## 测试验证
### 1. 基础测试
```bash
# 1. 启动服务器和客户端
make run-server # 终端1
make run-client # 终端2
# 2. 创建端口映射
curl -X POST http://localhost:8080/api/mapping/create \
-H "Content-Type: application/json" \
-d '{
"source_port": 30009,
"target_host": "127.0.0.1",
"target_port": 22,
"use_tunnel": true
}'
# 3. 通过映射传输一些数据
ssh root@localhost -p 30009
# 4. 查看流量统计
curl http://localhost:8080/api/stats/traffic | jq .
```
### 2. 压力测试
```bash
# 使用 scp 传输大文件测试流量统计
dd if=/dev/zero of=test.dat bs=1M count=100
scp -P 30009 test.dat root@localhost:/tmp/
# 观察监控页面的实时变化
```
### 3. 长时间运行测试
```bash
# 保持监控页面打开,观察:
# - 图表是否正常更新
# - 数据是否累积正确
# - 内存占用是否稳定
```
## 故障排查
### 问题1: 流量统计不准确
**可能原因**:
- 连接未正常关闭
- 统计溢出极少见uint64可存储18EB
**解决方法**:
```bash
# 重启服务器重置统计
```
### 问题2: 监控页面无法访问
**检查**:
```bash
# 确认服务器正在运行
curl http://localhost:8080/health
# 检查端口是否被占用
netstat -tlnp | grep 8080
```
### 问题3: 图表不更新
**检查**:
1. 打开浏览器开发者工具 (F12)
2. 查看 Console 是否有错误
3. 查看 Network 标签,确认 API 请求成功
## 未来改进
可能的增强功能:
- [ ] 流量统计持久化(保存到数据库)
- [ ] 历史流量查询
- [ ] 流量告警(超过阈值时通知)
- [ ] 导出流量报表
- [ ] 按时间段统计(小时/天/月)
- [ ] WebSocket实时推送减少轮询
## 相关文件
- `src/server/stats/stats.go` - 流量统计数据结构
- `src/server/tunnel/tunnel.go` - Tunnel流量统计实现
- `src/server/forwarder/forwarder.go` - Forwarder流量统计实现
- `src/server/api/api.go` - API接口和监控页面
## 技术栈
- **后端**: Go 1.x
- **前端**: 原生 JavaScript + Chart.js 4.4.0
- **图表库**: Chart.js (CDN)
- **样式**: CSS3 (渐变、动画)
## 许可证
与主项目相同
## 更新日期
2025-10-16

122
monitor.sh Normal file
View File

@ -0,0 +1,122 @@
#!/bin/bash
# Go Tunnel 监控脚本
# 用于实时监控 goroutine 数量和 CPU 占用
SERVER_PPROF="http://localhost:6060"
CLIENT_PPROF="http://localhost:6061"
# 颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo "========================================"
echo " Go Tunnel 实时监控"
echo "========================================"
echo ""
echo "监控项:"
echo " - Goroutine 数量(正常应该 < 15"
echo " - CPU 占用率"
echo " - 进程状态"
echo ""
echo "按 Ctrl+C 停止监控"
echo "========================================"
echo ""
printf "%-12s | %-20s | %-20s | %-10s\n" "时间" "服务器" "客户端" "CPU占用"
printf "%-12s | %-20s | %-20s | %-10s\n" "--------" "--------------------" "--------------------" "----------"
while true; do
TIMESTAMP=$(date +%H:%M:%S)
# 获取服务器 goroutine 数量
SERVER_GOROUTINES=$(curl -s --connect-timeout 1 $SERVER_PPROF/debug/pprof/goroutine 2>/dev/null | \
grep -oP 'goroutine profile: total \K\d+' || echo "N/A")
# 获取客户端 goroutine 数量
CLIENT_GOROUTINES=$(curl -s --connect-timeout 1 $CLIENT_PPROF/debug/pprof/goroutine 2>/dev/null | \
grep -oP 'goroutine profile: total \K\d+' || echo "N/A")
# 获取服务器进程 CPU 占用
SERVER_PID=$(pgrep -f "bin/server" | head -1)
if [ ! -z "$SERVER_PID" ]; then
SERVER_CPU=$(ps -p $SERVER_PID -o %cpu= | tr -d ' ')
else
SERVER_CPU="未运行"
fi
# 获取客户端进程 CPU 占用
CLIENT_PID=$(pgrep -f "bin/client" | head -1)
if [ ! -z "$CLIENT_PID" ]; then
CLIENT_CPU=$(ps -p $CLIENT_PID -o %cpu= | tr -d ' ')
else
CLIENT_CPU="未运行"
fi
# 格式化 goroutine 信息
if [ "$SERVER_GOROUTINES" != "N/A" ]; then
SERVER_INFO="Goroutines: $SERVER_GOROUTINES"
# 如果 goroutine 数量异常,高亮显示
if [ "$SERVER_GOROUTINES" -gt 20 ]; then
SERVER_INFO="${RED}Goroutines: $SERVER_GOROUTINES${NC}"
elif [ "$SERVER_GOROUTINES" -gt 15 ]; then
SERVER_INFO="${YELLOW}Goroutines: $SERVER_GOROUTINES${NC}"
else
SERVER_INFO="${GREEN}Goroutines: $SERVER_GOROUTINES${NC}"
fi
else
SERVER_INFO="${RED}离线${NC}"
fi
if [ "$CLIENT_GOROUTINES" != "N/A" ]; then
CLIENT_INFO="Goroutines: $CLIENT_GOROUTINES"
if [ "$CLIENT_GOROUTINES" -gt 20 ]; then
CLIENT_INFO="${RED}Goroutines: $CLIENT_GOROUTINES${NC}"
elif [ "$CLIENT_GOROUTINES" -gt 15 ]; then
CLIENT_INFO="${YELLOW}Goroutines: $CLIENT_GOROUTINES${NC}"
else
CLIENT_INFO="${GREEN}Goroutines: $CLIENT_GOROUTINES${NC}"
fi
else
CLIENT_INFO="${RED}离线${NC}"
fi
# CPU 信息
if [ "$SERVER_CPU" != "未运行" ]; then
CPU_INFO="S:${SERVER_CPU}%"
# 检查 CPU 占用是否过高
CPU_VALUE=$(echo $SERVER_CPU | cut -d. -f1)
if [ "$CPU_VALUE" -gt 50 ] 2>/dev/null; then
CPU_INFO="${RED}S:${SERVER_CPU}% ⚠${NC}"
elif [ "$CPU_VALUE" -gt 10 ] 2>/dev/null; then
CPU_INFO="${YELLOW}S:${SERVER_CPU}%${NC}"
else
CPU_INFO="${GREEN}S:${SERVER_CPU}%${NC}"
fi
else
CPU_INFO="${RED}未运行${NC}"
fi
if [ "$CLIENT_CPU" != "未运行" ]; then
if [ "$SERVER_CPU" != "未运行" ]; then
CPU_INFO="${CPU_INFO} "
else
CPU_INFO=""
fi
CLIENT_CPU_VALUE=$(echo $CLIENT_CPU | cut -d. -f1)
if [ "$CLIENT_CPU_VALUE" -gt 50 ] 2>/dev/null; then
CPU_INFO="${CPU_INFO}${RED}C:${CLIENT_CPU}% ⚠${NC}"
elif [ "$CLIENT_CPU_VALUE" -gt 10 ] 2>/dev/null; then
CPU_INFO="${CPU_INFO}${YELLOW}C:${CLIENT_CPU}%${NC}"
else
CPU_INFO="${CPU_INFO}${GREEN}C:${CLIENT_CPU}%${NC}"
fi
fi
# 输出监控信息
printf "%-12s | %-35s | %-35s | %-30s\n" "$TIMESTAMP" "$SERVER_INFO" "$CLIENT_INFO" "$CPU_INFO"
sleep 2
done

View File

@ -3,6 +3,7 @@ package main
import (
"flag"
"log"
_ "net/http/pprof" // 导入pprof用于性能分析
"os"
"os/signal"
"port-forward/client/tunnel"
@ -27,8 +28,28 @@ func main() {
log.Fatalf("启动隧道客户端失败: %v", err)
}
// // 启动 pprof 调试服务器(用于性能分析和调试)
// pprofPort := 6061
// go func() {
// log.Printf("启动 pprof 调试服务器: http://localhost:%d/debug/pprof/", pprofPort)
// if err := http.ListenAndServe(":6061", nil); err != nil {
// log.Printf("pprof 服务器启动失败: %v", err)
// }
// }()
// // 启动 goroutine 监控
// go func() {
// ticker := time.NewTicker(10 * time.Second)
// defer ticker.Stop()
// for range ticker.C {
// numGoroutines := runtime.NumGoroutine()
// log.Printf("[监控] 当前 Goroutine 数量: %d", numGoroutines)
// }
// }()
log.Println("===========================================")
log.Println("隧道客户端运行中...")
// log.Printf("调试接口: http://localhost:%d/debug/pprof/", pprofPort)
log.Println("按 Ctrl+C 退出")
log.Println("===========================================")

View File

@ -185,6 +185,7 @@ func (c *Client) handleServerRead(conn net.Conn, connCtx context.Context, connCa
defer conn.Close()
for {
// 检查是否应该退出
select {
case <-c.ctx.Done():
return
@ -193,15 +194,19 @@ func (c *Client) handleServerRead(conn net.Conn, connCtx context.Context, connCa
default:
}
// 设置读取超时,避免无限阻塞
conn.SetReadDeadline(time.Now().Add(ReadTimeout))
msg, err := c.readTunnelMessage(conn)
if err != nil {
if err != io.EOF {
if err != io.EOF && !isTimeout(err) {
log.Printf("读取隧道消息失败: %v", err)
}
connCancel() // 通知其他协程退出
return
}
// 重置读取超时
conn.SetReadDeadline(time.Time{})
c.handleTunnelMessage(msg)
}
}
@ -265,6 +270,10 @@ func (c *Client) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
// writeTunnelMessage 写入隧道消息
func (c *Client) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
// 设置写入超时,防止阻塞
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
defer conn.SetWriteDeadline(time.Time{}) // 重置超时
// 构建消息头
header := make([]byte, HeaderSize)
header[0] = msg.Version
@ -273,13 +282,13 @@ func (c *Client) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
// 写入消息头
if _, err := conn.Write(header); err != nil {
return err
return fmt.Errorf("写入消息头失败: %w", err)
}
// 写入数据
if msg.Length > 0 && msg.Data != nil {
if _, err := conn.Write(msg.Data); err != nil {
return err
return fmt.Errorf("写入消息数据失败: %w", err)
}
}
@ -368,11 +377,24 @@ func (c *Client) handleDataMessage(msg *TunnelMessage) {
c.connMu.RUnlock()
if !exists {
log.Printf("收到未知连接的数据: %d", connID)
log.Printf("收到未知连接的数据: %d发送关闭消息", connID)
// 连接不存在,发送关闭消息通知对端
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
closeMsg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeClose,
Length: 4,
Data: closeData,
}
select {
case c.sendChan <- closeMsg:
default:
}
return
}
if _, err := connection.Conn.Write(data); err != nil {
if _, err := connection.Conn.Write(data); err != nil {
log.Printf("写入目标连接失败 (ID=%d): %v", connID, err)
c.closeConnection(connID)
}
@ -391,19 +413,9 @@ func (c *Client) handleCloseMessage(msg *TunnelMessage) {
// handleKeepAlive 处理心跳消息
func (c *Client) handleKeepAlive(msg *TunnelMessage) {
// 回应心跳
response := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeKeepAlive,
Length: 0,
Data: nil,
}
select {
case c.sendChan <- response:
default:
log.Printf("发送心跳响应失败: 发送队列已满")
}
// 客户端收到服务器的心跳响应,不需要再回应
// 这样避免心跳消息的无限循环
// log.Printf("收到服务器心跳响应")
}
// sendConnectResponse 发送连接响应
@ -433,22 +445,38 @@ func (c *Client) forwardData(connection *LocalConnection) {
buffer := make([]byte, 32*1024)
for {
select {
case <-connection.closeChan:
return
case <-c.ctx.Done():
return
case <-connection.closeChan:
return
default:
}
// 设置读取超时
connection.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
n, err := connection.Conn.Read(buffer)
if err != nil {
if err != io.EOF && !isTimeout(err) {
// 任何错误都应该终止转发
if err == io.EOF {
log.Printf("目标连接正常关闭 (ID=%d)", connection.ID)
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
log.Printf("目标连接超时 (ID=%d)", connection.ID)
} else {
log.Printf("读取目标连接失败 (ID=%d): %v", connection.ID, err)
}
return
}
// 读取到0字节连接已关闭
if n == 0 {
log.Printf("目标连接已关闭 (ID=%d, 读取0字节)", connection.ID)
return
}
// 重置读取超时
connection.Conn.SetReadDeadline(time.Time{})
// 发送数据到隧道
dataMsg := make([]byte, 4+n)
binary.BigEndian.PutUint32(dataMsg[0:4], connection.ID)
@ -463,11 +491,14 @@ func (c *Client) forwardData(connection *LocalConnection) {
select {
case c.sendChan <- msg:
// 数据已发送
case <-time.After(5 * time.Second):
log.Printf("发送数据超时 (ID=%d)", connection.ID)
return
case <-c.ctx.Done():
return
case <-connection.closeChan:
return
}
}
}
@ -481,10 +512,18 @@ func (c *Client) closeConnection(connID uint32) {
connection.closeOnce.Do(func() {
close(connection.closeChan)
})
connection.Conn.Close()
// 确保连接被关闭
if connection.Conn != nil {
connection.Conn.Close()
}
}
c.connMu.Unlock()
if !exists {
// 连接不存在,无需发送关闭消息
return
}
// 发送关闭消息
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
@ -498,12 +537,11 @@ func (c *Client) closeConnection(connID uint32) {
select {
case c.sendChan <- msg:
default:
// 发送队列满,忽略
}
if exists {
log.Printf("连接已关闭: ID=%d", connID)
case <-time.After(1 * time.Second):
log.Printf("发送关闭消息超时: ID=%d", connID)
case <-c.ctx.Done():
log.Printf("客户端关闭,跳过发送关闭消息: ID=%d", connID)
}
}

View File

@ -152,16 +152,20 @@ func TestClientHandleConnectRequest(t *testing.T) {
client := NewClient("127.0.0.1:9000")
// 创建连接请求消息
// 创建连接请求消息新格式connID + port + hostLen + host
connID := uint32(12345)
reqData := make([]byte, 6)
targetHost := "127.0.0.1"
targetHostBytes := []byte(targetHost)
reqData := make([]byte, 7+len(targetHostBytes))
binary.BigEndian.PutUint32(reqData[0:4], connID)
binary.BigEndian.PutUint16(reqData[4:6], uint16(localPort))
reqData[6] = byte(len(targetHostBytes))
copy(reqData[7:], targetHostBytes)
msg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeConnectRequest,
Length: 6,
Length: uint32(len(reqData)),
Data: reqData,
}

View File

@ -8,6 +8,7 @@ import (
"net/http"
"port-forward/server/db"
"port-forward/server/forwarder"
"port-forward/server/stats"
"port-forward/server/tunnel"
"strconv"
"time"
@ -58,6 +59,8 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/mapping/create", h.handleCreateMapping)
mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping)
mux.HandleFunc("/api/mapping/list", h.handleListMappings)
mux.HandleFunc("/api/stats/traffic", h.handleGetTrafficStats)
mux.HandleFunc("/api/stats/monitor", h.handleTrafficMonitor)
mux.HandleFunc("/health", h.handleHealth)
}
@ -277,4 +280,56 @@ func Start(handler *Handler, port int) error {
log.Printf("HTTP API 服务启动: 端口 %d", port)
return server.ListenAndServe()
}
// handleGetTrafficStats 获取流量统计
func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
h.writeError(w, http.StatusMethodNotAllowed, "只支持 GET 方法")
return
}
// 获取隧道流量统计
var tunnelStats stats.TrafficStats
if h.tunnelServer != nil {
tunnelStats = h.tunnelServer.GetTrafficStats()
}
// 获取所有端口映射的流量统计
forwarderStats := h.forwarderMgr.GetAllTrafficStats()
// 构建响应
mappings := make([]stats.PortTrafficStats, 0, len(forwarderStats))
var totalSent, totalReceived uint64
for port, stat := range forwarderStats {
mappings = append(mappings, stats.PortTrafficStats{
Port: port,
BytesSent: stat.BytesSent,
BytesReceived: stat.BytesReceived,
LastUpdate: stat.LastUpdate,
})
totalSent += stat.BytesSent
totalReceived += stat.BytesReceived
}
// 加上隧道的流量
totalSent += tunnelStats.BytesSent
totalReceived += tunnelStats.BytesReceived
response := stats.AllTrafficStats{
Tunnel: tunnelStats,
Mappings: mappings,
TotalSent: totalSent,
TotalReceived: totalReceived,
Timestamp: time.Now().Unix(),
}
h.writeSuccess(w, "获取流量统计成功", response)
}
// handleTrafficMonitor 流量监控页面
func (h *Handler) handleTrafficMonitor(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, html)
}

View File

@ -260,7 +260,7 @@ func TestHandleCreateMappingInvalidIP(t *testing.T) {
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetHost: "invalid-ip",
TargetHost: "", // 使用空字符串而不是无效域名避免DNS查询超时
}
body, _ := json.Marshal(reqBody)

341
src/server/api/html.go Normal file
View File

@ -0,0 +1,341 @@
package api
const html = `
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>流量监控 - Go Tunnel</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.0/dist/chart.umd.min.js"></script>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1400px;
margin: 0 auto;
}
h1 {
color: white;
text-align: center;
margin-bottom: 30px;
font-size: 2.5em;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
.stats-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 20px;
margin-bottom: 30px;
}
.stat-card {
background: rgba(255, 255, 255, 0.95);
border-radius: 15px;
padding: 25px;
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
transition: transform 0.3s;
}
.stat-card:hover {
transform: translateY(-5px);
}
.stat-card h3 {
color: #667eea;
margin-bottom: 15px;
font-size: 1.3em;
border-bottom: 2px solid #667eea;
padding-bottom: 10px;
}
.stat-value {
font-size: 2em;
font-weight: bold;
color: #333;
margin: 10px 0;
}
.stat-label {
color: #666;
font-size: 0.9em;
margin-top: 5px;
}
.chart-container {
background: rgba(255, 255, 255, 0.95);
border-radius: 15px;
padding: 25px;
margin-bottom: 20px;
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
}
.chart-container h2 {
color: #667eea;
margin-bottom: 20px;
font-size: 1.5em;
}
.mapping-list {
background: rgba(255, 255, 255, 0.95);
border-radius: 15px;
padding: 25px;
box-shadow: 0 8px 32px rgba(0,0,0,0.1);
}
.mapping-item {
background: #f8f9fa;
border-radius: 10px;
padding: 15px;
margin-bottom: 15px;
border-left: 4px solid #667eea;
}
.mapping-item h4 {
color: #667eea;
margin-bottom: 10px;
}
.mapping-stats {
display: flex;
justify-content: space-between;
color: #666;
font-size: 0.9em;
}
.status-indicator {
display: inline-block;
width: 12px;
height: 12px;
border-radius: 50%;
background: #4CAF50;
animation: pulse 2s infinite;
}
@keyframes pulse {
0%, 100% {
opacity: 1;
}
50% {
opacity: 0.5;
}
}
.update-time {
text-align: center;
color: white;
margin-top: 20px;
font-size: 0.9em;
}
</style>
</head>
<body>
<div class="container">
<h1><span class="status-indicator"></span> 流量监控面板</h1>
<div class="stats-grid">
<div class="stat-card">
<h3>总发送流量</h3>
<div class="stat-value" id="total-sent">0 B</div>
<div class="stat-label">Total Sent</div>
</div>
<div class="stat-card">
<h3>总接收流量</h3>
<div class="stat-value" id="total-received">0 B</div>
<div class="stat-label">Total Received</div>
</div>
<div class="stat-card">
<h3>隧道发送</h3>
<div class="stat-value" id="tunnel-sent">0 B</div>
<div class="stat-label">Tunnel Sent</div>
</div>
<div class="stat-card">
<h3>隧道接收</h3>
<div class="stat-value" id="tunnel-received">0 B</div>
<div class="stat-label">Tunnel Received</div>
</div>
</div>
<div class="chart-container">
<h2>实时流量趋势</h2>
<canvas id="trafficChart"></canvas>
</div>
<div class="mapping-list">
<h2 style="color: #667eea; margin-bottom: 20px;">端口映射流量</h2>
<div id="mapping-list"></div>
</div>
<div class="update-time">
最后更新: <span id="update-time">-</span> | 3 秒自动刷新
</div>
</div>
<script>
// 初始化图表
const ctx = document.getElementById('trafficChart');
const chart = new Chart(ctx, {
type: 'line',
data: {
labels: [],
datasets: [
{
label: '发送 (KB/s)',
data: [],
borderColor: '#667eea',
backgroundColor: 'rgba(102, 126, 234, 0.1)',
tension: 0.4,
fill: true
},
{
label: '接收 (KB/s)',
data: [],
borderColor: '#764ba2',
backgroundColor: 'rgba(118, 75, 162, 0.1)',
tension: 0.4,
fill: true
}
]
},
options: {
responsive: true,
maintainAspectRatio: true,
aspectRatio: 2.5,
plugins: {
legend: {
display: true,
position: 'top',
}
},
scales: {
y: {
beginAtZero: true,
ticks: {
callback: function(value) {
return value.toFixed(2) + ' KB/s';
}
}
},
x: {
display: true
}
}
}
});
let lastData = null;
const maxDataPoints = 20;
// 格式化字节数
function formatBytes(bytes) {
if (bytes === 0) return '0 B';
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB', 'TB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
}
// 更新统计数据
function updateStats(data) {
document.getElementById('total-sent').textContent = formatBytes(data.total_sent);
document.getElementById('total-received').textContent = formatBytes(data.total_received);
document.getElementById('tunnel-sent').textContent = formatBytes(data.tunnel.bytes_sent);
document.getElementById('tunnel-received').textContent = formatBytes(data.tunnel.bytes_received);
// 更新时间
const now = new Date();
document.getElementById('update-time').textContent = now.toLocaleTimeString('zh-CN');
}
// 更新图表
function updateChart(data) {
const now = new Date().toLocaleTimeString('zh-CN');
// 计算速率 (如果有上次数据)
let sendRate = 0;
let recvRate = 0;
if (lastData) {
const timeDiff = 3; // 3秒间隔
sendRate = (data.total_sent - lastData.total_sent) / timeDiff / 1024; // KB/s
recvRate = (data.total_received - lastData.total_received) / timeDiff / 1024; // KB/s
}
lastData = data;
// 添加新数据点
chart.data.labels.push(now);
chart.data.datasets[0].data.push(sendRate);
chart.data.datasets[1].data.push(recvRate);
// 限制数据点数量
if (chart.data.labels.length > maxDataPoints) {
chart.data.labels.shift();
chart.data.datasets[0].data.shift();
chart.data.datasets[1].data.shift();
}
chart.update('none'); // 无动画更新,更流畅
}
// 更新端口映射列表
function updateMappings(mappings) {
const container = document.getElementById('mapping-list');
if (mappings.length === 0) {
container.innerHTML = '<p style="color: #999; text-align: center;">暂无端口映射</p>';
return;
}
container.innerHTML = mappings.map(m =>
'<div class="mapping-item">' +
'<h4>端口 ' + m.port + '</h4>' +
'<div class="mapping-stats">' +
'<span>发送: ' + formatBytes(m.bytes_sent) + '</span>' +
'<span>接收: ' + formatBytes(m.bytes_received) + '</span>' +
'</div>' +
'</div>'
).join('');
}
// 获取流量数据
async function fetchTrafficData() {
try {
const response = await fetch('/api/stats/traffic');
const result = await response.json();
if (result.success) {
updateStats(result.data);
updateChart(result.data);
updateMappings(result.data.mappings || []);
}
} catch (error) {
console.error('获取流量数据失败:', error);
}
}
// 初始加载
fetchTrafficData();
// 定时刷新 (每3秒)
setInterval(fetchTrafficData, 3000);
</script>
</body>
</html>
`

View File

@ -200,7 +200,7 @@ func TestValidateConfig(t *testing.T) {
{
name: "端口范围过大",
config: Config{
PortRange: PortRangeConfig{From: 1, End: 20000},
PortRange: PortRangeConfig{From: 1, End: 40000},
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
API: APIConfig{ListenPort: 8080},
Database: DatabaseConfig{Path: "./data/test.db"},

View File

@ -6,7 +6,9 @@ import (
"io"
"log"
"net"
"port-forward/server/stats"
"sync"
"sync/atomic"
"time"
)
@ -14,6 +16,7 @@ import (
type TunnelServer interface {
ForwardConnection(clientConn net.Conn, targetIP string, targetPort int) error
IsConnected() bool
GetTrafficStats() stats.TrafficStats
}
// Forwarder 端口转发器
@ -27,6 +30,10 @@ type Forwarder struct {
wg sync.WaitGroup
tunnelServer TunnelServer
useTunnel bool
// 流量统计(使用原子操作)
bytesSent uint64 // 发送字节数
bytesReceived uint64 // 接收字节数
}
// NewForwarder 创建新的端口转发器
@ -142,26 +149,44 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
defer targetConn.Close()
// 双向转发
errChan := make(chan error, 2)
var wg sync.WaitGroup
wg.Add(2)
// 客户端 -> 目标
go func() {
_, err := io.Copy(targetConn, clientConn)
errChan <- err
defer wg.Done()
n, _ := io.Copy(targetConn, clientConn)
atomic.AddUint64(&f.bytesSent, uint64(n))
// 关闭目标连接的写入端,通知对方不会再发送数据
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
// 目标 -> 客户端
go func() {
_, err := io.Copy(clientConn, targetConn)
errChan <- err
defer wg.Done()
n, _ := io.Copy(clientConn, targetConn)
atomic.AddUint64(&f.bytesReceived, uint64(n))
// 关闭客户端连接的写入端
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
// 等待任一方向完成或出错
// 创建一个 channel 来等待完成
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
// 等待两个方向都完成或上下文取消
select {
case <-errChan:
// 连接已关闭或出错
case <-done:
// 两个方向都已完成
case <-f.ctx.Done():
// 转发器被停止
// 转发器被停止,连接会在 defer 中关闭
}
}
@ -279,4 +304,26 @@ func (m *Manager) StopAll() {
}
m.forwarders = make(map[int]*Forwarder)
}
// GetTrafficStats 获取流量统计信息
func (f *Forwarder) GetTrafficStats() stats.TrafficStats {
return stats.TrafficStats{
BytesSent: atomic.LoadUint64(&f.bytesSent),
BytesReceived: atomic.LoadUint64(&f.bytesReceived),
LastUpdate: time.Now().Unix(),
}
}
// GetAllTrafficStats 获取所有转发器的流量统计
func (m *Manager) GetAllTrafficStats() map[int]stats.TrafficStats {
m.mu.RLock()
defer m.mu.RUnlock()
statsMap := make(map[int]stats.TrafficStats)
for port, forwarder := range m.forwarders {
statsMap[port] = forwarder.GetTrafficStats()
}
return statsMap
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net"
"port-forward/server/stats"
"testing"
"time"
)
@ -20,8 +21,12 @@ func TestNewForwarder(t *testing.T) {
t.Errorf("源端口不正确,期望 8080得到 %d", fwd.sourcePort)
}
if fwd.targetHost != "192.168.1.100:80" {
t.Errorf("目标地址不正确,期望 192.168.1.100:80得到 %s", fwd.targetHost)
if fwd.targetHost != "192.168.1.100" {
t.Errorf("目标主机不正确,期望 192.168.1.100,得到 %s", fwd.targetHost)
}
if fwd.targetPort != 80 {
t.Errorf("目标端口不正确,期望 80得到 %d", fwd.targetPort)
}
if fwd.useTunnel {
@ -43,6 +48,10 @@ func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp strin
func (m *mockTunnelServer) IsConnected() bool {
return m.connected
}
func (m *mockTunnelServer) GetTrafficStats() stats.TrafficStats {
return stats.TrafficStats{}
}
// TestNewTunnelForwarder 测试创建隧道转发器
func TestNewTunnelForwarder(t *testing.T) {

View File

@ -4,6 +4,7 @@ import (
"context"
"flag"
"log"
_ "net/http/pprof" // 导入pprof用于性能分析
"os"
"os/signal"
"port-forward/server/api"
@ -64,8 +65,6 @@ func main() {
continue
}
log.Printf("恢复端口映射: %d -> %s:%d (tunnel: %v)", mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, mapping.UseTunnel)
var err error
if mapping.UseTunnel {
// 隧道模式:检查隧道服务器是否可用
@ -103,10 +102,30 @@ func main() {
}
}()
// // 启动 pprof 调试服务器(用于性能分析和调试)
// pprofPort := 6060
// go func() {
// log.Printf("启动 pprof 调试服务器: http://localhost:%d/debug/pprof/", pprofPort)
// if err := http.ListenAndServe(":6060", nil); err != nil {
// log.Printf("pprof 服务器启动失败: %v", err)
// }
// }()
// // 启动 goroutine 监控
// go func() {
// ticker := time.NewTicker(10 * time.Second)
// defer ticker.Stop()
// for range ticker.C {
// numGoroutines := runtime.NumGoroutine()
// log.Printf("[监控] 当前 Goroutine 数量: %d", numGoroutines)
// }
// }()
log.Println("===========================================")
log.Printf("服务器启动成功!")
log.Printf("端口范围: %d-%d", cfg.PortRange.From, cfg.PortRange.End)
log.Printf("HTTP API: http://localhost:%d", cfg.API.ListenPort)
// log.Printf("调试接口: http://localhost:%d/debug/pprof/", pprofPort)
if cfg.Tunnel.Enabled {
log.Printf("隧道服务: 端口 %d", cfg.Tunnel.ListenPort)
}

25
src/server/stats/stats.go Normal file
View File

@ -0,0 +1,25 @@
package stats
// TrafficStats 流量统计
type TrafficStats struct {
BytesSent uint64 `json:"bytes_sent"` // 发送字节数
BytesReceived uint64 `json:"bytes_received"` // 接收字节数
LastUpdate int64 `json:"last_update"` // 最后更新时间Unix时间戳
}
// PortTrafficStats 端口流量统计
type PortTrafficStats struct {
Port int `json:"port"`
BytesSent uint64 `json:"bytes_sent"`
BytesReceived uint64 `json:"bytes_received"`
LastUpdate int64 `json:"last_update"`
}
// AllTrafficStats 所有流量统计
type AllTrafficStats struct {
Tunnel TrafficStats `json:"tunnel"` // 隧道流量
Mappings []PortTrafficStats `json:"mappings"` // 端口映射流量
TotalSent uint64 `json:"total_sent"` // 总发送
TotalReceived uint64 `json:"total_received"` // 总接收
Timestamp int64 `json:"timestamp"` // 时间戳
}

View File

@ -7,7 +7,9 @@ import (
"io"
"log"
"net"
"port-forward/server/stats"
"sync"
"sync/atomic"
"time"
)
@ -37,7 +39,7 @@ const (
// 超时设置
ConnectTimeout = 10 * time.Second // 连接超时
ReadTimeout = 30 * time.Second // 读取超时
ReadTimeout = 300 * time.Second // 读取超时
KeepAliveInterval = 15 * time.Second // 心跳间隔
)
@ -111,6 +113,10 @@ type Server struct {
// 消息队列
sendChan chan *TunnelMessage
// 流量统计(使用原子操作)
bytesSent uint64 // 通过隧道发送的总字节数
bytesReceived uint64 // 通过隧道接收的总字节数
}
// NewServer 创建新的隧道服务器
@ -214,20 +220,25 @@ func (s *Server) handleTunnelRead(conn net.Conn) {
}()
for {
// 检查是否应该退出
select {
case <-s.ctx.Done():
return
default:
}
// 设置读取超时,避免无限阻塞
conn.SetReadDeadline(time.Now().Add(ReadTimeout))
msg, err := s.readTunnelMessage(conn)
if err != nil {
if err != io.EOF {
if err != io.EOF && !isTimeout(err) {
log.Printf("读取隧道消息失败: %v", err)
}
return
}
// 重置读取超时
conn.SetReadDeadline(time.Time{})
s.handleTunnelMessage(msg)
}
}
@ -256,6 +267,9 @@ func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
if _, err := io.ReadFull(conn, header); err != nil {
return nil, err
}
// 统计接收字节数
s.addBytesReceived(uint64(HeaderSize))
version := header[0]
msgType := header[1]
@ -276,6 +290,8 @@ func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
if _, err := io.ReadFull(conn, data); err != nil {
return nil, err
}
// 统计接收字节数
s.addBytesReceived(uint64(dataLen))
}
return &TunnelMessage{
@ -288,6 +304,10 @@ func (s *Server) readTunnelMessage(conn net.Conn) (*TunnelMessage, error) {
// writeTunnelMessage 写入隧道消息
func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
// 设置写入超时,防止阻塞
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
defer conn.SetWriteDeadline(time.Time{}) // 重置超时
// 构建消息头
header := make([]byte, HeaderSize)
header[0] = msg.Version
@ -296,14 +316,19 @@ func (s *Server) writeTunnelMessage(conn net.Conn, msg *TunnelMessage) error {
// 写入消息头
if _, err := conn.Write(header); err != nil {
return err
return fmt.Errorf("写入消息头失败: %w", err)
}
// 统计发送字节数
s.addBytesSent(uint64(HeaderSize))
// 写入数据
if msg.Length > 0 && msg.Data != nil {
if _, err := conn.Write(msg.Data); err != nil {
return err
return fmt.Errorf("写入消息数据失败: %w", err)
}
// 统计发送字节数
s.addBytesSent(uint64(msg.Length))
}
return nil
@ -401,7 +426,20 @@ func (s *Server) handleDataMessage(msg *TunnelMessage) {
s.connMu.RUnlock()
if !exists {
log.Printf("收到未知连接的数据: %d", connID)
log.Printf("收到未知连接的数据: %d发送关闭消息", connID)
// 连接不存在,发送关闭消息通知对端
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
closeMsg := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeClose,
Length: 4,
Data: closeData,
}
select {
case s.sendChan <- closeMsg:
default:
}
return
}
@ -425,7 +463,8 @@ func (s *Server) handleCloseMessage(msg *TunnelMessage) {
// handleKeepAlive 处理心跳消息
func (s *Server) handleKeepAlive(msg *TunnelMessage) {
// 回应心跳
// 服务器收到客户端的心跳请求,回应一次即可
// 不要形成心跳循环
response := &TunnelMessage{
Version: ProtocolVersion,
Type: MsgTypeKeepAlive,
@ -454,15 +493,31 @@ func (s *Server) forwardData(active *ActiveConnection) {
default:
}
// 设置读取超时
active.ClientConn.SetReadDeadline(time.Now().Add(ReadTimeout))
n, err := active.ClientConn.Read(buffer)
if err != nil {
if err != io.EOF && !isTimeout(err) {
// 任何错误都应该终止转发,包括超时
if err == io.EOF {
log.Printf("客户端连接正常关闭 (ID=%d)", active.ID)
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
log.Printf("客户端连接超时 (ID=%d)", active.ID)
} else {
log.Printf("读取客户端连接失败 (ID=%d): %v", active.ID, err)
}
return
}
// 读取到0字节连接已关闭
if n == 0 {
log.Printf("客户端连接已关闭 (ID=%d, 读取0字节)", active.ID)
return
}
// 重置读取超时
active.ClientConn.SetReadDeadline(time.Time{})
// 发送数据到隧道
dataMsg := make([]byte, 4+n)
binary.BigEndian.PutUint32(dataMsg[0:4], active.ID)
@ -477,6 +532,7 @@ func (s *Server) forwardData(active *ActiveConnection) {
select {
case s.sendChan <- msg:
// 数据已发送
case <-time.After(5 * time.Second):
log.Printf("发送数据超时 (ID=%d)", active.ID)
return
@ -492,10 +548,18 @@ func (s *Server) closeConnection(connID uint32) {
active, exists := s.activeConns[connID]
if exists {
delete(s.activeConns, connID)
active.ClientConn.Close()
// 确保连接被关闭
if active.ClientConn != nil {
active.ClientConn.Close()
}
}
s.connMu.Unlock()
if !exists {
// 连接不存在,无需发送关闭消息
return
}
// 发送关闭消息
closeData := make([]byte, 4)
binary.BigEndian.PutUint32(closeData, connID)
@ -509,12 +573,11 @@ func (s *Server) closeConnection(connID uint32) {
select {
case s.sendChan <- msg:
default:
// 发送队列满,忽略
}
if exists {
log.Printf("连接已关闭: ID=%d", connID)
case <-time.After(1 * time.Second):
log.Printf("发送关闭消息超时: ID=%d", connID)
case <-s.ctx.Done():
log.Printf("服务器关闭,跳过发送关闭消息: ID=%d", connID)
}
}
@ -667,4 +730,23 @@ func (s *Server) keepAliveLoop(conn net.Conn) {
}
}
}
}
// GetTrafficStats 获取流量统计信息
func (s *Server) GetTrafficStats() stats.TrafficStats {
return stats.TrafficStats{
BytesSent: atomic.LoadUint64(&s.bytesSent),
BytesReceived: atomic.LoadUint64(&s.bytesReceived),
LastUpdate: time.Now().Unix(),
}
}
// addBytesSent 增加发送字节数
func (s *Server) addBytesSent(bytes uint64) {
atomic.AddUint64(&s.bytesSent, bytes)
}
// addBytesReceived 增加接收字节数
func (s *Server) addBytesReceived(bytes uint64) {
atomic.AddUint64(&s.bytesReceived, bytes)
}