feat: 支持了记录统计信息到数据库
This commit is contained in:
parent
fc1614f7a4
commit
0f47e943f4
|
|
@ -17,4 +17,10 @@ api:
|
|||
|
||||
# 数据库配置
|
||||
database:
|
||||
path: "./data/mappings.db"
|
||||
path: "./data/mappings.db"
|
||||
|
||||
# 统计信息记录配置
|
||||
stats:
|
||||
enabled: true # 是否启用统计记录
|
||||
record_interval: 60 # 记录间隔(秒),默认60秒记录一次
|
||||
retention_days: 7 # 数据保留天数,默认保留7天
|
||||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"port-forward/server/stats"
|
||||
"port-forward/server/tunnel"
|
||||
"port-forward/server/utils"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -64,6 +65,7 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
|||
mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping))
|
||||
mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings))
|
||||
mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats))
|
||||
mux.HandleFunc("/api/stats/history", h.authMiddleware(h.handleGetTrafficHistory))
|
||||
mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor))
|
||||
mux.HandleFunc("/admin", h.handleManagement)
|
||||
mux.HandleFunc("/health", h.handleHealth)
|
||||
|
|
@ -359,6 +361,12 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request)
|
|||
totalSent += tunnelStats.BytesSent
|
||||
totalReceived += tunnelStats.BytesReceived
|
||||
|
||||
// mappings 根据端口号排序
|
||||
sort.Slice(mappings, func(i, j int) bool {
|
||||
return mappings[i].Port < mappings[j].Port
|
||||
})
|
||||
|
||||
// 构建最终响应
|
||||
response := stats.AllTrafficStats{
|
||||
Tunnel: tunnelStats,
|
||||
Mappings: mappings,
|
||||
|
|
@ -381,3 +389,50 @@ func (h *Handler) handleManagement(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
fmt.Fprint(w, GetManagementHTML())
|
||||
}
|
||||
|
||||
// handleGetTrafficHistory 获取历史流量统计
|
||||
func (h *Handler) handleGetTrafficHistory(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
h.writeError(w, http.StatusMethodNotAllowed, "只支持 GET 方法")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取查询参数
|
||||
portStr := r.URL.Query().Get("port")
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
|
||||
port := -1 // -1 表示所有端口
|
||||
if portStr != "" {
|
||||
var err error
|
||||
port, err = strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
h.writeError(w, http.StatusBadRequest, "无效的端口号")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
limit := 100 // 默认返回最近100条
|
||||
if limitStr != "" {
|
||||
var err error
|
||||
limit, err = strconv.Atoi(limitStr)
|
||||
if err != nil {
|
||||
h.writeError(w, http.StatusBadRequest, "无效的limit参数")
|
||||
return
|
||||
}
|
||||
if limit > 1000 {
|
||||
limit = 1000 // 最多返回1000条
|
||||
}
|
||||
}
|
||||
|
||||
// 查询历史记录
|
||||
records, err := h.db.GetTrafficRecords(port, limit)
|
||||
if err != nil {
|
||||
h.writeError(w, http.StatusInternalServerError, "查询历史记录失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.writeSuccess(w, "获取历史流量统计成功", map[string]interface{}{
|
||||
"records": records,
|
||||
"count": len(records),
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ type Config struct {
|
|||
Tunnel TunnelConfig `yaml:"tunnel"`
|
||||
API APIConfig `yaml:"api"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Stats StatsConfig `yaml:"stats"`
|
||||
}
|
||||
|
||||
// PortRangeConfig 端口范围配置
|
||||
|
|
@ -38,6 +39,13 @@ type DatabaseConfig struct {
|
|||
Path string `yaml:"path"`
|
||||
}
|
||||
|
||||
// StatsConfig 统计记录配置
|
||||
type StatsConfig struct {
|
||||
Enabled bool `yaml:"enabled"` // 是否启用统计记录
|
||||
RecordInterval int `yaml:"record_interval"` // 记录间隔(秒)
|
||||
RetentionDays int `yaml:"retention_days"` // 数据保留天数
|
||||
}
|
||||
|
||||
// Load 从文件加载配置
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,16 @@ type Mapping struct {
|
|||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// TrafficRecord 流量统计记录
|
||||
type TrafficRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
Port int `json:"port"` // 端口号
|
||||
IsTunnel bool `json:"is_tunnel"` // 是否为隧道整体流量
|
||||
BytesSent uint64 `json:"bytes_sent"` // 发送字节数
|
||||
BytesReceived uint64 `json:"bytes_received"` // 接收字节数
|
||||
RecordedAt string `json:"recorded_at"` // 记录时间
|
||||
}
|
||||
|
||||
// Database 数据库管理器
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
|
|
@ -69,6 +79,18 @@ func (d *Database) initTables() error {
|
|||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS traffic_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
port INTEGER NOT NULL,
|
||||
is_tunnel BOOLEAN NOT NULL DEFAULT 0,
|
||||
bytes_sent INTEGER NOT NULL DEFAULT 0,
|
||||
bytes_received INTEGER NOT NULL DEFAULT 0,
|
||||
recorded_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_port ON traffic_records(port);
|
||||
CREATE INDEX IF NOT EXISTS idx_is_tunnel ON traffic_records(is_tunnel);
|
||||
CREATE INDEX IF NOT EXISTS idx_recorded_at ON traffic_records(recorded_at);
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query)
|
||||
|
|
@ -176,6 +198,37 @@ func (d *Database) migrateDatabase() error {
|
|||
}
|
||||
}
|
||||
|
||||
// 检查 traffic_records 表的 is_tunnel 列
|
||||
rows3, err := d.db.Query("PRAGMA table_info(traffic_records)")
|
||||
if err == nil {
|
||||
defer rows3.Close()
|
||||
hasIsTunnel := false
|
||||
for rows3.Next() {
|
||||
var cid int
|
||||
var name, dataType string
|
||||
var notNull, hasDefault int
|
||||
var defaultValue interface{}
|
||||
|
||||
err := rows3.Scan(&cid, &name, &dataType, ¬Null, &defaultValue, &hasDefault)
|
||||
if err != nil {
|
||||
return fmt.Errorf("扫描 traffic_records 表结构失败: %w", err)
|
||||
}
|
||||
|
||||
if name == "is_tunnel" {
|
||||
hasIsTunnel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不存在 is_tunnel 列,则添加它
|
||||
if !hasIsTunnel {
|
||||
_, err := d.db.Exec("ALTER TABLE traffic_records ADD COLUMN is_tunnel BOOLEAN NOT NULL DEFAULT 0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加 is_tunnel 列失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -285,3 +338,90 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
|
|||
func (d *Database) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
// AddTrafficRecord 添加流量统计记录
|
||||
func (d *Database) AddTrafficRecord(port int, isTunnel bool, bytesSent, bytesReceived uint64) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
query := `INSERT INTO traffic_records (port, is_tunnel, bytes_sent, bytes_received) VALUES (?, ?, ?, ?)`
|
||||
_, err := d.db.Exec(query, port, isTunnel, bytesSent, bytesReceived)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加流量记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanOldTrafficRecords 清理旧的流量记录
|
||||
func (d *Database) CleanOldTrafficRecords(retentionDays int) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
query := `DELETE FROM traffic_records WHERE recorded_at < datetime('now', '-' || ? || ' days')`
|
||||
result, err := d.db.Exec(query, retentionDays)
|
||||
if err != nil {
|
||||
return fmt.Errorf("清理旧流量记录失败: %w", err)
|
||||
}
|
||||
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows > 0 {
|
||||
fmt.Printf("已清理 %d 条旧流量记录(保留 %d 天)\n", rows, retentionDays)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTrafficRecords 获取流量记录
|
||||
func (d *Database) GetTrafficRecords(port int, limit int) ([]*TrafficRecord, error) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
var query string
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
if port == -1 {
|
||||
// 获取所有记录
|
||||
query = `SELECT id, port, is_tunnel, bytes_sent, bytes_received, recorded_at
|
||||
FROM traffic_records
|
||||
ORDER BY recorded_at DESC
|
||||
LIMIT ?`
|
||||
rows, err = d.db.Query(query, limit)
|
||||
} else {
|
||||
// 获取指定端口的记录
|
||||
query = `SELECT id, port, is_tunnel, bytes_sent, bytes_received, recorded_at
|
||||
FROM traffic_records
|
||||
WHERE port = ?
|
||||
ORDER BY recorded_at DESC
|
||||
LIMIT ?`
|
||||
rows, err = d.db.Query(query, port, limit)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询流量记录失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []*TrafficRecord
|
||||
for rows.Next() {
|
||||
var record TrafficRecord
|
||||
if err := rows.Scan(
|
||||
&record.ID,
|
||||
&record.Port,
|
||||
&record.IsTunnel,
|
||||
&record.BytesSent,
|
||||
&record.BytesReceived,
|
||||
&record.RecordedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描流量记录失败: %w", err)
|
||||
}
|
||||
records = append(records, &record)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("遍历流量记录失败: %w", err)
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -187,11 +187,28 @@ func (rlr *rateLimitedReader) Read(p []byte) (int, error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
// countingWriter 带统计的 Writer
|
||||
type countingWriter struct {
|
||||
w io.Writer
|
||||
counter *uint64
|
||||
port int
|
||||
}
|
||||
|
||||
func (cw *countingWriter) Write(p []byte) (int, error) {
|
||||
n, err := cw.w.Write(p)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(cw.counter, uint64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// handleConnection 处理单个连接
|
||||
func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
||||
defer f.wg.Done()
|
||||
defer clientConn.Close()
|
||||
|
||||
log.Printf("端口 %d 收到新连接: %s", f.sourcePort, clientConn.RemoteAddr())
|
||||
|
||||
var targetConn net.Conn
|
||||
var err error
|
||||
|
||||
|
|
@ -238,8 +255,13 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
|||
limiter: f.limiterOut,
|
||||
ctx: f.ctx,
|
||||
}
|
||||
n, _ := io.Copy(targetConn, reader)
|
||||
atomic.AddUint64(&f.bytesSent, uint64(n))
|
||||
writer := &countingWriter{
|
||||
w: targetConn,
|
||||
counter: &f.bytesSent,
|
||||
port: f.sourcePort,
|
||||
}
|
||||
n, _ := io.Copy(writer, reader)
|
||||
log.Printf("端口 %d: 客户端->目标传输完成,本次发送 %d 字节 (总发送: %d)", f.sourcePort, n, atomic.LoadUint64(&f.bytesSent))
|
||||
// 关闭目标连接的写入端,通知对方不会再发送数据
|
||||
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
|
||||
tcpConn.CloseWrite()
|
||||
|
|
@ -254,8 +276,13 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
|||
limiter: f.limiterIn,
|
||||
ctx: f.ctx,
|
||||
}
|
||||
n, _ := io.Copy(clientConn, reader)
|
||||
atomic.AddUint64(&f.bytesReceived, uint64(n))
|
||||
writer := &countingWriter{
|
||||
w: clientConn,
|
||||
counter: &f.bytesReceived,
|
||||
port: f.sourcePort,
|
||||
}
|
||||
n, _ := io.Copy(writer, reader)
|
||||
log.Printf("端口 %d: 目标->客户端传输完成,本次接收 %d 字节 (总接收: %d)", f.sourcePort, n, atomic.LoadUint64(&f.bytesReceived))
|
||||
// 关闭客户端连接的写入端
|
||||
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
|
||||
tcpConn.CloseWrite()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"port-forward/server/config"
|
||||
"port-forward/server/db"
|
||||
"port-forward/server/forwarder"
|
||||
"port-forward/server/recorder"
|
||||
"port-forward/server/tunnel"
|
||||
"port-forward/server/utils"
|
||||
"strings"
|
||||
|
|
@ -20,13 +21,14 @@ import (
|
|||
|
||||
// serverService 服务实例
|
||||
type serverService struct {
|
||||
configPath string
|
||||
cfg *config.Config
|
||||
database *db.Database
|
||||
fwdManager *forwarder.Manager
|
||||
tunnelServer *tunnel.Server
|
||||
apiHandler *api.Handler
|
||||
sigChan chan os.Signal
|
||||
configPath string
|
||||
cfg *config.Config
|
||||
database *db.Database
|
||||
fwdManager *forwarder.Manager
|
||||
tunnelServer *tunnel.Server
|
||||
apiHandler *api.Handler
|
||||
statsRecorder *recorder.Recorder
|
||||
sigChan chan os.Signal
|
||||
}
|
||||
|
||||
func (s *serverService) Start() error {
|
||||
|
|
@ -105,12 +107,22 @@ func (s *serverService) Start() error {
|
|||
}
|
||||
}()
|
||||
|
||||
// 启动流量统计记录器(如果启用)
|
||||
if cfg.Stats.Enabled {
|
||||
log.Println("启动流量统计记录器...")
|
||||
s.statsRecorder = recorder.New(database, s.fwdManager, s.tunnelServer, cfg.Stats.RecordInterval, cfg.Stats.RetentionDays)
|
||||
s.statsRecorder.Start()
|
||||
}
|
||||
|
||||
log.Println("===========================================")
|
||||
log.Printf("服务器启动成功!")
|
||||
log.Printf("HTTP API: http://localhost:%d", cfg.API.ListenPort)
|
||||
if cfg.Tunnel.Enabled {
|
||||
log.Printf("隧道服务: 端口 %d", cfg.Tunnel.ListenPort)
|
||||
}
|
||||
if cfg.Stats.Enabled {
|
||||
log.Printf("流量统计: 每 %d 秒记录一次,保留 %d 天", cfg.Stats.RecordInterval, cfg.Stats.RetentionDays)
|
||||
}
|
||||
log.Println("===========================================")
|
||||
|
||||
// 等待中断信号
|
||||
|
|
@ -124,6 +136,11 @@ func (s *serverService) Start() error {
|
|||
func (s *serverService) Stop() error {
|
||||
log.Println("接收到关闭信号,正在优雅关闭...")
|
||||
|
||||
// 停止流量统计记录器
|
||||
if s.statsRecorder != nil {
|
||||
s.statsRecorder.Stop()
|
||||
}
|
||||
|
||||
// 停止所有转发器
|
||||
if s.fwdManager != nil {
|
||||
log.Println("停止所有端口转发...")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,131 @@
|
|||
package recorder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"port-forward/server/db"
|
||||
"port-forward/server/forwarder"
|
||||
"port-forward/server/tunnel"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Recorder 流量统计记录器
|
||||
type Recorder struct {
|
||||
db *db.Database
|
||||
forwarderMgr *forwarder.Manager
|
||||
tunnelServer *tunnel.Server
|
||||
interval time.Duration
|
||||
retentionDay int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// New 创建新的记录器
|
||||
func New(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, recordInterval, retentionDays int) *Recorder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Recorder{
|
||||
db: database,
|
||||
forwarderMgr: fwdMgr,
|
||||
tunnelServer: ts,
|
||||
interval: time.Duration(recordInterval) * time.Second,
|
||||
retentionDay: retentionDays,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动记录器
|
||||
func (r *Recorder) Start() {
|
||||
log.Printf("流量统计记录器启动: 记录间隔=%v, 数据保留=%d天", r.interval, r.retentionDay)
|
||||
|
||||
// 启动定时记录任务
|
||||
r.wg.Add(1)
|
||||
go r.recordLoop()
|
||||
|
||||
// 启动定时清理任务(每小时执行一次)
|
||||
r.wg.Add(1)
|
||||
go r.cleanupLoop()
|
||||
}
|
||||
|
||||
// Stop 停止记录器
|
||||
func (r *Recorder) Stop() {
|
||||
log.Println("正在停止流量统计记录器...")
|
||||
r.cancel()
|
||||
r.wg.Wait()
|
||||
log.Println("流量统计记录器已停止")
|
||||
}
|
||||
|
||||
// recordLoop 定时记录循环
|
||||
func (r *Recorder) recordLoop() {
|
||||
defer r.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(r.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
r.recordStats()
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupLoop 定时清理循环
|
||||
func (r *Recorder) cleanupLoop() {
|
||||
defer r.wg.Done()
|
||||
|
||||
// 每小时执行一次清理
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动时先执行一次清理
|
||||
r.cleanup()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
r.cleanup()
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordStats 记录当前统计信息
|
||||
func (r *Recorder) recordStats() {
|
||||
// 记录隧道流量(is_tunnel=true, port可以是0或任意标识)
|
||||
if r.tunnelServer != nil {
|
||||
tunnelStats := r.tunnelServer.GetTrafficStats()
|
||||
if err := r.db.AddTrafficRecord(0, true, tunnelStats.BytesSent, tunnelStats.BytesReceived); err != nil {
|
||||
log.Printf("记录隧道流量失败: %v", err)
|
||||
} else {
|
||||
log.Printf("记录隧道流量: 发送=%d, 接收=%d", tunnelStats.BytesSent, tunnelStats.BytesReceived)
|
||||
}
|
||||
}
|
||||
|
||||
// 记录各端口映射的流量(is_tunnel=false)
|
||||
forwarderStats := r.forwarderMgr.GetAllTrafficStats()
|
||||
recordCount := 0
|
||||
for port, stats := range forwarderStats {
|
||||
if err := r.db.AddTrafficRecord(port, false, stats.BytesSent, stats.BytesReceived); err != nil {
|
||||
log.Printf("记录端口 %d 流量失败: %v", port, err)
|
||||
} else {
|
||||
recordCount++
|
||||
}
|
||||
}
|
||||
|
||||
if recordCount > 0 {
|
||||
log.Printf("已记录 %d 个端口的流量统计", recordCount)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 清理旧数据
|
||||
func (r *Recorder) cleanup() {
|
||||
if err := r.db.CleanOldTrafficRecords(r.retentionDay); err != nil {
|
||||
log.Printf("清理旧流量记录失败: %v", err)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue