diff --git a/src/go.mod b/src/go.mod index 03f99d3..9d154fb 100644 --- a/src/go.mod +++ b/src/go.mod @@ -6,7 +6,8 @@ toolchain go1.24.4 require ( github.com/mattn/go-sqlite3 v1.14.22 + golang.org/x/time v0.14.0 gopkg.in/yaml.v3 v3.0.1 ) -require golang.org/x/sys v0.37.0 // indirect +require golang.org/x/sys v0.37.0 diff --git a/src/go.sum b/src/go.sum index 19b802d..a690b8e 100644 --- a/src/go.sum +++ b/src/go.sum @@ -2,6 +2,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/src/server/api/api.go b/src/server/api/api.go index 38c5cb0..eec44cf 100644 --- a/src/server/api/api.go +++ b/src/server/api/api.go @@ -39,10 +39,11 @@ func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Ser // CreateMappingRequest 创建映射请求 type CreateMappingRequest struct { - SourcePort int `json:"source_port"` // 源端口(本地监听端口) - TargetPort int `json:"target_port"` // 目标端口(远程服务端口) - TargetHost string `json:"target_host"` // 目标主机(支持IP或域名) - UseTunnel bool `json:"use_tunnel"` // 是否使用隧道模式 + SourcePort int `json:"source_port"` // 源端口(本地监听端口) + TargetPort int `json:"target_port"` // 目标端口(远程服务端口) + TargetHost string `json:"target_host"` // 目标主机(支持IP或域名) + UseTunnel bool `json:"use_tunnel"` // 是否使用隧道模式 + BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制,字节/秒,可为空 } // RemoveMappingRequest 删除映射请求 @@ -168,8 +169,14 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) { } } + //BandwidthLimit 合理范围不小于0 + if req.BandwidthLimit != nil && *req.BandwidthLimit < 0 { + h.writeError(w, http.StatusBadRequest, "带宽限制必须大于等于0") + return + } + // 添加到数据库 - if err := h.db.AddMapping(req.SourcePort, req.TargetHost, req.TargetPort, req.UseTunnel); err != nil { + if err := h.db.AddMapping(req.SourcePort, req.TargetHost, req.TargetPort, req.UseTunnel, req.BandwidthLimit); err != nil { h.writeError(w, http.StatusInternalServerError, "保存映射失败: "+err.Error()) return } @@ -178,10 +185,10 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) { var err error if req.UseTunnel { // 隧道模式:使用隧道转发 - err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetHost, req.TargetPort, h.tunnelServer) + err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetHost, req.TargetPort, h.tunnelServer, req.BandwidthLimit) } else { // 直接模式:直接TCP转发 - err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort) + err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort, req.BandwidthLimit) } if err != nil { diff --git a/src/server/api/api_test.go b/src/server/api/api_test.go index 23f1c97..c6716f7 100644 --- a/src/server/api/api_test.go +++ b/src/server/api/api_test.go @@ -341,8 +341,8 @@ func TestHandleRemoveMapping(t *testing.T) { defer cleanup() // 先创建一个映射 - database.AddMapping(15000, "192.168.1.100", 15000, false) - handler.forwarderMgr.Add(15000, "192.168.1.100", 15000) + database.AddMapping(15000, "192.168.1.100", 15000, false, nil) + handler.forwarderMgr.Add(15000, "192.168.1.100", 15000, nil) reqBody := RemoveMappingRequest{ Port: 15000, @@ -391,9 +391,9 @@ func TestHandleListMappings(t *testing.T) { defer cleanup() // 添加一些映射 - database.AddMapping(15000, "192.168.1.100", 15000, false) - database.AddMapping(15001, "192.168.1.101", 15001, true) - database.AddMapping(15002, "192.168.1.102", 15002, false) + database.AddMapping(15000, "192.168.1.100", 15000, false, nil) + database.AddMapping(15001, "192.168.1.101", 15001, true, nil) + database.AddMapping(15002, "192.168.1.102", 15002, false, nil) req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil) addAuthHeader(req) @@ -582,7 +582,7 @@ func BenchmarkHandleListMappings(b *testing.B) { // 添加一些映射 for i := 0; i < 100; i++ { useTunnel := i%2 == 0 // 偶数使用隧道模式 - database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel) + database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel, nil) } fwdMgr := forwarder.NewManager() diff --git a/src/server/db/database.go b/src/server/db/database.go index 62bcbd5..ab9d2d9 100644 --- a/src/server/db/database.go +++ b/src/server/db/database.go @@ -12,12 +12,13 @@ import ( // Mapping 端口映射结构 type Mapping struct { - ID int64 `json:"id"` - SourcePort int `json:"source_port"` - TargetHost string `json:"target_host"` // 支持IP或域名 - TargetPort int `json:"target_port"` - UseTunnel bool `json:"use_tunnel"` - CreatedAt string `json:"created_at"` + ID int64 `json:"id"` + SourcePort int `json:"source_port"` + TargetHost string `json:"target_host"` // 支持IP或域名 + TargetPort int `json:"target_port"` + UseTunnel bool `json:"use_tunnel"` + BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制,字节/秒,可为空 + CreatedAt string `json:"created_at"` } // Database 数据库管理器 @@ -64,21 +65,22 @@ func (d *Database) initTables() error { target_host TEXT NOT NULL, target_port INTEGER NOT NULL, use_tunnel BOOLEAN NOT NULL DEFAULT 0, + bandwidth_limit INTEGER, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port); ` - + _, err := d.db.Exec(query) if err != nil { return fmt.Errorf("初始化数据库表失败: %w", err) } - + // 检查是否需要迁移现有数据 if err := d.migrateDatabase(); err != nil { return fmt.Errorf("数据库迁移失败: %w", err) } - + return nil } @@ -93,23 +95,27 @@ func (d *Database) migrateDatabase() error { hasUseTunnel := false hasTargetHost := false + hasBandwidthLimit := false for rows.Next() { var cid int var name, dataType string var notNull, hasDefault int var defaultValue interface{} - + err := rows.Scan(&cid, &name, &dataType, ¬Null, &defaultValue, &hasDefault) if err != nil { return fmt.Errorf("扫描表结构失败: %w", err) } - + if name == "use_tunnel" { hasUseTunnel = true } if name == "target_host" { hasTargetHost = true } + if name == "bandwidth_limit" { + hasBandwidthLimit = true + } } // 如果不存在 use_tunnel 列,则添加它 @@ -135,12 +141,12 @@ func (d *Database) migrateDatabase() error { var name, dataType string var notNull, hasDefault int var defaultValue interface{} - + err := rows2.Scan(&cid, &name, &dataType, ¬Null, &defaultValue, &hasDefault) if err != nil { return fmt.Errorf("扫描表结构失败: %w", err) } - + if name == "target_ip" { hasTargetIP = true break @@ -162,20 +168,28 @@ func (d *Database) migrateDatabase() error { } } + // 如果不存在 bandwidth_limit 列,则添加它 + if !hasBandwidthLimit { + _, err := d.db.Exec("ALTER TABLE mappings ADD COLUMN bandwidth_limit INTEGER") + if err != nil { + return fmt.Errorf("添加 bandwidth_limit 列失败: %w", err) + } + } + return nil } -// AddMapping 添加端口映射 -func (d *Database) AddMapping(sourcePort int, targetHost string, targetPort int, useTunnel bool) error { +// AddMapping 添加带宽限制的端口映射 +func (d *Database) AddMapping(sourcePort int, targetHost string, targetPort int, useTunnel bool, bandwidthLimit *int64) error { d.mu.Lock() defer d.mu.Unlock() - query := `INSERT INTO mappings (source_port, target_host, target_port, use_tunnel) VALUES (?, ?, ?, ?)` - _, err := d.db.Exec(query, sourcePort, targetHost, targetPort, useTunnel) + query := `INSERT INTO mappings (source_port, target_host, target_port, use_tunnel, bandwidth_limit) VALUES (?, ?, ?, ?, ?)` + _, err := d.db.Exec(query, sourcePort, targetHost, targetPort, useTunnel, bandwidthLimit) if err != nil { return fmt.Errorf("添加端口映射失败: %w", err) } - + return nil } @@ -189,16 +203,16 @@ func (d *Database) RemoveMapping(sourcePort int) error { if err != nil { return fmt.Errorf("删除端口映射失败: %w", err) } - + rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("获取影响行数失败: %w", err) } - + if rows == 0 { return fmt.Errorf("端口映射不存在") } - + return nil } @@ -207,8 +221,8 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) { d.mu.RLock() defer d.mu.RUnlock() - query := `SELECT id, source_port, target_host, target_port, use_tunnel, created_at FROM mappings WHERE source_port = ?` - + query := `SELECT id, source_port, target_host, target_port, use_tunnel, bandwidth_limit, created_at FROM mappings WHERE source_port = ?` + var mapping Mapping err := d.db.QueryRow(query, sourcePort).Scan( &mapping.ID, @@ -216,16 +230,17 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) { &mapping.TargetHost, &mapping.TargetPort, &mapping.UseTunnel, + &mapping.BandwidthLimit, &mapping.CreatedAt, ) - + if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("查询端口映射失败: %w", err) } - + return &mapping, nil } @@ -234,14 +249,14 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) { d.mu.RLock() defer d.mu.RUnlock() - query := `SELECT id, source_port, target_host, target_port, use_tunnel, created_at FROM mappings ORDER BY source_port` - + query := `SELECT id, source_port, target_host, target_port, use_tunnel, bandwidth_limit, created_at FROM mappings ORDER BY source_port` + rows, err := d.db.Query(query) if err != nil { return nil, fmt.Errorf("查询所有映射失败: %w", err) } defer rows.Close() - + var mappings []*Mapping for rows.Next() { var mapping Mapping @@ -251,21 +266,22 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) { &mapping.TargetHost, &mapping.TargetPort, &mapping.UseTunnel, + &mapping.BandwidthLimit, &mapping.CreatedAt, ); err != nil { return nil, fmt.Errorf("扫描映射记录失败: %w", err) } mappings = append(mappings, &mapping) } - + if err := rows.Err(); err != nil { return nil, fmt.Errorf("遍历映射记录失败: %w", err) } - + return mappings, nil } // Close 关闭数据库连接 func (d *Database) Close() error { return d.db.Close() -} \ No newline at end of file +} diff --git a/src/server/db/database_test.go b/src/server/db/database_test.go index 31b36b5..68c65f5 100644 --- a/src/server/db/database_test.go +++ b/src/server/db/database_test.go @@ -18,7 +18,7 @@ func TestDatabase(t *testing.T) { defer db.Close() t.Run("添加映射", func(t *testing.T) { - err := db.AddMapping(10001, "192.168.1.100", 22, false) + err := db.AddMapping(10001, "192.168.1.100", 22, false, nil) if err != nil { t.Errorf("添加映射失败: %v", err) } @@ -44,7 +44,7 @@ func TestDatabase(t *testing.T) { }) t.Run("添加重复映射应该失败", func(t *testing.T) { - err := db.AddMapping(10001, "192.168.1.101", 22, true) + err := db.AddMapping(10001, "192.168.1.101", 22, true, nil) if err == nil { t.Error("添加重复映射应该失败") } @@ -52,8 +52,8 @@ func TestDatabase(t *testing.T) { t.Run("获取所有映射", func(t *testing.T) { // 添加更多映射 - db.AddMapping(10002, "192.168.1.101", 22, true) - db.AddMapping(10003, "192.168.1.102", 22, false) + db.AddMapping(10002, "192.168.1.101", 22, true, nil) + db.AddMapping(10003, "192.168.1.102", 22, false, nil) mappings, err := db.GetAllMappings() if err != nil { @@ -102,7 +102,7 @@ func TestDatabaseConcurrency(t *testing.T) { for i := 0; i < 10; i++ { go func(port int) { useTunnel := port%2 == 0 // 偶数端口使用隧道模式 - err := db.AddMapping(10000+port, "192.168.1.100", port, useTunnel) + err := db.AddMapping(10000+port, "192.168.1.100", port, useTunnel, nil) if err != nil { t.Logf("添加映射失败 (端口 %d): %v", 10000+port, err) } @@ -123,4 +123,4 @@ func TestDatabaseConcurrency(t *testing.T) { if len(mappings) == 0 { t.Error("应该至少有一些映射") } -} \ No newline at end of file +} diff --git a/src/server/forwarder/forwarder.go b/src/server/forwarder/forwarder.go index b11c066..c93d882 100644 --- a/src/server/forwarder/forwarder.go +++ b/src/server/forwarder/forwarder.go @@ -10,6 +10,8 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/time/rate" ) // TunnelServer 隧道服务器接口 @@ -30,15 +32,24 @@ type Forwarder struct { wg sync.WaitGroup tunnelServer TunnelServer useTunnel bool - + limit *int64 + // 流量统计(使用原子操作) - bytesSent uint64 // 发送字节数 - bytesReceived uint64 // 接收字节数 + bytesSent uint64 // 发送字节数 + bytesReceived uint64 // 接收字节数 + limiterOut *rate.Limiter // 限速器(出方向) + limiterIn *rate.Limiter // 限速器(入方向) } // NewForwarder 创建新的端口转发器 -func NewForwarder(sourcePort int, targetHost string, targetPort int) *Forwarder { +func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int64) *Forwarder { ctx, cancel := context.WithCancel(context.Background()) + var limiterOut, limiterIn *rate.Limiter + if limit != nil { + burst := int(*limit) // 容量至少等于速率,不然无法正常突发 + limiterOut = rate.NewLimiter(rate.Limit(*limit), burst) + limiterIn = rate.NewLimiter(rate.Limit(*limit), burst) + } return &Forwarder{ sourcePort: sourcePort, targetPort: targetPort, @@ -46,12 +57,21 @@ func NewForwarder(sourcePort int, targetHost string, targetPort int) *Forwarder cancel: cancel, ctx: ctx, useTunnel: false, + limit: limit, + limiterOut: limiterOut, + limiterIn: limiterIn, } } // NewTunnelForwarder 创建使用隧道的端口转发器 -func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer) *Forwarder { +func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer, limit *int64) *Forwarder { ctx, cancel := context.WithCancel(context.Background()) + var limiterOut, limiterIn *rate.Limiter + if limit != nil { + burst := int(*limit) // 容量至少等于速率,不然无法正常突发 + limiterOut = rate.NewLimiter(rate.Limit(*limit), burst) + limiterIn = rate.NewLimiter(rate.Limit(*limit), burst) + } return &Forwarder{ sourcePort: sourcePort, targetPort: targetPort, @@ -60,6 +80,9 @@ func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunne useTunnel: true, cancel: cancel, ctx: ctx, + limit: limit, + limiterOut: limiterOut, + limiterIn: limiterIn, } } @@ -92,7 +115,7 @@ func (f *Forwarder) acceptLoop() { // 设置接受超时,避免阻塞关闭 f.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second)) - + conn, err := f.listener.Accept() if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { @@ -112,6 +135,27 @@ func (f *Forwarder) acceptLoop() { } } +type rateLimitedReader struct { + r io.Reader + limiter *rate.Limiter + ctx context.Context +} + +func (rlr *rateLimitedReader) Read(p []byte) (int, error) { + if rlr.limiter != nil { + maxReq := rlr.limiter.Burst() + reqSize := len(p) + if reqSize > maxReq { + reqSize = maxReq // 避免一次申请超过桶容量导致错误 + } + err := rlr.limiter.WaitN(rlr.ctx, reqSize) + if err != nil { + return 0, err + } + } + return rlr.r.Read(p) +} + // handleConnection 处理单个连接 func (f *Forwarder) handleConnection(clientConn net.Conn) { defer f.wg.Done() @@ -138,7 +182,7 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) { Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, } - + // 动态解析域名并连接 targetAddr := fmt.Sprintf("%s:%d", f.targetHost, f.targetPort) targetConn, err := dialer.DialContext(f.ctx, "tcp", targetAddr) @@ -155,7 +199,12 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) { // 客户端 -> 目标 go func() { defer wg.Done() - n, _ := io.Copy(targetConn, clientConn) + reader := &rateLimitedReader{ + r: clientConn, + limiter: f.limiterOut, + ctx: f.ctx, + } + n, _ := io.Copy(targetConn, reader) atomic.AddUint64(&f.bytesSent, uint64(n)) // 关闭目标连接的写入端,通知对方不会再发送数据 if tcpConn, ok := targetConn.(*net.TCPConn); ok { @@ -166,7 +215,12 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) { // 目标 -> 客户端 go func() { defer wg.Done() - n, _ := io.Copy(clientConn, targetConn) + reader := &rateLimitedReader{ + r: targetConn, + limiter: f.limiterIn, + ctx: f.ctx, + } + n, _ := io.Copy(clientConn, reader) atomic.AddUint64(&f.bytesReceived, uint64(n)) // 关闭客户端连接的写入端 if tcpConn, ok := clientConn.(*net.TCPConn); ok { @@ -193,7 +247,7 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) { // Stop 停止端口转发 func (f *Forwarder) Stop() error { f.cancel() - + if f.listener != nil { if err := f.listener.Close(); err != nil { log.Printf("关闭监听器失败 (端口 %d): %v", f.sourcePort, err) @@ -231,7 +285,7 @@ func NewManager() *Manager { } // Add 添加并启动转发器 -func (m *Manager) Add(sourcePort int, targetHost string, targetPort int) error { +func (m *Manager) Add(sourcePort int, targetHost string, targetPort int, limit *int64) error { m.mu.Lock() defer m.mu.Unlock() @@ -239,7 +293,7 @@ func (m *Manager) Add(sourcePort int, targetHost string, targetPort int) error { return fmt.Errorf("端口 %d 已被占用", sourcePort) } - forwarder := NewForwarder(sourcePort, targetHost, targetPort) + forwarder := NewForwarder(sourcePort, targetHost, targetPort, limit) if err := forwarder.Start(); err != nil { return err } @@ -249,7 +303,7 @@ func (m *Manager) Add(sourcePort int, targetHost string, targetPort int) error { } // AddTunnel 添加使用隧道的转发器 -func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer) error { +func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer, limit *int64) error { m.mu.Lock() defer m.mu.Unlock() @@ -257,7 +311,7 @@ func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, t return fmt.Errorf("端口 %d 已被占用", sourcePort) } - forwarder := NewTunnelForwarder(sourcePort, targetHost, targetPort, tunnelServer) + forwarder := NewTunnelForwarder(sourcePort, targetHost, targetPort, tunnelServer, limit) if err := forwarder.Start(); err != nil { return err } @@ -319,11 +373,11 @@ func (f *Forwarder) GetTrafficStats() stats.TrafficStats { func (m *Manager) GetAllTrafficStats() map[int]stats.TrafficStats { m.mu.RLock() defer m.mu.RUnlock() - + statsMap := make(map[int]stats.TrafficStats) for port, forwarder := range m.forwarders { statsMap[port] = forwarder.GetTrafficStats() } - + return statsMap -} \ No newline at end of file +} diff --git a/src/server/forwarder/forwarder_test.go b/src/server/forwarder/forwarder_test.go index ac1b346..afc87e7 100644 --- a/src/server/forwarder/forwarder_test.go +++ b/src/server/forwarder/forwarder_test.go @@ -11,24 +11,24 @@ import ( // TestNewForwarder 测试创建转发器 func TestNewForwarder(t *testing.T) { - fwd := NewForwarder(8080, "192.168.1.100", 80) - + fwd := NewForwarder(8080, "192.168.1.100", 80, nil) + if fwd == nil { t.Fatal("创建转发器失败") } - + if fwd.sourcePort != 8080 { t.Errorf("源端口不正确,期望 8080,得到 %d", fwd.sourcePort) } - + if fwd.targetHost != "192.168.1.100" { t.Errorf("目标主机不正确,期望 192.168.1.100,得到 %s", fwd.targetHost) } - + if fwd.targetPort != 80 { t.Errorf("目标端口不正确,期望 80,得到 %d", fwd.targetPort) } - + if fwd.useTunnel { t.Error("普通转发器不应使用隧道") } @@ -48,8 +48,8 @@ func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp strin func (m *mockTunnelServer) IsConnected() bool { return m.connected } - -func (m *mockTunnelServer) GetTrafficStats() stats.TrafficStats { + +func (m *mockTunnelServer) GetTrafficStats() stats.TrafficStats { return stats.TrafficStats{} } @@ -57,17 +57,17 @@ func (m *mockTunnelServer) GetTrafficStats() stats.TrafficStats { func TestNewTunnelForwarder(t *testing.T) { // 创建模拟隧道服务器 mockServer := &mockTunnelServer{connected: true} - - fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer) - + + fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer, nil) + if fwd == nil { t.Fatal("创建隧道转发器失败") } - + if !fwd.useTunnel { t.Error("隧道转发器应使用隧道") } - + if fwd.tunnelServer == nil { t.Error("隧道服务器未设置") } @@ -81,28 +81,28 @@ func TestForwarderStartStop(t *testing.T) { t.Fatalf("创建目标服务器失败: %v", err) } defer targetListener.Close() - + targetPort := targetListener.Addr().(*net.TCPAddr).Port - + // 启动转发器到一个随机端口 - fwd := NewForwarder(0, "127.0.0.1", targetPort) - + fwd := NewForwarder(0, "127.0.0.1", targetPort, nil) + // 创建监听器 listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("创建监听器失败: %v", err) } - + fwd.listener = listener fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port - + // 启动接受循环 fwd.wg.Add(1) go fwd.acceptLoop() - + // 等待一段时间 time.Sleep(100 * time.Millisecond) - + // 停止转发器 err = fwd.Stop() if err != nil { @@ -118,9 +118,9 @@ func TestForwarderConnection(t *testing.T) { t.Fatalf("创建目标服务器失败: %v", err) } defer targetListener.Close() - + targetPort := targetListener.Addr().(*net.TCPAddr).Port - + // 在后台处理连接 go func() { conn, err := targetListener.Accept() @@ -128,41 +128,41 @@ func TestForwarderConnection(t *testing.T) { return } defer conn.Close() - + // 回显服务器 io.Copy(conn, conn) }() - + // 创建并启动转发器 - fwd := NewForwarder(0, "127.0.0.1", targetPort) - + fwd := NewForwarder(0, "127.0.0.1", targetPort, nil) + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("创建监听器失败: %v", err) } fwd.listener = listener fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port - + fwd.wg.Add(1) go fwd.acceptLoop() defer fwd.Stop() - + time.Sleep(100 * time.Millisecond) - + // 连接到转发器 client, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", fwd.sourcePort)) if err != nil { t.Fatalf("连接转发器失败: %v", err) } defer client.Close() - + // 发送数据 testData := []byte("Hello, World!") _, err = client.Write(testData) if err != nil { t.Fatalf("发送数据失败: %v", err) } - + // 读取响应 buf := make([]byte, len(testData)) client.SetReadDeadline(time.Now().Add(2 * time.Second)) @@ -170,11 +170,11 @@ func TestForwarderConnection(t *testing.T) { if err != nil { t.Fatalf("读取响应失败: %v", err) } - + if n != len(testData) { t.Errorf("读取数据长度不正确,期望 %d,得到 %d", len(testData), n) } - + if string(buf) != string(testData) { t.Errorf("数据不匹配,期望 %s,得到 %s", testData, buf) } @@ -183,16 +183,16 @@ func TestForwarderConnection(t *testing.T) { // TestManagerAdd 测试管理器添加转发器 func TestManagerAdd(t *testing.T) { mgr := NewManager() - + // 创建模拟目标服务器 targetListener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("创建目标服务器失败: %v", err) } defer targetListener.Close() - + targetPort := targetListener.Addr().(*net.TCPAddr).Port - + // 添加转发器到一个随机可用端口 fwdListener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -200,17 +200,17 @@ func TestManagerAdd(t *testing.T) { } sourcePort := fwdListener.Addr().(*net.TCPAddr).Port fwdListener.Close() // 关闭以便转发器可以使用这个端口 - - err = mgr.Add(sourcePort, "127.0.0.1", targetPort) + + err = mgr.Add(sourcePort, "127.0.0.1", targetPort, nil) if err != nil { t.Fatalf("添加转发器失败: %v", err) } - + // 验证转发器已添加 if !mgr.Exists(sourcePort) { t.Error("转发器应该存在") } - + // 清理 mgr.Remove(sourcePort) } @@ -218,7 +218,7 @@ func TestManagerAdd(t *testing.T) { // TestManagerAddDuplicate 测试添加重复转发器 func TestManagerAddDuplicate(t *testing.T) { mgr := NewManager() - + // 获取一个随机端口 listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -226,16 +226,16 @@ func TestManagerAddDuplicate(t *testing.T) { } sourcePort := listener.Addr().(*net.TCPAddr).Port listener.Close() - + // 添加第一个转发器 - err = mgr.Add(sourcePort, "127.0.0.1", 80) + err = mgr.Add(sourcePort, "127.0.0.1", 80, nil) if err != nil { t.Fatalf("添加第一个转发器失败: %v", err) } defer mgr.Remove(sourcePort) - + // 尝试添加重复端口 - err = mgr.Add(sourcePort, "127.0.0.1", 81) + err = mgr.Add(sourcePort, "127.0.0.1", 81, nil) if err == nil { t.Error("应该返回端口已占用错误") } @@ -244,7 +244,7 @@ func TestManagerAddDuplicate(t *testing.T) { // TestManagerRemove 测试移除转发器 func TestManagerRemove(t *testing.T) { mgr := NewManager() - + // 获取一个随机端口 listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -252,19 +252,19 @@ func TestManagerRemove(t *testing.T) { } sourcePort := listener.Addr().(*net.TCPAddr).Port listener.Close() - + // 添加转发器 - err = mgr.Add(sourcePort, "127.0.0.1", 80) + err = mgr.Add(sourcePort, "127.0.0.1", 80, nil) if err != nil { t.Fatalf("添加转发器失败: %v", err) } - + // 移除转发器 err = mgr.Remove(sourcePort) if err != nil { t.Errorf("移除转发器失败: %v", err) } - + // 验证转发器已移除 if mgr.Exists(sourcePort) { t.Error("转发器应该已被移除") @@ -274,7 +274,7 @@ func TestManagerRemove(t *testing.T) { // TestManagerRemoveNonExistent 测试移除不存在的转发器 func TestManagerRemoveNonExistent(t *testing.T) { mgr := NewManager() - + err := mgr.Remove(9999) if err == nil { t.Error("应该返回转发器不存在错误") @@ -284,12 +284,12 @@ func TestManagerRemoveNonExistent(t *testing.T) { // TestManagerExists 测试检查转发器是否存在 func TestManagerExists(t *testing.T) { mgr := NewManager() - + // 检查不存在的转发器 if mgr.Exists(8080) { t.Error("转发器不应该存在") } - + // 获取一个随机端口 listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -297,14 +297,14 @@ func TestManagerExists(t *testing.T) { } sourcePort := listener.Addr().(*net.TCPAddr).Port listener.Close() - + // 添加转发器 - err = mgr.Add(sourcePort, "127.0.0.1", 80) + err = mgr.Add(sourcePort, "127.0.0.1", 80, nil) if err != nil { t.Fatalf("添加转发器失败: %v", err) } defer mgr.Remove(sourcePort) - + // 检查存在的转发器 if !mgr.Exists(sourcePort) { t.Error("转发器应该存在") @@ -314,7 +314,7 @@ func TestManagerExists(t *testing.T) { // TestManagerStopAll 测试停止所有转发器 func TestManagerStopAll(t *testing.T) { mgr := NewManager() - + // 添加多个转发器 ports := make([]int, 0) for i := 0; i < 3; i++ { @@ -324,17 +324,17 @@ func TestManagerStopAll(t *testing.T) { } port := listener.Addr().(*net.TCPAddr).Port listener.Close() - - err = mgr.Add(port, "127.0.0.1", 80+i) + + err = mgr.Add(port, "127.0.0.1", 80+i, nil) if err != nil { t.Fatalf("添加转发器 %d 失败: %v", i, err) } ports = append(ports, port) } - + // 停止所有转发器 mgr.StopAll() - + // 验证所有转发器已停止 for _, port := range ports { if mgr.Exists(port) { @@ -345,27 +345,27 @@ func TestManagerStopAll(t *testing.T) { // TestForwarderContextCancellation 测试上下文取消 func TestForwarderContextCancellation(t *testing.T) { - fwd := NewForwarder(0, "127.0.0.1", 80) - + fwd := NewForwarder(0, "127.0.0.1", 80, nil) + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("创建监听器失败: %v", err) } fwd.listener = listener - + fwd.wg.Add(1) go fwd.acceptLoop() - + // 取消上下文 fwd.cancel() - + // 等待 goroutine 退出 done := make(chan struct{}) go func() { fwd.wg.Wait() close(done) }() - + select { case <-done: // 成功退出 @@ -379,9 +379,9 @@ func BenchmarkForwarderConnection(b *testing.B) { // 创建模拟目标服务器 targetListener, _ := net.Listen("tcp", "127.0.0.1:0") defer targetListener.Close() - + targetPort := targetListener.Addr().(*net.TCPAddr).Port - + // 后台处理连接 go func() { for { @@ -395,19 +395,19 @@ func BenchmarkForwarderConnection(b *testing.B) { }(conn) } }() - + // 创建转发器 - fwd := NewForwarder(0, "127.0.0.1", targetPort) + fwd := NewForwarder(0, "127.0.0.1", targetPort, nil) listener, _ := net.Listen("tcp", "127.0.0.1:0") fwd.listener = listener fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port - + fwd.wg.Add(1) go fwd.acceptLoop() defer fwd.Stop() - + time.Sleep(100 * time.Millisecond) - + b.ResetTimer() for i := 0; i < b.N; i++ { conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", fwd.sourcePort)) @@ -422,17 +422,17 @@ func BenchmarkForwarderConnection(b *testing.B) { // BenchmarkManagerOperations 基准测试管理器操作 func BenchmarkManagerOperations(b *testing.B) { mgr := NewManager() - + b.Run("Add", func(b *testing.B) { for i := 0; i < b.N; i++ { listener, _ := net.Listen("tcp", "127.0.0.1:0") port := listener.Addr().(*net.TCPAddr).Port listener.Close() - - mgr.Add(port, "127.0.0.1", 80) + + mgr.Add(port, "127.0.0.1", 80, nil) } }) - + b.Run("Exists", func(b *testing.B) { for i := 0; i < b.N; i++ { mgr.Exists(8080) diff --git a/src/server/html/management.html b/src/server/html/management.html index c2de565..0ee6085 100644 --- a/src/server/html/management.html +++ b/src/server/html/management.html @@ -420,6 +420,15 @@ +
+ + + + 限制此映射的传输速度,单位:字节/秒。留空表示不限制
+ 示例:1048576 = 1MB/s, 10485760 = 10MB/s +
+
+
@@ -617,12 +626,18 @@ document.getElementById('create-form').addEventListener('submit', async function (e) { e.preventDefault(); + const bandwidthLimitValue = document.getElementById('bandwidth-limit').value; const formData = { source_port: parseInt(document.getElementById('source-port').value), target_host: document.getElementById('target-host').value, target_port: parseInt(document.getElementById('target-port').value), use_tunnel: document.getElementById('use-tunnel').checked }; + + // 如果填写了带宽限制,则添加到请求中 + if (bandwidthLimitValue && bandwidthLimitValue.trim() !== '') { + formData.bandwidth_limit = parseInt(bandwidthLimitValue); + } try { const response = await fetch(getApiUrl('/api/mapping/create'), { diff --git a/src/server/main.go b/src/server/main.go index 08d313e..aeddbb0 100644 --- a/src/server/main.go +++ b/src/server/main.go @@ -81,10 +81,10 @@ func (s *serverService) Start() error { log.Printf("警告: 端口 %d 需要隧道模式但隧道服务未启用,跳过", mapping.SourcePort) continue } - err = s.fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, s.tunnelServer) + err = s.fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, s.tunnelServer, mapping.BandwidthLimit) } else { // 直接模式 - err = s.fwdManager.Add(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort) + err = s.fwdManager.Add(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, mapping.BandwidthLimit) } if err != nil {