feat: 直接tcp转发模式支持了限速

This commit is contained in:
qcqcqc@wsl 2026-01-08 17:55:21 +08:00
parent 2d40357789
commit 2655b5592f
10 changed files with 244 additions and 149 deletions

View File

@ -6,7 +6,8 @@ toolchain go1.24.4
require (
github.com/mattn/go-sqlite3 v1.14.22
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
)
require golang.org/x/sys v0.37.0 // indirect
require golang.org/x/sys v0.37.0

View File

@ -2,6 +2,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@ -39,10 +39,11 @@ func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Ser
// CreateMappingRequest 创建映射请求
type CreateMappingRequest struct {
SourcePort int `json:"source_port"` // 源端口(本地监听端口)
TargetPort int `json:"target_port"` // 目标端口(远程服务端口)
TargetHost string `json:"target_host"` // 目标主机支持IP或域名
UseTunnel bool `json:"use_tunnel"` // 是否使用隧道模式
SourcePort int `json:"source_port"` // 源端口(本地监听端口)
TargetPort int `json:"target_port"` // 目标端口(远程服务端口)
TargetHost string `json:"target_host"` // 目标主机支持IP或域名
UseTunnel bool `json:"use_tunnel"` // 是否使用隧道模式
BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制,字节/秒,可为空
}
// RemoveMappingRequest 删除映射请求
@ -168,8 +169,14 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
}
}
//BandwidthLimit 合理范围不小于0
if req.BandwidthLimit != nil && *req.BandwidthLimit < 0 {
h.writeError(w, http.StatusBadRequest, "带宽限制必须大于等于0")
return
}
// 添加到数据库
if err := h.db.AddMapping(req.SourcePort, req.TargetHost, req.TargetPort, req.UseTunnel); err != nil {
if err := h.db.AddMapping(req.SourcePort, req.TargetHost, req.TargetPort, req.UseTunnel, req.BandwidthLimit); err != nil {
h.writeError(w, http.StatusInternalServerError, "保存映射失败: "+err.Error())
return
}
@ -178,10 +185,10 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
var err error
if req.UseTunnel {
// 隧道模式:使用隧道转发
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetHost, req.TargetPort, h.tunnelServer)
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetHost, req.TargetPort, h.tunnelServer, req.BandwidthLimit)
} else {
// 直接模式直接TCP转发
err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort)
err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort, req.BandwidthLimit)
}
if err != nil {

View File

@ -341,8 +341,8 @@ func TestHandleRemoveMapping(t *testing.T) {
defer cleanup()
// 先创建一个映射
database.AddMapping(15000, "192.168.1.100", 15000, false)
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000)
database.AddMapping(15000, "192.168.1.100", 15000, false, nil)
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000, nil)
reqBody := RemoveMappingRequest{
Port: 15000,
@ -391,9 +391,9 @@ func TestHandleListMappings(t *testing.T) {
defer cleanup()
// 添加一些映射
database.AddMapping(15000, "192.168.1.100", 15000, false)
database.AddMapping(15001, "192.168.1.101", 15001, true)
database.AddMapping(15002, "192.168.1.102", 15002, false)
database.AddMapping(15000, "192.168.1.100", 15000, false, nil)
database.AddMapping(15001, "192.168.1.101", 15001, true, nil)
database.AddMapping(15002, "192.168.1.102", 15002, false, nil)
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
addAuthHeader(req)
@ -582,7 +582,7 @@ func BenchmarkHandleListMappings(b *testing.B) {
// 添加一些映射
for i := 0; i < 100; i++ {
useTunnel := i%2 == 0 // 偶数使用隧道模式
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel)
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel, nil)
}
fwdMgr := forwarder.NewManager()

View File

@ -12,12 +12,13 @@ import (
// Mapping 端口映射结构
type Mapping struct {
ID int64 `json:"id"`
SourcePort int `json:"source_port"`
TargetHost string `json:"target_host"` // 支持IP或域名
TargetPort int `json:"target_port"`
UseTunnel bool `json:"use_tunnel"`
CreatedAt string `json:"created_at"`
ID int64 `json:"id"`
SourcePort int `json:"source_port"`
TargetHost string `json:"target_host"` // 支持IP或域名
TargetPort int `json:"target_port"`
UseTunnel bool `json:"use_tunnel"`
BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制,字节/秒,可为空
CreatedAt string `json:"created_at"`
}
// Database 数据库管理器
@ -64,21 +65,22 @@ func (d *Database) initTables() error {
target_host TEXT NOT NULL,
target_port INTEGER NOT NULL,
use_tunnel BOOLEAN NOT NULL DEFAULT 0,
bandwidth_limit INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port);
`
_, err := d.db.Exec(query)
if err != nil {
return fmt.Errorf("初始化数据库表失败: %w", err)
}
// 检查是否需要迁移现有数据
if err := d.migrateDatabase(); err != nil {
return fmt.Errorf("数据库迁移失败: %w", err)
}
return nil
}
@ -93,23 +95,27 @@ func (d *Database) migrateDatabase() error {
hasUseTunnel := false
hasTargetHost := false
hasBandwidthLimit := false
for rows.Next() {
var cid int
var name, dataType string
var notNull, hasDefault int
var defaultValue interface{}
err := rows.Scan(&cid, &name, &dataType, &notNull, &defaultValue, &hasDefault)
if err != nil {
return fmt.Errorf("扫描表结构失败: %w", err)
}
if name == "use_tunnel" {
hasUseTunnel = true
}
if name == "target_host" {
hasTargetHost = true
}
if name == "bandwidth_limit" {
hasBandwidthLimit = true
}
}
// 如果不存在 use_tunnel 列,则添加它
@ -135,12 +141,12 @@ func (d *Database) migrateDatabase() error {
var name, dataType string
var notNull, hasDefault int
var defaultValue interface{}
err := rows2.Scan(&cid, &name, &dataType, &notNull, &defaultValue, &hasDefault)
if err != nil {
return fmt.Errorf("扫描表结构失败: %w", err)
}
if name == "target_ip" {
hasTargetIP = true
break
@ -162,20 +168,28 @@ func (d *Database) migrateDatabase() error {
}
}
// 如果不存在 bandwidth_limit 列,则添加它
if !hasBandwidthLimit {
_, err := d.db.Exec("ALTER TABLE mappings ADD COLUMN bandwidth_limit INTEGER")
if err != nil {
return fmt.Errorf("添加 bandwidth_limit 列失败: %w", err)
}
}
return nil
}
// AddMapping 添加端口映射
func (d *Database) AddMapping(sourcePort int, targetHost string, targetPort int, useTunnel bool) error {
// AddMapping 添加带宽限制的端口映射
func (d *Database) AddMapping(sourcePort int, targetHost string, targetPort int, useTunnel bool, bandwidthLimit *int64) error {
d.mu.Lock()
defer d.mu.Unlock()
query := `INSERT INTO mappings (source_port, target_host, target_port, use_tunnel) VALUES (?, ?, ?, ?)`
_, err := d.db.Exec(query, sourcePort, targetHost, targetPort, useTunnel)
query := `INSERT INTO mappings (source_port, target_host, target_port, use_tunnel, bandwidth_limit) VALUES (?, ?, ?, ?, ?)`
_, err := d.db.Exec(query, sourcePort, targetHost, targetPort, useTunnel, bandwidthLimit)
if err != nil {
return fmt.Errorf("添加端口映射失败: %w", err)
}
return nil
}
@ -189,16 +203,16 @@ func (d *Database) RemoveMapping(sourcePort int) error {
if err != nil {
return fmt.Errorf("删除端口映射失败: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("获取影响行数失败: %w", err)
}
if rows == 0 {
return fmt.Errorf("端口映射不存在")
}
return nil
}
@ -207,8 +221,8 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
d.mu.RLock()
defer d.mu.RUnlock()
query := `SELECT id, source_port, target_host, target_port, use_tunnel, created_at FROM mappings WHERE source_port = ?`
query := `SELECT id, source_port, target_host, target_port, use_tunnel, bandwidth_limit, created_at FROM mappings WHERE source_port = ?`
var mapping Mapping
err := d.db.QueryRow(query, sourcePort).Scan(
&mapping.ID,
@ -216,16 +230,17 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
&mapping.TargetHost,
&mapping.TargetPort,
&mapping.UseTunnel,
&mapping.BandwidthLimit,
&mapping.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询端口映射失败: %w", err)
}
return &mapping, nil
}
@ -234,14 +249,14 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
d.mu.RLock()
defer d.mu.RUnlock()
query := `SELECT id, source_port, target_host, target_port, use_tunnel, created_at FROM mappings ORDER BY source_port`
query := `SELECT id, source_port, target_host, target_port, use_tunnel, bandwidth_limit, created_at FROM mappings ORDER BY source_port`
rows, err := d.db.Query(query)
if err != nil {
return nil, fmt.Errorf("查询所有映射失败: %w", err)
}
defer rows.Close()
var mappings []*Mapping
for rows.Next() {
var mapping Mapping
@ -251,21 +266,22 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
&mapping.TargetHost,
&mapping.TargetPort,
&mapping.UseTunnel,
&mapping.BandwidthLimit,
&mapping.CreatedAt,
); err != nil {
return nil, fmt.Errorf("扫描映射记录失败: %w", err)
}
mappings = append(mappings, &mapping)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("遍历映射记录失败: %w", err)
}
return mappings, nil
}
// Close 关闭数据库连接
func (d *Database) Close() error {
return d.db.Close()
}
}

View File

@ -18,7 +18,7 @@ func TestDatabase(t *testing.T) {
defer db.Close()
t.Run("添加映射", func(t *testing.T) {
err := db.AddMapping(10001, "192.168.1.100", 22, false)
err := db.AddMapping(10001, "192.168.1.100", 22, false, nil)
if err != nil {
t.Errorf("添加映射失败: %v", err)
}
@ -44,7 +44,7 @@ func TestDatabase(t *testing.T) {
})
t.Run("添加重复映射应该失败", func(t *testing.T) {
err := db.AddMapping(10001, "192.168.1.101", 22, true)
err := db.AddMapping(10001, "192.168.1.101", 22, true, nil)
if err == nil {
t.Error("添加重复映射应该失败")
}
@ -52,8 +52,8 @@ func TestDatabase(t *testing.T) {
t.Run("获取所有映射", func(t *testing.T) {
// 添加更多映射
db.AddMapping(10002, "192.168.1.101", 22, true)
db.AddMapping(10003, "192.168.1.102", 22, false)
db.AddMapping(10002, "192.168.1.101", 22, true, nil)
db.AddMapping(10003, "192.168.1.102", 22, false, nil)
mappings, err := db.GetAllMappings()
if err != nil {
@ -102,7 +102,7 @@ func TestDatabaseConcurrency(t *testing.T) {
for i := 0; i < 10; i++ {
go func(port int) {
useTunnel := port%2 == 0 // 偶数端口使用隧道模式
err := db.AddMapping(10000+port, "192.168.1.100", port, useTunnel)
err := db.AddMapping(10000+port, "192.168.1.100", port, useTunnel, nil)
if err != nil {
t.Logf("添加映射失败 (端口 %d): %v", 10000+port, err)
}
@ -123,4 +123,4 @@ func TestDatabaseConcurrency(t *testing.T) {
if len(mappings) == 0 {
t.Error("应该至少有一些映射")
}
}
}

View File

@ -10,6 +10,8 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/time/rate"
)
// TunnelServer 隧道服务器接口
@ -30,15 +32,24 @@ type Forwarder struct {
wg sync.WaitGroup
tunnelServer TunnelServer
useTunnel bool
limit *int64
// 流量统计(使用原子操作)
bytesSent uint64 // 发送字节数
bytesReceived uint64 // 接收字节数
bytesSent uint64 // 发送字节数
bytesReceived uint64 // 接收字节数
limiterOut *rate.Limiter // 限速器(出方向)
limiterIn *rate.Limiter // 限速器(入方向)
}
// NewForwarder 创建新的端口转发器
func NewForwarder(sourcePort int, targetHost string, targetPort int) *Forwarder {
func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int64) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
var limiterOut, limiterIn *rate.Limiter
if limit != nil {
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
return &Forwarder{
sourcePort: sourcePort,
targetPort: targetPort,
@ -46,12 +57,21 @@ func NewForwarder(sourcePort int, targetHost string, targetPort int) *Forwarder
cancel: cancel,
ctx: ctx,
useTunnel: false,
limit: limit,
limiterOut: limiterOut,
limiterIn: limiterIn,
}
}
// NewTunnelForwarder 创建使用隧道的端口转发器
func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer) *Forwarder {
func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer, limit *int64) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
var limiterOut, limiterIn *rate.Limiter
if limit != nil {
burst := int(*limit) // 容量至少等于速率,不然无法正常突发
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
return &Forwarder{
sourcePort: sourcePort,
targetPort: targetPort,
@ -60,6 +80,9 @@ func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunne
useTunnel: true,
cancel: cancel,
ctx: ctx,
limit: limit,
limiterOut: limiterOut,
limiterIn: limiterIn,
}
}
@ -92,7 +115,7 @@ func (f *Forwarder) acceptLoop() {
// 设置接受超时,避免阻塞关闭
f.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second))
conn, err := f.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
@ -112,6 +135,27 @@ func (f *Forwarder) acceptLoop() {
}
}
type rateLimitedReader struct {
r io.Reader
limiter *rate.Limiter
ctx context.Context
}
func (rlr *rateLimitedReader) Read(p []byte) (int, error) {
if rlr.limiter != nil {
maxReq := rlr.limiter.Burst()
reqSize := len(p)
if reqSize > maxReq {
reqSize = maxReq // 避免一次申请超过桶容量导致错误
}
err := rlr.limiter.WaitN(rlr.ctx, reqSize)
if err != nil {
return 0, err
}
}
return rlr.r.Read(p)
}
// handleConnection 处理单个连接
func (f *Forwarder) handleConnection(clientConn net.Conn) {
defer f.wg.Done()
@ -138,7 +182,7 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
// 动态解析域名并连接
targetAddr := fmt.Sprintf("%s:%d", f.targetHost, f.targetPort)
targetConn, err := dialer.DialContext(f.ctx, "tcp", targetAddr)
@ -155,7 +199,12 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
// 客户端 -> 目标
go func() {
defer wg.Done()
n, _ := io.Copy(targetConn, clientConn)
reader := &rateLimitedReader{
r: clientConn,
limiter: f.limiterOut,
ctx: f.ctx,
}
n, _ := io.Copy(targetConn, reader)
atomic.AddUint64(&f.bytesSent, uint64(n))
// 关闭目标连接的写入端,通知对方不会再发送数据
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
@ -166,7 +215,12 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
// 目标 -> 客户端
go func() {
defer wg.Done()
n, _ := io.Copy(clientConn, targetConn)
reader := &rateLimitedReader{
r: targetConn,
limiter: f.limiterIn,
ctx: f.ctx,
}
n, _ := io.Copy(clientConn, reader)
atomic.AddUint64(&f.bytesReceived, uint64(n))
// 关闭客户端连接的写入端
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
@ -193,7 +247,7 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
// Stop 停止端口转发
func (f *Forwarder) Stop() error {
f.cancel()
if f.listener != nil {
if err := f.listener.Close(); err != nil {
log.Printf("关闭监听器失败 (端口 %d): %v", f.sourcePort, err)
@ -231,7 +285,7 @@ func NewManager() *Manager {
}
// Add 添加并启动转发器
func (m *Manager) Add(sourcePort int, targetHost string, targetPort int) error {
func (m *Manager) Add(sourcePort int, targetHost string, targetPort int, limit *int64) error {
m.mu.Lock()
defer m.mu.Unlock()
@ -239,7 +293,7 @@ func (m *Manager) Add(sourcePort int, targetHost string, targetPort int) error {
return fmt.Errorf("端口 %d 已被占用", sourcePort)
}
forwarder := NewForwarder(sourcePort, targetHost, targetPort)
forwarder := NewForwarder(sourcePort, targetHost, targetPort, limit)
if err := forwarder.Start(); err != nil {
return err
}
@ -249,7 +303,7 @@ func (m *Manager) Add(sourcePort int, targetHost string, targetPort int) error {
}
// AddTunnel 添加使用隧道的转发器
func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer) error {
func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer, limit *int64) error {
m.mu.Lock()
defer m.mu.Unlock()
@ -257,7 +311,7 @@ func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, t
return fmt.Errorf("端口 %d 已被占用", sourcePort)
}
forwarder := NewTunnelForwarder(sourcePort, targetHost, targetPort, tunnelServer)
forwarder := NewTunnelForwarder(sourcePort, targetHost, targetPort, tunnelServer, limit)
if err := forwarder.Start(); err != nil {
return err
}
@ -319,11 +373,11 @@ func (f *Forwarder) GetTrafficStats() stats.TrafficStats {
func (m *Manager) GetAllTrafficStats() map[int]stats.TrafficStats {
m.mu.RLock()
defer m.mu.RUnlock()
statsMap := make(map[int]stats.TrafficStats)
for port, forwarder := range m.forwarders {
statsMap[port] = forwarder.GetTrafficStats()
}
return statsMap
}
}

View File

@ -11,24 +11,24 @@ import (
// TestNewForwarder 测试创建转发器
func TestNewForwarder(t *testing.T) {
fwd := NewForwarder(8080, "192.168.1.100", 80)
fwd := NewForwarder(8080, "192.168.1.100", 80, nil)
if fwd == nil {
t.Fatal("创建转发器失败")
}
if fwd.sourcePort != 8080 {
t.Errorf("源端口不正确,期望 8080得到 %d", fwd.sourcePort)
}
if fwd.targetHost != "192.168.1.100" {
t.Errorf("目标主机不正确,期望 192.168.1.100,得到 %s", fwd.targetHost)
}
if fwd.targetPort != 80 {
t.Errorf("目标端口不正确,期望 80得到 %d", fwd.targetPort)
}
if fwd.useTunnel {
t.Error("普通转发器不应使用隧道")
}
@ -48,8 +48,8 @@ func (m *mockTunnelServer) ForwardConnection(clientConn net.Conn, targetIp strin
func (m *mockTunnelServer) IsConnected() bool {
return m.connected
}
func (m *mockTunnelServer) GetTrafficStats() stats.TrafficStats {
func (m *mockTunnelServer) GetTrafficStats() stats.TrafficStats {
return stats.TrafficStats{}
}
@ -57,17 +57,17 @@ func (m *mockTunnelServer) GetTrafficStats() stats.TrafficStats {
func TestNewTunnelForwarder(t *testing.T) {
// 创建模拟隧道服务器
mockServer := &mockTunnelServer{connected: true}
fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer)
fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer, nil)
if fwd == nil {
t.Fatal("创建隧道转发器失败")
}
if !fwd.useTunnel {
t.Error("隧道转发器应使用隧道")
}
if fwd.tunnelServer == nil {
t.Error("隧道服务器未设置")
}
@ -81,28 +81,28 @@ func TestForwarderStartStop(t *testing.T) {
t.Fatalf("创建目标服务器失败: %v", err)
}
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 启动转发器到一个随机端口
fwd := NewForwarder(0, "127.0.0.1", targetPort)
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil)
// 创建监听器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
fwd.listener = listener
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
// 启动接受循环
fwd.wg.Add(1)
go fwd.acceptLoop()
// 等待一段时间
time.Sleep(100 * time.Millisecond)
// 停止转发器
err = fwd.Stop()
if err != nil {
@ -118,9 +118,9 @@ func TestForwarderConnection(t *testing.T) {
t.Fatalf("创建目标服务器失败: %v", err)
}
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 在后台处理连接
go func() {
conn, err := targetListener.Accept()
@ -128,41 +128,41 @@ func TestForwarderConnection(t *testing.T) {
return
}
defer conn.Close()
// 回显服务器
io.Copy(conn, conn)
}()
// 创建并启动转发器
fwd := NewForwarder(0, "127.0.0.1", targetPort)
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
fwd.listener = listener
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
fwd.wg.Add(1)
go fwd.acceptLoop()
defer fwd.Stop()
time.Sleep(100 * time.Millisecond)
// 连接到转发器
client, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", fwd.sourcePort))
if err != nil {
t.Fatalf("连接转发器失败: %v", err)
}
defer client.Close()
// 发送数据
testData := []byte("Hello, World!")
_, err = client.Write(testData)
if err != nil {
t.Fatalf("发送数据失败: %v", err)
}
// 读取响应
buf := make([]byte, len(testData))
client.SetReadDeadline(time.Now().Add(2 * time.Second))
@ -170,11 +170,11 @@ func TestForwarderConnection(t *testing.T) {
if err != nil {
t.Fatalf("读取响应失败: %v", err)
}
if n != len(testData) {
t.Errorf("读取数据长度不正确,期望 %d得到 %d", len(testData), n)
}
if string(buf) != string(testData) {
t.Errorf("数据不匹配,期望 %s得到 %s", testData, buf)
}
@ -183,16 +183,16 @@ func TestForwarderConnection(t *testing.T) {
// TestManagerAdd 测试管理器添加转发器
func TestManagerAdd(t *testing.T) {
mgr := NewManager()
// 创建模拟目标服务器
targetListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建目标服务器失败: %v", err)
}
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 添加转发器到一个随机可用端口
fwdListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -200,17 +200,17 @@ func TestManagerAdd(t *testing.T) {
}
sourcePort := fwdListener.Addr().(*net.TCPAddr).Port
fwdListener.Close() // 关闭以便转发器可以使用这个端口
err = mgr.Add(sourcePort, "127.0.0.1", targetPort)
err = mgr.Add(sourcePort, "127.0.0.1", targetPort, nil)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
// 验证转发器已添加
if !mgr.Exists(sourcePort) {
t.Error("转发器应该存在")
}
// 清理
mgr.Remove(sourcePort)
}
@ -218,7 +218,7 @@ func TestManagerAdd(t *testing.T) {
// TestManagerAddDuplicate 测试添加重复转发器
func TestManagerAddDuplicate(t *testing.T) {
mgr := NewManager()
// 获取一个随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -226,16 +226,16 @@ func TestManagerAddDuplicate(t *testing.T) {
}
sourcePort := listener.Addr().(*net.TCPAddr).Port
listener.Close()
// 添加第一个转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80)
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil)
if err != nil {
t.Fatalf("添加第一个转发器失败: %v", err)
}
defer mgr.Remove(sourcePort)
// 尝试添加重复端口
err = mgr.Add(sourcePort, "127.0.0.1", 81)
err = mgr.Add(sourcePort, "127.0.0.1", 81, nil)
if err == nil {
t.Error("应该返回端口已占用错误")
}
@ -244,7 +244,7 @@ func TestManagerAddDuplicate(t *testing.T) {
// TestManagerRemove 测试移除转发器
func TestManagerRemove(t *testing.T) {
mgr := NewManager()
// 获取一个随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -252,19 +252,19 @@ func TestManagerRemove(t *testing.T) {
}
sourcePort := listener.Addr().(*net.TCPAddr).Port
listener.Close()
// 添加转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80)
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
// 移除转发器
err = mgr.Remove(sourcePort)
if err != nil {
t.Errorf("移除转发器失败: %v", err)
}
// 验证转发器已移除
if mgr.Exists(sourcePort) {
t.Error("转发器应该已被移除")
@ -274,7 +274,7 @@ func TestManagerRemove(t *testing.T) {
// TestManagerRemoveNonExistent 测试移除不存在的转发器
func TestManagerRemoveNonExistent(t *testing.T) {
mgr := NewManager()
err := mgr.Remove(9999)
if err == nil {
t.Error("应该返回转发器不存在错误")
@ -284,12 +284,12 @@ func TestManagerRemoveNonExistent(t *testing.T) {
// TestManagerExists 测试检查转发器是否存在
func TestManagerExists(t *testing.T) {
mgr := NewManager()
// 检查不存在的转发器
if mgr.Exists(8080) {
t.Error("转发器不应该存在")
}
// 获取一个随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -297,14 +297,14 @@ func TestManagerExists(t *testing.T) {
}
sourcePort := listener.Addr().(*net.TCPAddr).Port
listener.Close()
// 添加转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80)
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
defer mgr.Remove(sourcePort)
// 检查存在的转发器
if !mgr.Exists(sourcePort) {
t.Error("转发器应该存在")
@ -314,7 +314,7 @@ func TestManagerExists(t *testing.T) {
// TestManagerStopAll 测试停止所有转发器
func TestManagerStopAll(t *testing.T) {
mgr := NewManager()
// 添加多个转发器
ports := make([]int, 0)
for i := 0; i < 3; i++ {
@ -324,17 +324,17 @@ func TestManagerStopAll(t *testing.T) {
}
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
err = mgr.Add(port, "127.0.0.1", 80+i)
err = mgr.Add(port, "127.0.0.1", 80+i, nil)
if err != nil {
t.Fatalf("添加转发器 %d 失败: %v", i, err)
}
ports = append(ports, port)
}
// 停止所有转发器
mgr.StopAll()
// 验证所有转发器已停止
for _, port := range ports {
if mgr.Exists(port) {
@ -345,27 +345,27 @@ func TestManagerStopAll(t *testing.T) {
// TestForwarderContextCancellation 测试上下文取消
func TestForwarderContextCancellation(t *testing.T) {
fwd := NewForwarder(0, "127.0.0.1", 80)
fwd := NewForwarder(0, "127.0.0.1", 80, nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
fwd.listener = listener
fwd.wg.Add(1)
go fwd.acceptLoop()
// 取消上下文
fwd.cancel()
// 等待 goroutine 退出
done := make(chan struct{})
go func() {
fwd.wg.Wait()
close(done)
}()
select {
case <-done:
// 成功退出
@ -379,9 +379,9 @@ func BenchmarkForwarderConnection(b *testing.B) {
// 创建模拟目标服务器
targetListener, _ := net.Listen("tcp", "127.0.0.1:0")
defer targetListener.Close()
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 后台处理连接
go func() {
for {
@ -395,19 +395,19 @@ func BenchmarkForwarderConnection(b *testing.B) {
}(conn)
}
}()
// 创建转发器
fwd := NewForwarder(0, "127.0.0.1", targetPort)
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil)
listener, _ := net.Listen("tcp", "127.0.0.1:0")
fwd.listener = listener
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
fwd.wg.Add(1)
go fwd.acceptLoop()
defer fwd.Stop()
time.Sleep(100 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", fwd.sourcePort))
@ -422,17 +422,17 @@ func BenchmarkForwarderConnection(b *testing.B) {
// BenchmarkManagerOperations 基准测试管理器操作
func BenchmarkManagerOperations(b *testing.B) {
mgr := NewManager()
b.Run("Add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
listener, _ := net.Listen("tcp", "127.0.0.1:0")
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
mgr.Add(port, "127.0.0.1", 80)
mgr.Add(port, "127.0.0.1", 80, nil)
}
})
b.Run("Exists", func(b *testing.B) {
for i := 0; i < b.N; i++ {
mgr.Exists(8080)

View File

@ -420,6 +420,15 @@
<input type="number" id="target-port" min="1" max="65535" required placeholder="例如: 3000">
</div>
<div class="form-group">
<label for="bandwidth-limit">带宽限制 (可选)</label>
<input type="number" id="bandwidth-limit" min="0" placeholder="例如: 1048576 (1MB/s)">
<small style="color: #666; margin-top: 5px; display: block;">
限制此映射的传输速度,单位:字节/秒。留空表示不限制<br>
示例1048576 = 1MB/s, 10485760 = 10MB/s
</small>
</div>
<div class="form-group">
<label>连接模式</label>
<div class="checkbox-group">
@ -617,12 +626,18 @@
document.getElementById('create-form').addEventListener('submit', async function (e) {
e.preventDefault();
const bandwidthLimitValue = document.getElementById('bandwidth-limit').value;
const formData = {
source_port: parseInt(document.getElementById('source-port').value),
target_host: document.getElementById('target-host').value,
target_port: parseInt(document.getElementById('target-port').value),
use_tunnel: document.getElementById('use-tunnel').checked
};
// 如果填写了带宽限制,则添加到请求中
if (bandwidthLimitValue && bandwidthLimitValue.trim() !== '') {
formData.bandwidth_limit = parseInt(bandwidthLimitValue);
}
try {
const response = await fetch(getApiUrl('/api/mapping/create'), {

View File

@ -81,10 +81,10 @@ func (s *serverService) Start() error {
log.Printf("警告: 端口 %d 需要隧道模式但隧道服务未启用,跳过", mapping.SourcePort)
continue
}
err = s.fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, s.tunnelServer)
err = s.fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, s.tunnelServer, mapping.BandwidthLimit)
} else {
// 直接模式
err = s.fwdManager.Add(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort)
err = s.fwdManager.Add(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, mapping.BandwidthLimit)
}
if err != nil {