feat: 支持了黑白名单功能

This commit is contained in:
qcqcqc@wsl 2026-01-09 16:44:46 +08:00
parent cdc1972ffa
commit 074f5fd8bd
9 changed files with 607 additions and 60 deletions

View File

@ -40,11 +40,13 @@ func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Ser
// CreateMappingRequest 创建映射请求
type CreateMappingRequest struct {
SourcePort int `json:"source_port"` // 源端口(本地监听端口)
TargetPort int `json:"target_port"` // 目标端口(远程服务端口)
TargetHost string `json:"target_host"` // 目标主机支持IP或域名
UseTunnel bool `json:"use_tunnel"` // 是否使用隧道模式
BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制,字节/秒,可为空
SourcePort int `json:"source_port"` // 源端口(本地监听端口)
TargetPort int `json:"target_port"` // 目标端口(远程服务端口)
TargetHost string `json:"target_host"` // 目标主机支持IP或域名
UseTunnel bool `json:"use_tunnel"` // 是否使用隧道模式
BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制,字节/秒,可为空
AccessRule *string `json:"access_rule,omitempty"` // 访问控制规则:"whitelist", "blacklist", "disabled"
AccessIPs *string `json:"access_ips,omitempty"` // IP列表JSON格式
}
// RemoveMappingRequest 删除映射请求
@ -52,6 +54,14 @@ type RemoveMappingRequest struct {
Port int `json:"port"`
}
// UpdateAccessRuleRequest 更新访问规则请求
type UpdateAccessRuleRequest struct {
Port int `json:"port"` // 端口号
BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制,字节/秒,可为空
AccessRule *string `json:"access_rule,omitempty"` // 访问控制规则:"whitelist", "blacklist", "disabled"
AccessIPs *string `json:"access_ips,omitempty"` // IP列表JSON格式
}
// Response 统一响应格式
type Response struct {
Success bool `json:"success"`
@ -65,6 +75,7 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/mapping/create", h.authMiddleware(h.handleCreateMapping))
mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping))
mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings))
mux.HandleFunc("/api/mapping/update", h.authMiddleware(h.handleUpdateMapping))
mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats))
mux.HandleFunc("/api/stats/history", h.authMiddleware(h.handleGetTrafficHistory))
mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor))
@ -183,8 +194,16 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
return
}
// 验证AccessRule如果提供则必须是有效值
if req.AccessRule != nil {
if *req.AccessRule != "whitelist" && *req.AccessRule != "blacklist" && *req.AccessRule != "disabled" {
h.writeError(w, http.StatusBadRequest, "访问控制规则必须是 'whitelist', 'blacklist' 或 'disabled'")
return
}
}
// 添加到数据库
if err := h.db.AddMapping(req.SourcePort, req.TargetHost, req.TargetPort, req.UseTunnel, req.BandwidthLimit); err != nil {
if err := h.db.AddMapping(req.SourcePort, req.TargetHost, req.TargetPort, req.UseTunnel, req.BandwidthLimit, req.AccessRule, req.AccessIPs); err != nil {
h.writeError(w, http.StatusInternalServerError, "保存映射失败: "+err.Error())
return
}
@ -193,10 +212,10 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
var err error
if req.UseTunnel {
// 隧道模式:使用隧道转发
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetHost, req.TargetPort, h.tunnelServer, req.BandwidthLimit)
err = h.forwarderMgr.AddTunnel(req.SourcePort, req.TargetHost, req.TargetPort, h.tunnelServer, req.BandwidthLimit, req.AccessRule, req.AccessIPs)
} else {
// 直接模式直接TCP转发
err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort, req.BandwidthLimit)
err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort, req.BandwidthLimit, req.AccessRule, req.AccessIPs)
}
if err != nil {
@ -390,12 +409,6 @@ func (h *Handler) handleTrafficMonitor(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, GetTraffticMonitorHTML())
}
// handleConnections 连接监控页面
func (h *Handler) handleConnections(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, GetConnectionsHTML())
}
// handleRoot 根路径重定向
func (h *Handler) handleRoot(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
@ -474,6 +487,25 @@ func (h *Handler) handleGetActiveConnections(w http.ResponseWriter, r *http.Requ
// 获取所有活跃连接
connectionsStats := h.forwarderMgr.GetAllActiveConnections()
// 获取所有映射的访问规则信息
allMappings, err := h.db.GetAllMappings()
if err == nil {
// 将访问规则信息添加到连接统计中
mappingRules := make(map[int]*db.Mapping)
for _, m := range allMappings {
mappingRules[m.SourcePort] = m
}
// 合并访问规则信息到连接统计
for i := range connectionsStats {
if mapping, exists := mappingRules[connectionsStats[i].SourcePort]; exists {
connectionsStats[i].AccessRule = mapping.AccessRule
connectionsStats[i].AccessIPs = mapping.AccessIPs
connectionsStats[i].BandwidthLimit = mapping.BandwidthLimit
}
}
}
// 按端口号排序
sort.Slice(connectionsStats, func(i, j int) bool {
return connectionsStats[i].SourcePort < connectionsStats[j].SourcePort
@ -487,3 +519,53 @@ func (h *Handler) handleGetActiveConnections(w http.ResponseWriter, r *http.Requ
h.writeSuccess(w, "获取活跃连接成功", response)
}
// handleUpdateMapping 处理更新访问规则请求
func (h *Handler) handleUpdateMapping(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
h.writeError(w, http.StatusMethodNotAllowed, "只支持 POST 方法")
return
}
var req UpdateAccessRuleRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
// 验证AccessRule如果提供则必须是有效值
if req.AccessRule != nil {
if *req.AccessRule != "whitelist" && *req.AccessRule != "blacklist" && *req.AccessRule != "disabled" {
h.writeError(w, http.StatusBadRequest, "访问控制规则必须是 'whitelist', 'blacklist' 或 'disabled'")
return
}
}
//BandwidthLimit 合理范围不小于0
if req.BandwidthLimit != nil && *req.BandwidthLimit < 0 {
h.writeError(w, http.StatusBadRequest, "带宽限制必须大于等于0")
return
}
// 更新数据库
if err := h.db.UpdateMapping(req.Port, req.BandwidthLimit, req.AccessRule, req.AccessIPs); err != nil {
h.writeError(w, http.StatusInternalServerError, "更新访问规则失败: "+err.Error())
return
}
log.Printf("更新端口 %d 的映射", req.Port)
h.writeSuccess(w, "映射更新成功", map[string]interface{}{
"port": req.Port,
"access_rule": req.AccessRule,
"access_ips": req.AccessIPs,
"bandwidth_limit": req.BandwidthLimit,
})
// 更新转发器的访问规则和带宽限制
fwd := h.forwarderMgr.GetForwarder(req.Port)
if fwd != nil {
fwd.UpdateAccessControl(req.AccessRule, req.AccessIPs)
fwd.UpdateBandwidthLimit(req.BandwidthLimit)
}
}

View File

@ -341,8 +341,8 @@ func TestHandleRemoveMapping(t *testing.T) {
defer cleanup()
// 先创建一个映射
database.AddMapping(15000, "192.168.1.100", 15000, false, nil)
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000, nil)
database.AddMapping(15000, "192.168.1.100", 15000, false, nil, nil, nil)
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000, nil, nil, nil)
reqBody := RemoveMappingRequest{
Port: 15000,
@ -391,9 +391,9 @@ func TestHandleListMappings(t *testing.T) {
defer cleanup()
// 添加一些映射
database.AddMapping(15000, "192.168.1.100", 15000, false, nil)
database.AddMapping(15001, "192.168.1.101", 15001, true, nil)
database.AddMapping(15002, "192.168.1.102", 15002, false, nil)
database.AddMapping(15000, "192.168.1.100", 15000, false, nil, nil, nil)
database.AddMapping(15001, "192.168.1.101", 15001, true, nil, nil, nil)
database.AddMapping(15002, "192.168.1.102", 15002, false, nil, nil, nil)
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
addAuthHeader(req)
@ -582,7 +582,7 @@ func BenchmarkHandleListMappings(b *testing.B) {
// 添加一些映射
for i := 0; i < 100; i++ {
useTunnel := i%2 == 0 // 偶数使用隧道模式
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel, nil)
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel, nil, nil, nil)
}
fwdMgr := forwarder.NewManager()

View File

@ -19,6 +19,9 @@ type Mapping struct {
UseTunnel bool `json:"use_tunnel"`
BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制,字节/秒,可为空
CreatedAt string `json:"created_at"`
// 规则,白名单还是黑名单
AccessRule *string `json:"access_rule,omitempty"` // 可选访问控制规则JSON格式
AccessIPs *string `json:"access_ips,omitempty"` // 可选访问控制的IP列表JSON格式
}
// TrafficRecord 流量统计记录
@ -76,6 +79,8 @@ func (d *Database) initTables() error {
target_port INTEGER NOT NULL,
use_tunnel BOOLEAN NOT NULL DEFAULT 0,
bandwidth_limit INTEGER,
access_rule TEXT,
access_ips TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port);
@ -118,6 +123,8 @@ func (d *Database) migrateDatabase() error {
hasUseTunnel := false
hasTargetHost := false
hasBandwidthLimit := false
hasAccessRule := false
hasAccessIPs := false
for rows.Next() {
var cid int
var name, dataType string
@ -138,6 +145,12 @@ func (d *Database) migrateDatabase() error {
if name == "bandwidth_limit" {
hasBandwidthLimit = true
}
if name == "access_rule" {
hasAccessRule = true
}
if name == "access_ips" {
hasAccessIPs = true
}
}
// 如果不存在 use_tunnel 列,则添加它
@ -198,6 +211,22 @@ func (d *Database) migrateDatabase() error {
}
}
// 如果不存在 access_rule 列,则添加它
if !hasAccessRule {
_, err := d.db.Exec("ALTER TABLE mappings ADD COLUMN access_rule TEXT")
if err != nil {
return fmt.Errorf("添加 access_rule 列失败: %w", err)
}
}
// 如果不存在 access_ips 列,则添加它
if !hasAccessIPs {
_, err := d.db.Exec("ALTER TABLE mappings ADD COLUMN access_ips TEXT")
if err != nil {
return fmt.Errorf("添加 access_ips 列失败: %w", err)
}
}
// 检查 traffic_records 表的 is_tunnel 列
rows3, err := d.db.Query("PRAGMA table_info(traffic_records)")
if err == nil {
@ -233,12 +262,12 @@ func (d *Database) migrateDatabase() error {
}
// AddMapping 添加带宽限制的端口映射
func (d *Database) AddMapping(sourcePort int, targetHost string, targetPort int, useTunnel bool, bandwidthLimit *int64) error {
func (d *Database) AddMapping(sourcePort int, targetHost string, targetPort int, useTunnel bool, bandwidthLimit *int64, accessRule *string, accessIPs *string) error {
d.mu.Lock()
defer d.mu.Unlock()
query := `INSERT INTO mappings (source_port, target_host, target_port, use_tunnel, bandwidth_limit) VALUES (?, ?, ?, ?, ?)`
_, err := d.db.Exec(query, sourcePort, targetHost, targetPort, useTunnel, bandwidthLimit)
query := `INSERT INTO mappings (source_port, target_host, target_port, use_tunnel, bandwidth_limit, access_rule, access_ips) VALUES (?, ?, ?, ?, ?, ?, ?)`
_, err := d.db.Exec(query, sourcePort, targetHost, targetPort, useTunnel, bandwidthLimit, accessRule, accessIPs)
if err != nil {
return fmt.Errorf("添加端口映射失败: %w", err)
}
@ -269,12 +298,57 @@ func (d *Database) RemoveMapping(sourcePort int) error {
return nil
}
// UpdateMapping 更新端口映射的访问控制规则
func (d *Database) UpdateMapping(sourcePort int, limit *int64, accessRule *string, accessIPs *string) error {
d.mu.Lock()
defer d.mu.Unlock()
// 先把老的查出来如果是nil的话就赋值上老的
selectQuery := `SELECT bandwidth_limit, access_rule, access_ips FROM mappings WHERE source_port = ?`
var oldLimit sql.NullInt64
var oldAccessRule sql.NullString
var oldAccessIPs sql.NullString
err := d.db.QueryRow(selectQuery, sourcePort).Scan(&oldLimit, &oldAccessRule, &oldAccessIPs)
if err != nil {
return fmt.Errorf("查询现有映射失败: %w", err)
}
if limit == nil && oldLimit.Valid {
limit = &oldLimit.Int64
}
if accessRule == nil && oldAccessRule.Valid {
accessRule = &oldAccessRule.String
}
if accessIPs == nil && oldAccessIPs.Valid {
accessIPs = &oldAccessIPs.String
}
// 然后更新
query := `UPDATE mappings SET bandwidth_limit = ?, access_rule = ?, access_ips = ? WHERE source_port = ?`
result, err := d.db.Exec(query, limit, accessRule, accessIPs, sourcePort)
if err != nil {
return fmt.Errorf("更新访问规则失败: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("获取影响行数失败: %w", err)
}
if rows == 0 {
return fmt.Errorf("端口映射不存在")
}
return nil
}
// GetMapping 获取指定端口的映射
func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
d.mu.RLock()
defer d.mu.RUnlock()
query := `SELECT id, source_port, target_host, target_port, use_tunnel, bandwidth_limit, created_at FROM mappings WHERE source_port = ?`
query := `SELECT id, source_port, target_host, target_port, use_tunnel, bandwidth_limit, access_rule, access_ips, created_at FROM mappings WHERE source_port = ?`
var mapping Mapping
err := d.db.QueryRow(query, sourcePort).Scan(
@ -284,6 +358,8 @@ func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
&mapping.TargetPort,
&mapping.UseTunnel,
&mapping.BandwidthLimit,
&mapping.AccessRule,
&mapping.AccessIPs,
&mapping.CreatedAt,
)
@ -302,7 +378,7 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
d.mu.RLock()
defer d.mu.RUnlock()
query := `SELECT id, source_port, target_host, target_port, use_tunnel, bandwidth_limit, created_at FROM mappings ORDER BY source_port`
query := `SELECT id, source_port, target_host, target_port, use_tunnel, bandwidth_limit, access_rule, access_ips, created_at FROM mappings ORDER BY source_port`
rows, err := d.db.Query(query)
if err != nil {
@ -320,6 +396,8 @@ func (d *Database) GetAllMappings() ([]*Mapping, error) {
&mapping.TargetPort,
&mapping.UseTunnel,
&mapping.BandwidthLimit,
&mapping.AccessRule,
&mapping.AccessIPs,
&mapping.CreatedAt,
); err != nil {
return nil, fmt.Errorf("扫描映射记录失败: %w", err)

View File

@ -18,7 +18,7 @@ func TestDatabase(t *testing.T) {
defer db.Close()
t.Run("添加映射", func(t *testing.T) {
err := db.AddMapping(10001, "192.168.1.100", 22, false, nil)
err := db.AddMapping(10001, "192.168.1.100", 22, false, nil, nil, nil)
if err != nil {
t.Errorf("添加映射失败: %v", err)
}
@ -44,7 +44,7 @@ func TestDatabase(t *testing.T) {
})
t.Run("添加重复映射应该失败", func(t *testing.T) {
err := db.AddMapping(10001, "192.168.1.101", 22, true, nil)
err := db.AddMapping(10001, "192.168.1.101", 22, true, nil, nil, nil)
if err == nil {
t.Error("添加重复映射应该失败")
}
@ -52,8 +52,8 @@ func TestDatabase(t *testing.T) {
t.Run("获取所有映射", func(t *testing.T) {
// 添加更多映射
db.AddMapping(10002, "192.168.1.101", 22, true, nil)
db.AddMapping(10003, "192.168.1.102", 22, false, nil)
db.AddMapping(10002, "192.168.1.101", 22, true, nil, nil, nil)
db.AddMapping(10003, "192.168.1.102", 22, false, nil, nil, nil)
mappings, err := db.GetAllMappings()
if err != nil {
@ -102,7 +102,7 @@ func TestDatabaseConcurrency(t *testing.T) {
for i := 0; i < 10; i++ {
go func(port int) {
useTunnel := port%2 == 0 // 偶数端口使用隧道模式
err := db.AddMapping(10000+port, "192.168.1.100", port, useTunnel, nil)
err := db.AddMapping(10000+port, "192.168.1.100", port, useTunnel, nil, nil, nil)
if err != nil {
t.Logf("添加映射失败 (端口 %d): %v", 10000+port, err)
}

View File

@ -2,11 +2,13 @@ package forwarder
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net"
"port-forward/server/stats"
"strings"
"sync"
"sync/atomic"
"time"
@ -45,6 +47,40 @@ type Forwarder struct {
// 活跃连接管理
connections map[string]*activeConnection // 活跃连接映射
connMutex sync.RWMutex // 连接映射锁
// 访问控制
accessRule string // 访问控制规则: "disabled", "whitelist", "blacklist"
accessIPs []string // IP列表CIDR或单个IP
}
func (f *Forwarder) UpdateBandwidthLimit(limit *int64) {
if limit == nil {
f.limiterOut = nil
f.limiterIn = nil
f.limit = nil
log.Printf("取消端口 %d 的带宽限制", f.sourcePort)
return
}
f.limit = limit
// burst设置为1秒的流量这样可以平滑处理突发
burst := int(*limit) / 100
if burst < 10240 {
burst = 10240 // 最小burst为10KB
}
log.Printf("更新端口 %d 的带宽限制: %d bytes/sec, burst: %d bytes", f.sourcePort, *limit, burst)
f.limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
f.limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
func (f *Forwarder) UpdateAccessControl(rule *string, ps *string) {
if rule == nil || *rule == "disabled" {
f.accessRule = "disabled"
f.accessIPs = nil
log.Printf("禁用端口 %d 的访问控制", f.sourcePort)
return
}
f.accessRule = *rule
f.accessIPs = parseIPList(ps)
}
// activeConnection 活跃连接信息
@ -56,8 +92,77 @@ type activeConnection struct {
connectedAt int64
}
// parseIPList 解析IP列表JSON字符串为切片
func parseIPList(ipListJSON *string) []string {
if ipListJSON == nil || *ipListJSON == "" {
return nil
}
var ips []string
if err := json.Unmarshal([]byte(*ipListJSON), &ips); err != nil {
log.Printf("解析IP列表失败: %v", err)
return nil
}
return ips
}
// checkIPAllowed 检查IP是否允许访问
func (f *Forwarder) checkIPAllowed(remoteAddr string) bool {
// 如果规则被禁用,允许所有连接
if f.accessRule == "" || f.accessRule == "disabled" {
return true
}
// 提取IP地址去除端口
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
log.Printf("解析远程地址失败: %v", err)
return false
}
clientIP := net.ParseIP(host)
if clientIP == nil {
log.Printf("无效的IP地址: %s", host)
return false
}
// 检查IP是否在列表中
inList := false
for _, ipStr := range f.accessIPs {
// 检查是否为CIDR网段
if strings.Contains(ipStr, "/") {
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
log.Printf("无效的CIDR: %s", ipStr)
continue
}
if ipNet.Contains(clientIP) {
inList = true
break
}
} else {
// 单个IP地址
ip := net.ParseIP(ipStr)
if ip != nil && ip.Equal(clientIP) {
inList = true
break
}
}
}
// 根据规则类型返回结果
if f.accessRule == "whitelist" {
return inList // 白名单只允许列表中的IP
} else if f.accessRule == "blacklist" {
return !inList // 黑名单拒绝列表中的IP
}
return true
}
// NewForwarder 创建新的端口转发器
func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int64) *Forwarder {
func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int64, accessRule *string, accessIPs *string) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
var limiterOut, limiterIn *rate.Limiter
if limit != nil {
@ -71,6 +176,13 @@ func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int6
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
// 解析访问规则
rule := "disabled"
if accessRule != nil {
rule = *accessRule
}
return &Forwarder{
sourcePort: sourcePort,
targetPort: targetPort,
@ -82,11 +194,13 @@ func NewForwarder(sourcePort int, targetHost string, targetPort int, limit *int6
limiterOut: limiterOut,
limiterIn: limiterIn,
connections: make(map[string]*activeConnection),
accessRule: rule,
accessIPs: parseIPList(accessIPs),
}
}
// NewTunnelForwarder 创建使用隧道的端口转发器
func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer, limit *int64) *Forwarder {
func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer, limit *int64, accessRule *string, accessIPs *string) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
var limiterOut, limiterIn *rate.Limiter
if limit != nil {
@ -100,6 +214,13 @@ func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunne
limiterOut = rate.NewLimiter(rate.Limit(*limit), burst)
limiterIn = rate.NewLimiter(rate.Limit(*limit), burst)
}
// 解析访问规则
rule := "disabled"
if accessRule != nil {
rule = *accessRule
}
return &Forwarder{
sourcePort: sourcePort,
targetPort: targetPort,
@ -112,6 +233,8 @@ func NewTunnelForwarder(sourcePort int, targetHost string, targetPort int, tunne
limiterOut: limiterOut,
limiterIn: limiterIn,
connections: make(map[string]*activeConnection),
accessRule: rule,
accessIPs: parseIPList(accessIPs),
}
}
@ -232,6 +355,19 @@ func (f *Forwarder) handleConnection(clientConn net.Conn) {
connID := uuid.New().String()
clientAddr := clientConn.RemoteAddr().String()
// 检查IP访问控制
if !f.checkIPAllowed(clientAddr) {
log.Printf("端口 %d 拒绝连接: %s (访问规则: %s, 连接ID: %s)", f.sourcePort, clientAddr, f.accessRule, connID)
// 白名单模式:返回简单提示(便于调试配置错误)
// 黑名单模式:静默拒绝(不暴露信息给恶意访问者)
if f.accessRule == "whitelist" {
clientConn.Write([]byte("Access denied: IP not in whitelist\n"))
}
return
}
log.Printf("端口 %d 收到新连接: %s (连接ID: %s)", f.sourcePort, clientAddr, connID)
var targetConn net.Conn
@ -389,6 +525,12 @@ type Manager struct {
mu sync.RWMutex
}
func (m *Manager) GetForwarder(port int) *Forwarder {
m.mu.RLock()
defer m.mu.RUnlock()
return m.forwarders[port]
}
// NewManager 创建新的转发器管理器
func NewManager() *Manager {
return &Manager{
@ -397,7 +539,7 @@ func NewManager() *Manager {
}
// Add 添加并启动转发器
func (m *Manager) Add(sourcePort int, targetHost string, targetPort int, limit *int64) error {
func (m *Manager) Add(sourcePort int, targetHost string, targetPort int, limit *int64, accessRule *string, accessIPs *string) error {
m.mu.Lock()
defer m.mu.Unlock()
@ -405,7 +547,7 @@ func (m *Manager) Add(sourcePort int, targetHost string, targetPort int, limit *
return fmt.Errorf("端口 %d 已被占用", sourcePort)
}
forwarder := NewForwarder(sourcePort, targetHost, targetPort, limit)
forwarder := NewForwarder(sourcePort, targetHost, targetPort, limit, accessRule, accessIPs)
if err := forwarder.Start(); err != nil {
return err
}
@ -415,7 +557,7 @@ func (m *Manager) Add(sourcePort int, targetHost string, targetPort int, limit *
}
// AddTunnel 添加使用隧道的转发器
func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer, limit *int64) error {
func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, tunnelServer TunnelServer, limit *int64, accessRule *string, accessIPs *string) error {
m.mu.Lock()
defer m.mu.Unlock()
@ -423,7 +565,7 @@ func (m *Manager) AddTunnel(sourcePort int, targetHost string, targetPort int, t
return fmt.Errorf("端口 %d 已被占用", sourcePort)
}
forwarder := NewTunnelForwarder(sourcePort, targetHost, targetPort, tunnelServer, limit)
forwarder := NewTunnelForwarder(sourcePort, targetHost, targetPort, tunnelServer, limit, accessRule, accessIPs)
if err := forwarder.Start(); err != nil {
return err
}

View File

@ -11,7 +11,7 @@ import (
// TestNewForwarder 测试创建转发器
func TestNewForwarder(t *testing.T) {
fwd := NewForwarder(8080, "192.168.1.100", 80, nil)
fwd := NewForwarder(8080, "192.168.1.100", 80, nil, nil, nil)
if fwd == nil {
t.Fatal("创建转发器失败")
@ -58,7 +58,7 @@ func TestNewTunnelForwarder(t *testing.T) {
// 创建模拟隧道服务器
mockServer := &mockTunnelServer{connected: true}
fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer, nil)
fwd := NewTunnelForwarder(8080, "127.0.0.1", 80, mockServer, nil, nil, nil)
if fwd == nil {
t.Fatal("创建隧道转发器失败")
@ -85,7 +85,7 @@ func TestForwarderStartStop(t *testing.T) {
targetPort := targetListener.Addr().(*net.TCPAddr).Port
// 启动转发器到一个随机端口
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil)
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil, nil, nil)
// 创建监听器
listener, err := net.Listen("tcp", "127.0.0.1:0")
@ -134,7 +134,7 @@ func TestForwarderConnection(t *testing.T) {
}()
// 创建并启动转发器
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil)
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil, nil, nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -201,7 +201,7 @@ func TestManagerAdd(t *testing.T) {
sourcePort := fwdListener.Addr().(*net.TCPAddr).Port
fwdListener.Close() // 关闭以便转发器可以使用这个端口
err = mgr.Add(sourcePort, "127.0.0.1", targetPort, nil)
err = mgr.Add(sourcePort, "127.0.0.1", targetPort, nil, nil, nil)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
@ -228,14 +228,14 @@ func TestManagerAddDuplicate(t *testing.T) {
listener.Close()
// 添加第一个转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil)
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil, nil, nil)
if err != nil {
t.Fatalf("添加第一个转发器失败: %v", err)
}
defer mgr.Remove(sourcePort)
// 尝试添加重复端口
err = mgr.Add(sourcePort, "127.0.0.1", 81, nil)
err = mgr.Add(sourcePort, "127.0.0.1", 81, nil, nil, nil)
if err == nil {
t.Error("应该返回端口已占用错误")
}
@ -254,7 +254,7 @@ func TestManagerRemove(t *testing.T) {
listener.Close()
// 添加转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil)
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil, nil, nil)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
@ -299,7 +299,7 @@ func TestManagerExists(t *testing.T) {
listener.Close()
// 添加转发器
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil)
err = mgr.Add(sourcePort, "127.0.0.1", 80, nil, nil, nil)
if err != nil {
t.Fatalf("添加转发器失败: %v", err)
}
@ -325,7 +325,7 @@ func TestManagerStopAll(t *testing.T) {
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
err = mgr.Add(port, "127.0.0.1", 80+i, nil)
err = mgr.Add(port, "127.0.0.1", 80+i, nil, nil, nil)
if err != nil {
t.Fatalf("添加转发器 %d 失败: %v", i, err)
}
@ -345,7 +345,7 @@ func TestManagerStopAll(t *testing.T) {
// TestForwarderContextCancellation 测试上下文取消
func TestForwarderContextCancellation(t *testing.T) {
fwd := NewForwarder(0, "127.0.0.1", 80, nil)
fwd := NewForwarder(0, "127.0.0.1", 80, nil, nil, nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -397,7 +397,7 @@ func BenchmarkForwarderConnection(b *testing.B) {
}()
// 创建转发器
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil)
fwd := NewForwarder(0, "127.0.0.1", targetPort, nil, nil, nil)
listener, _ := net.Listen("tcp", "127.0.0.1:0")
fwd.listener = listener
fwd.sourcePort = listener.Addr().(*net.TCPAddr).Port
@ -429,7 +429,7 @@ func BenchmarkManagerOperations(b *testing.B) {
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
mgr.Add(port, "127.0.0.1", 80, nil)
mgr.Add(port, "127.0.0.1", 80, nil, nil, nil)
}
})

View File

@ -760,6 +760,7 @@
<th>目标端口</th>
<th>模式</th>
<th>带宽限制</th>
<th>访问规则</th>
<th>操作</th>
</tr>
</thead>
@ -774,6 +775,17 @@
</span>
</td>
<td>{{ mapping.bandwidth_limit ? formatBytes(mapping.bandwidth_limit) + '/s' : '无限制' }}</td>
<td>
<span v-if="!mapping.access_rule || mapping.access_rule === 'disabled'" class="badge" style="background: #6c757d; color: white;">
无限制
</span>
<span v-else-if="mapping.access_rule === 'whitelist'" class="badge" style="background: #28a745; color: white;">
🛡️ 白名单
</span>
<span v-else-if="mapping.access_rule === 'blacklist'" class="badge" style="background: #dc3545; color: white;">
🚫 黑名单
</span>
</td>
<td>
<button class="btn btn-danger btn-sm" @click="deleteMapping(mapping.source_port)">
删除
@ -827,6 +839,28 @@
<small class="form-help">单位:字节/秒,例如 1048576 表示 1MB/s</small>
</div>
<div class="form-group">
<label>访问控制规则</label>
<select v-model="newMapping.accessRule">
<option value="disabled">禁用(不限制访问)</option>
<option value="whitelist">白名单仅允许列表中的IP</option>
<option value="blacklist">黑名单拒绝列表中的IP</option>
</select>
</div>
<div class="form-group" v-if="newMapping.accessRule !== 'disabled'">
<label>IP列表</label>
<textarea
v-model="newMapping.accessIPsText"
rows="4"
placeholder="每行一个IP地址或CIDR网段"
style="width: 100%; padding: 12px 16px; border: 2px solid #e0e0e0; border-radius: 8px; font-size: 1em; font-family: monospace; resize: vertical;"
></textarea>
<small class="form-help">
{{ newMapping.accessRule === 'whitelist' ? '只有这些IP可以连接' : '这些IP将被拒绝连接' }}
</small>
</div>
<div style="display: flex; gap: 10px; margin-top: 25px;">
<button type="submit" class="btn btn-primary" style="flex: 1;">创建</button>
<button type="button" class="btn" style="flex: 1; background: #6c757d; color: white;" @click="showCreateModal = false">取消</button>
@ -847,7 +881,9 @@
targetHost: '',
targetPort: '',
useTunnel: false,
bandwidthLimit: null
bandwidthLimit: null,
accessRule: 'disabled',
accessIPsText: ''
}
};
},
@ -885,6 +921,20 @@
data.bandwidth_limit = this.newMapping.bandwidthLimit;
}
// 添加访问规则
if (this.newMapping.accessRule && this.newMapping.accessRule !== 'disabled') {
data.access_rule = this.newMapping.accessRule;
// 处理IP列表
if (this.newMapping.accessIPsText) {
const ips = this.newMapping.accessIPsText
.split('\n')
.map(ip => ip.trim())
.filter(ip => ip.length > 0);
data.access_ips = JSON.stringify(ips);
}
}
const response = await axios.post('/api/mapping/create', data, {
headers: { 'X-API-Key': this.apiKey }
});
@ -924,7 +974,9 @@
targetHost: '',
targetPort: '',
useTunnel: false,
bandwidthLimit: null
bandwidthLimit: null,
accessRule: 'disabled',
accessIPsText: ''
};
},
formatBytes(bytes) {
@ -1033,6 +1085,9 @@
</span>
</div>
<div style="display: flex; gap: 10px; align-items: center;">
<button class="btn btn-sm" @click="showRuleModal(mapping)" style="background: #17a2b8; color: white;">
{{ getRuleButtonText(mapping) }}
</button>
<div class="badge" style="background: #667eea; color: white; font-size: 1em;">
{{ mapping.total_connections }} 个连接
</div>
@ -1090,6 +1145,62 @@
<p style="color: #999; margin: 0;">暂无活跃连接</p>
</div>
</div>
<!-- 规则管理模态框 -->
<div v-if="showRuleModalFlag" class="modal-overlay" @click.self="closeRuleModal">
<div class="modal">
<div class="modal-header">
<h3 class="modal-title">访问规则管理 - 端口 {{ currentMapping.source_port }}</h3>
<button class="close-btn" @click="closeRuleModal">×</button>
</div>
<form @submit.prevent="updateRule">
<div class="form-group">
<label>访问控制规则</label>
<select v-model="ruleForm.accessRule" required>
<option value="disabled">禁用(不限制访问)</option>
<option value="whitelist">白名单仅允许列表中的IP</option>
<option value="blacklist">黑名单拒绝列表中的IP</option>
</select>
<small class="form-help">选择访问控制模式</small>
</div>
<div class="form-group">
<label>带宽限制(可选)</label>
<input
type="number"
v-model.number="ruleForm.bandwidthLimit"
min="0"
placeholder="留空表示无限制"
>
<small class="form-help">单位:字节/秒,例如 1048576 表示 1MB/s。当前: {{ formatBandwidth(ruleForm.bandwidthLimit) }}</small>
</div>
<div class="form-group" v-if="ruleForm.accessRule !== 'disabled'">
<label>IP列表</label>
<textarea
v-model="ruleForm.accessIPsText"
rows="6"
placeholder="每行一个IP地址或CIDR网段&#10;例如:&#10;192.168.1.100&#10;10.0.0.0/24&#10;172.16.0.1"
style="width: 100%; padding: 12px 16px; border: 2px solid #e0e0e0; border-radius: 8px; font-size: 1em; font-family: monospace; resize: vertical; transition: all 0.3s;"
></textarea>
<small class="form-help">
{{ ruleForm.accessRule === 'whitelist' ? '只有这些IP可以连接' : '这些IP将被拒绝连接' }}
</small>
</div>
<div v-if="ruleForm.accessRule !== 'disabled' && ruleForm.accessIPsText"
style="margin-bottom: 20px; padding: 12px; background: #f0f9ff; border-radius: 8px; border-left: 4px solid #17a2b8;">
<strong>当前IP数量:</strong> {{ getIPCount() }}
</div>
<div style="display: flex; gap: 10px; margin-top: 25px;">
<button type="submit" class="btn btn-primary" style="flex: 1;">保存规则</button>
<button type="button" class="btn" style="flex: 1; background: #6c757d; color: white;" @click="closeRuleModal">取消</button>
</div>
</form>
</div>
</div>
</div>
`,
props: ['apiKey', 'refreshTrigger'],
@ -1098,7 +1209,14 @@
return {
mappings: [],
autoRefresh: false,
refreshInterval: null
refreshInterval: null,
showRuleModalFlag: false,
currentMapping: {},
ruleForm: {
accessRule: 'disabled',
accessIPsText: '',
bandwidthLimit: 0
}
};
},
computed: {
@ -1200,6 +1318,13 @@
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
},
formatBandwidth(bytes) {
if (!bytes) return '无限制';
const k = 1024;
const sizes = ['B/s', 'KB/s', 'MB/s', 'GB/s', 'TB/s'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
},
formatDuration(timestamp) {
if (!timestamp) return '-';
const now = Math.floor(Date.now() / 1000);
@ -1211,6 +1336,95 @@
if (duration < 3600) return Math.floor(duration / 60) + ' 分钟';
if (duration < 86400) return Math.floor(duration / 3600) + ' 小时';
return Math.floor(duration / 86400) + ' 天';
},
showRuleModal(mapping) {
this.currentMapping = mapping;
this.showRuleModalFlag = true;
// 设置当前规则
this.ruleForm.accessRule = mapping.access_rule || 'disabled';
// 设置带宽限制
this.ruleForm.bandwidthLimit = mapping.bandwidth_limit || 0;
// 解析IP列表
if (mapping.access_ips) {
try {
const ips = JSON.parse(mapping.access_ips);
this.ruleForm.accessIPsText = Array.isArray(ips) ? ips.join('\n') : '';
} catch (e) {
this.ruleForm.accessIPsText = '';
}
} else {
this.ruleForm.accessIPsText = '';
}
},
closeRuleModal() {
this.showRuleModalFlag = false;
this.currentMapping = {};
this.ruleForm = {
accessRule: 'disabled',
accessIPsText: '',
bandwidthLimit: 0
};
},
async updateRule() {
try {
const data = {
port: this.currentMapping.source_port,
access_rule: this.ruleForm.accessRule,
bandwidth_limit: this.ruleForm.bandwidthLimit || 0
};
// 如果不是禁用模式则处理IP列表
if (this.ruleForm.accessRule !== 'disabled') {
const ips = this.ruleForm.accessIPsText
.split('\n')
.map(ip => ip.trim())
.filter(ip => ip.length > 0);
data.access_ips = JSON.stringify(ips);
} else {
data.access_ips = null;
}
const response = await axios.post('/api/mapping/update', data, {
headers: { 'X-API-Key': this.apiKey }
});
if (response.data.success) {
this.$emit('notify', 'success', '规则更新成功');
this.closeRuleModal();
this.loadConnections();
}
} catch (error) {
this.$emit('notify', 'error', error.response?.data?.message || '更新规则失败');
}
},
getIPCount() {
if (!this.ruleForm.accessIPsText) return 0;
return this.ruleForm.accessIPsText
.split('\n')
.map(ip => ip.trim())
.filter(ip => ip.length > 0).length;
},
getRuleButtonText(mapping) {
if (!mapping.access_rule || mapping.access_rule === 'disabled') {
return '🛡️ 规则管理';
}
let ipCount = 0;
if (mapping.access_ips) {
try {
const ips = JSON.parse(mapping.access_ips);
ipCount = Array.isArray(ips) ? ips.length : 0;
} catch (e) {
ipCount = 0;
}
}
const ruleText = mapping.access_rule === 'whitelist' ? '白名单' : '黑名单';
return `🛡️ ${ruleText} (${ipCount}个IP)`;
}
}
};
@ -1262,8 +1476,15 @@
},
mounted() {
this.checkAuth();
this.initPageFromHash();
this.loadStats();
setInterval(() => this.loadStats(), 5000);
// 监听hash变化
window.addEventListener('hashchange', this.handleHashChange);
},
beforeUnmount() {
window.removeEventListener('hashchange', this.handleHashChange);
},
methods: {
checkAuth() {
@ -1274,6 +1495,27 @@
},
changePage(page) {
this.currentPage = page;
// 更新URL hash
window.location.hash = page;
},
initPageFromHash() {
// 从URL hash读取当前页面
const hash = window.location.hash.slice(1); // 去掉 # 号
const validPages = this.menuItems.map(item => item.id);
if (hash && validPages.includes(hash)) {
this.currentPage = hash;
} else {
// 默认页面设置hash
window.location.hash = this.currentPage;
}
},
handleHashChange() {
// 处理hash变化
const hash = window.location.hash.slice(1);
const validPages = this.menuItems.map(item => item.id);
if (hash && validPages.includes(hash)) {
this.currentPage = hash;
}
},
async loadStats() {
try {

View File

@ -83,10 +83,10 @@ func (s *serverService) Start() error {
log.Printf("警告: 端口 %d 需要隧道模式但隧道服务未启用,跳过", mapping.SourcePort)
continue
}
err = s.fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, s.tunnelServer, mapping.BandwidthLimit)
err = s.fwdManager.AddTunnel(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, s.tunnelServer, mapping.BandwidthLimit, mapping.AccessRule, mapping.AccessIPs)
} else {
// 直接模式
err = s.fwdManager.Add(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, mapping.BandwidthLimit)
err = s.fwdManager.Add(mapping.SourcePort, mapping.TargetHost, mapping.TargetPort, mapping.BandwidthLimit, mapping.AccessRule, mapping.AccessIPs)
}
if err != nil {

View File

@ -35,12 +35,15 @@ type ConnectionInfo struct {
// PortConnectionStats 端口连接统计
type PortConnectionStats struct {
SourcePort int `json:"source_port"` // 源端口
TargetHost string `json:"target_host"` // 目标主机
TargetPort int `json:"target_port"` // 目标端口
UseTunnel bool `json:"use_tunnel"` // 是否使用隧道
ActiveConnections []ConnectionInfo `json:"active_connections"` // 活跃连接列表
TotalConnections int `json:"total_connections"` // 总连接数
SourcePort int `json:"source_port"` // 源端口
TargetHost string `json:"target_host"` // 目标主机
TargetPort int `json:"target_port"` // 目标端口
UseTunnel bool `json:"use_tunnel"` // 是否使用隧道
ActiveConnections []ConnectionInfo `json:"active_connections"` // 活跃连接列表
TotalConnections int `json:"total_connections"` // 总连接数
BandwidthLimit *int64 `json:"bandwidth_limit,omitempty"` // 带宽限制
AccessRule *string `json:"access_rule,omitempty"` // 访问控制规则
AccessIPs *string `json:"access_ips,omitempty"` // 访问控制IP列表
}
// AllConnectionsStats 所有连接统计