feat: 支持了记录统计信息到数据库

This commit is contained in:
qcqcqc@wsl 2026-01-09 13:29:48 +08:00
parent fc1614f7a4
commit 0f47e943f4
7 changed files with 396 additions and 12 deletions

View File

@ -17,4 +17,10 @@ api:
# 数据库配置
database:
path: "./data/mappings.db"
path: "./data/mappings.db"
# 统计信息记录配置
stats:
enabled: true # 是否启用统计记录
record_interval: 60 # 记录间隔默认60秒记录一次
retention_days: 7 # 数据保留天数默认保留7天

View File

@ -11,6 +11,7 @@ import (
"port-forward/server/stats"
"port-forward/server/tunnel"
"port-forward/server/utils"
"sort"
"strconv"
"time"
)
@ -64,6 +65,7 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping))
mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings))
mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats))
mux.HandleFunc("/api/stats/history", h.authMiddleware(h.handleGetTrafficHistory))
mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor))
mux.HandleFunc("/admin", h.handleManagement)
mux.HandleFunc("/health", h.handleHealth)
@ -359,6 +361,12 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request)
totalSent += tunnelStats.BytesSent
totalReceived += tunnelStats.BytesReceived
// mappings 根据端口号排序
sort.Slice(mappings, func(i, j int) bool {
return mappings[i].Port < mappings[j].Port
})
// 构建最终响应
response := stats.AllTrafficStats{
Tunnel: tunnelStats,
Mappings: mappings,
@ -381,3 +389,50 @@ func (h *Handler) handleManagement(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, GetManagementHTML())
}
// handleGetTrafficHistory 获取历史流量统计
func (h *Handler) handleGetTrafficHistory(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
h.writeError(w, http.StatusMethodNotAllowed, "只支持 GET 方法")
return
}
// 获取查询参数
portStr := r.URL.Query().Get("port")
limitStr := r.URL.Query().Get("limit")
port := -1 // -1 表示所有端口
if portStr != "" {
var err error
port, err = strconv.Atoi(portStr)
if err != nil {
h.writeError(w, http.StatusBadRequest, "无效的端口号")
return
}
}
limit := 100 // 默认返回最近100条
if limitStr != "" {
var err error
limit, err = strconv.Atoi(limitStr)
if err != nil {
h.writeError(w, http.StatusBadRequest, "无效的limit参数")
return
}
if limit > 1000 {
limit = 1000 // 最多返回1000条
}
}
// 查询历史记录
records, err := h.db.GetTrafficRecords(port, limit)
if err != nil {
h.writeError(w, http.StatusInternalServerError, "查询历史记录失败: "+err.Error())
return
}
h.writeSuccess(w, "获取历史流量统计成功", map[string]interface{}{
"records": records,
"count": len(records),
})
}

View File

@ -13,6 +13,7 @@ type Config struct {
Tunnel TunnelConfig `yaml:"tunnel"`
API APIConfig `yaml:"api"`
Database DatabaseConfig `yaml:"database"`
Stats StatsConfig `yaml:"stats"`
}
// PortRangeConfig 端口范围配置
@ -38,6 +39,13 @@ type DatabaseConfig struct {
Path string `yaml:"path"`
}
// StatsConfig 统计记录配置
type StatsConfig struct {
Enabled bool `yaml:"enabled"` // 是否启用统计记录
RecordInterval int `yaml:"record_interval"` // 记录间隔(秒)
RetentionDays int `yaml:"retention_days"` // 数据保留天数
}
// Load 从文件加载配置
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)

View File

@ -21,6 +21,16 @@ type Mapping struct {
CreatedAt string `json:"created_at"`
}
// TrafficRecord 流量统计记录
type TrafficRecord struct {
ID int64 `json:"id"`
Port int `json:"port"` // 端口号
IsTunnel bool `json:"is_tunnel"` // 是否为隧道整体流量
BytesSent uint64 `json:"bytes_sent"` // 发送字节数
BytesReceived uint64 `json:"bytes_received"` // 接收字节数
RecordedAt string `json:"recorded_at"` // 记录时间
}
// Database 数据库管理器
type Database struct {
db *sql.DB
@ -69,6 +79,18 @@ func (d *Database) initTables() error {
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port);
CREATE TABLE IF NOT EXISTS traffic_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
port INTEGER NOT NULL,
is_tunnel BOOLEAN NOT NULL DEFAULT 0,
bytes_sent INTEGER NOT NULL DEFAULT 0,
bytes_received INTEGER NOT NULL DEFAULT 0,
recorded_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_port ON traffic_records(port);
CREATE INDEX IF NOT EXISTS idx_is_tunnel ON traffic_records(is_tunnel);
CREATE INDEX IF NOT EXISTS idx_recorded_at ON traffic_records(recorded_at);
`
_, err := d.db.Exec(query)
@ -176,6 +198,37 @@ func (d *Database) migrateDatabase() error {
}
}
// 检查 traffic_records 表的 is_tunnel 列
rows3, err := d.db.Query("PRAGMA table_info(traffic_records)")
if err == nil {
defer rows3.Close()
hasIsTunnel := false
for rows3.Next() {
var cid int
var name, dataType string
var notNull, hasDefault int
var defaultValue interface{}
err := rows3.Scan(&cid, &name, &dataType, &notNull, &defaultValue, &hasDefault)
if err != nil {
return fmt.Errorf("扫描 traffic_records 表结构失败: %w", err)
}
if name == "is_tunnel" {
hasIsTunnel = true
break
}
}
// 如果不存在 is_tunnel 列,则添加它
if !hasIsTunnel {
_, err := d.db.Exec("ALTER TABLE traffic_records ADD COLUMN is_tunnel BOOLEAN NOT NULL DEFAULT 0")
if err != nil {
return fmt.Errorf("添加 is_tunnel 列失败: %w", err)
}
}
}
return nil
}
@ -285,3 +338,90 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
func (d *Database) Close() error {
return d.db.Close()
}
// AddTrafficRecord 添加流量统计记录
func (d *Database) AddTrafficRecord(port int, isTunnel bool, bytesSent, bytesReceived uint64) error {
d.mu.Lock()
defer d.mu.Unlock()
query := `INSERT INTO traffic_records (port, is_tunnel, bytes_sent, bytes_received) VALUES (?, ?, ?, ?)`
_, err := d.db.Exec(query, port, isTunnel, bytesSent, bytesReceived)
if err != nil {
return fmt.Errorf("添加流量记录失败: %w", err)
}
return nil
}
// CleanOldTrafficRecords 清理旧的流量记录
func (d *Database) CleanOldTrafficRecords(retentionDays int) error {
d.mu.Lock()
defer d.mu.Unlock()
query := `DELETE FROM traffic_records WHERE recorded_at < datetime('now', '-' || ? || ' days')`
result, err := d.db.Exec(query, retentionDays)
if err != nil {
return fmt.Errorf("清理旧流量记录失败: %w", err)
}
rows, _ := result.RowsAffected()
if rows > 0 {
fmt.Printf("已清理 %d 条旧流量记录(保留 %d 天)\n", rows, retentionDays)
}
return nil
}
// GetTrafficRecords 获取流量记录
func (d *Database) GetTrafficRecords(port int, limit int) ([]*TrafficRecord, error) {
d.mu.RLock()
defer d.mu.RUnlock()
var query string
var rows *sql.Rows
var err error
if port == -1 {
// 获取所有记录
query = `SELECT id, port, is_tunnel, bytes_sent, bytes_received, recorded_at
FROM traffic_records
ORDER BY recorded_at DESC
LIMIT ?`
rows, err = d.db.Query(query, limit)
} else {
// 获取指定端口的记录
query = `SELECT id, port, is_tunnel, bytes_sent, bytes_received, recorded_at
FROM traffic_records
WHERE port = ?
ORDER BY recorded_at DESC
LIMIT ?`
rows, err = d.db.Query(query, port, limit)
}
if err != nil {
return nil, fmt.Errorf("查询流量记录失败: %w", err)
}
defer rows.Close()
var records []*TrafficRecord
for rows.Next() {
var record TrafficRecord
if err := rows.Scan(
&record.ID,
&record.Port,
&record.IsTunnel,
&record.BytesSent,
&record.BytesReceived,
&record.RecordedAt,
); err != nil {
return nil, fmt.Errorf("扫描流量记录失败: %w", err)
}
records = append(records, &record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("遍历流量记录失败: %w", err)
}
return records, nil
}

View File

@ -187,11 +187,28 @@ func (rlr *rateLimitedReader) Read(p []byte) (int, error) {
return n, err
}
// countingWriter 带统计的 Writer
type countingWriter struct {
w io.Writer
counter *uint64
port int
}
func (cw *countingWriter) Write(p []byte) (int, error) {
n, err := cw.w.Write(p)
if n > 0 {
atomic.AddUint64(cw.counter, uint64(n))
}
return n, err
}
// handleConnection 处理单个连接
func (f *Forwarder) handleConnection(clientConn net.Conn) {
defer f.wg.Done()
defer clientConn.Close()
log.Printf("端口 %d 收到新连接: %s", f.sourcePort, clientConn.RemoteAddr())
var targetConn net.Conn
var err error
@ -238,8 +255,13 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
limiter: f.limiterOut,
ctx: f.ctx,
}
n, _ := io.Copy(targetConn, reader)
atomic.AddUint64(&f.bytesSent, uint64(n))
writer := &countingWriter{
w: targetConn,
counter: &f.bytesSent,
port: f.sourcePort,
}
n, _ := io.Copy(writer, reader)
log.Printf("端口 %d: 客户端->目标传输完成,本次发送 %d 字节 (总发送: %d)", f.sourcePort, n, atomic.LoadUint64(&f.bytesSent))
// 关闭目标连接的写入端,通知对方不会再发送数据
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
@ -254,8 +276,13 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
limiter: f.limiterIn,
ctx: f.ctx,
}
n, _ := io.Copy(clientConn, reader)
atomic.AddUint64(&f.bytesReceived, uint64(n))
writer := &countingWriter{
w: clientConn,
counter: &f.bytesReceived,
port: f.sourcePort,
}
n, _ := io.Copy(writer, reader)
log.Printf("端口 %d: 目标->客户端传输完成,本次接收 %d 字节 (总接收: %d)", f.sourcePort, n, atomic.LoadUint64(&f.bytesReceived))
// 关闭客户端连接的写入端
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()

View File

@ -12,6 +12,7 @@ import (
"port-forward/server/config"
"port-forward/server/db"
"port-forward/server/forwarder"
"port-forward/server/recorder"
"port-forward/server/tunnel"
"port-forward/server/utils"
"strings"
@ -20,13 +21,14 @@ import (
// serverService 服务实例
type serverService struct {
configPath string
cfg *config.Config
database *db.Database
fwdManager *forwarder.Manager
tunnelServer *tunnel.Server
apiHandler *api.Handler
sigChan chan os.Signal
configPath string
cfg *config.Config
database *db.Database
fwdManager *forwarder.Manager
tunnelServer *tunnel.Server
apiHandler *api.Handler
statsRecorder *recorder.Recorder
sigChan chan os.Signal
}
func (s *serverService) Start() error {
@ -105,12 +107,22 @@ func (s *serverService) Start() error {
}
}()
// 启动流量统计记录器(如果启用)
if cfg.Stats.Enabled {
log.Println("启动流量统计记录器...")
s.statsRecorder = recorder.New(database, s.fwdManager, s.tunnelServer, cfg.Stats.RecordInterval, cfg.Stats.RetentionDays)
s.statsRecorder.Start()
}
log.Println("===========================================")
log.Printf("服务器启动成功!")
log.Printf("HTTP API: http://localhost:%d", cfg.API.ListenPort)
if cfg.Tunnel.Enabled {
log.Printf("隧道服务: 端口 %d", cfg.Tunnel.ListenPort)
}
if cfg.Stats.Enabled {
log.Printf("流量统计: 每 %d 秒记录一次,保留 %d 天", cfg.Stats.RecordInterval, cfg.Stats.RetentionDays)
}
log.Println("===========================================")
// 等待中断信号
@ -124,6 +136,11 @@ func (s *serverService) Start() error {
func (s *serverService) Stop() error {
log.Println("接收到关闭信号,正在优雅关闭...")
// 停止流量统计记录器
if s.statsRecorder != nil {
s.statsRecorder.Stop()
}
// 停止所有转发器
if s.fwdManager != nil {
log.Println("停止所有端口转发...")

View File

@ -0,0 +1,131 @@
package recorder
import (
"context"
"log"
"port-forward/server/db"
"port-forward/server/forwarder"
"port-forward/server/tunnel"
"sync"
"time"
)
// Recorder 流量统计记录器
type Recorder struct {
db *db.Database
forwarderMgr *forwarder.Manager
tunnelServer *tunnel.Server
interval time.Duration
retentionDay int
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// New 创建新的记录器
func New(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, recordInterval, retentionDays int) *Recorder {
ctx, cancel := context.WithCancel(context.Background())
return &Recorder{
db: database,
forwarderMgr: fwdMgr,
tunnelServer: ts,
interval: time.Duration(recordInterval) * time.Second,
retentionDay: retentionDays,
ctx: ctx,
cancel: cancel,
}
}
// Start 启动记录器
func (r *Recorder) Start() {
log.Printf("流量统计记录器启动: 记录间隔=%v, 数据保留=%d天", r.interval, r.retentionDay)
// 启动定时记录任务
r.wg.Add(1)
go r.recordLoop()
// 启动定时清理任务(每小时执行一次)
r.wg.Add(1)
go r.cleanupLoop()
}
// Stop 停止记录器
func (r *Recorder) Stop() {
log.Println("正在停止流量统计记录器...")
r.cancel()
r.wg.Wait()
log.Println("流量统计记录器已停止")
}
// recordLoop 定时记录循环
func (r *Recorder) recordLoop() {
defer r.wg.Done()
ticker := time.NewTicker(r.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
r.recordStats()
case <-r.ctx.Done():
return
}
}
}
// cleanupLoop 定时清理循环
func (r *Recorder) cleanupLoop() {
defer r.wg.Done()
// 每小时执行一次清理
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
// 启动时先执行一次清理
r.cleanup()
for {
select {
case <-ticker.C:
r.cleanup()
case <-r.ctx.Done():
return
}
}
}
// recordStats 记录当前统计信息
func (r *Recorder) recordStats() {
// 记录隧道流量is_tunnel=true, port可以是0或任意标识
if r.tunnelServer != nil {
tunnelStats := r.tunnelServer.GetTrafficStats()
if err := r.db.AddTrafficRecord(0, true, tunnelStats.BytesSent, tunnelStats.BytesReceived); err != nil {
log.Printf("记录隧道流量失败: %v", err)
} else {
log.Printf("记录隧道流量: 发送=%d, 接收=%d", tunnelStats.BytesSent, tunnelStats.BytesReceived)
}
}
// 记录各端口映射的流量is_tunnel=false
forwarderStats := r.forwarderMgr.GetAllTrafficStats()
recordCount := 0
for port, stats := range forwarderStats {
if err := r.db.AddTrafficRecord(port, false, stats.BytesSent, stats.BytesReceived); err != nil {
log.Printf("记录端口 %d 流量失败: %v", port, err)
} else {
recordCount++
}
}
if recordCount > 0 {
log.Printf("已记录 %d 个端口的流量统计", recordCount)
}
}
// cleanup 清理旧数据
func (r *Recorder) cleanup() {
if err := r.db.CleanOldTrafficRecords(r.retentionDay); err != nil {
log.Printf("清理旧流量记录失败: %v", err)
}
}