feat: 去除port范围限制,在添加时动态检查操作系统内的占用情况
This commit is contained in:
parent
a30a7e38b4
commit
35f7eaf3c8
|
|
@ -10,6 +10,7 @@ import (
|
|||
"port-forward/server/forwarder"
|
||||
"port-forward/server/stats"
|
||||
"port-forward/server/tunnel"
|
||||
"port-forward/server/utils"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -24,13 +25,13 @@ type Handler struct {
|
|||
}
|
||||
|
||||
// NewHandler 创建新的 API 处理器
|
||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, portFrom, portEnd int) *Handler {
|
||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server) *Handler {
|
||||
return &Handler{
|
||||
db: database,
|
||||
forwarderMgr: fwdMgr,
|
||||
tunnelServer: ts,
|
||||
portRangeFrom: portFrom,
|
||||
portRangeEnd: portEnd,
|
||||
// portRangeFrom: portFrom,
|
||||
// portRangeEnd: portEnd,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -109,6 +110,12 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
used := utils.PortCheck(req.SourcePort)
|
||||
if used {
|
||||
h.writeError(w, http.StatusConflict, "端口已被占用")
|
||||
return
|
||||
}
|
||||
|
||||
// 根据请求决定使用哪种模式
|
||||
if req.UseTunnel {
|
||||
// 隧道模式,检查隧道服务器是否可用
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
|
|||
tunnelServer.Start()
|
||||
}
|
||||
|
||||
handler := NewHandler(database, fwdMgr, tunnelServer, 10000, 20000)
|
||||
handler := NewHandler(database, fwdMgr, tunnelServer)
|
||||
|
||||
cleanup := func() {
|
||||
fwdMgr.StopAll()
|
||||
|
|
@ -543,7 +543,7 @@ func BenchmarkHandleHealth(b *testing.B) {
|
|||
defer database.Close()
|
||||
|
||||
fwdMgr := forwarder.NewManager()
|
||||
handler := NewHandler(database, fwdMgr, nil, 10000, 20000)
|
||||
handler := NewHandler(database, fwdMgr, nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
|
||||
|
|
@ -568,7 +568,7 @@ func BenchmarkHandleListMappings(b *testing.B) {
|
|||
}
|
||||
|
||||
fwdMgr := forwarder.NewManager()
|
||||
handler := NewHandler(database, fwdMgr, nil, 10000, 20000)
|
||||
handler := NewHandler(database, fwdMgr, nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
// Config 应用配置结构
|
||||
type Config struct {
|
||||
PortRange PortRangeConfig `yaml:"port_range"`
|
||||
// PortRange PortRangeConfig `yaml:"port_range"`
|
||||
Tunnel TunnelConfig `yaml:"tunnel"`
|
||||
API APIConfig `yaml:"api"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
|
|
@ -59,15 +59,15 @@ func Load(path string) (*Config, error) {
|
|||
|
||||
// Validate 验证配置的有效性
|
||||
func (c *Config) Validate() error {
|
||||
if c.PortRange.From <= 0 || c.PortRange.End <= 0 {
|
||||
return fmt.Errorf("端口范围必须大于 0")
|
||||
}
|
||||
if c.PortRange.From > c.PortRange.End {
|
||||
return fmt.Errorf("起始端口不能大于结束端口")
|
||||
}
|
||||
if c.PortRange.End-c.PortRange.From > 30000 {
|
||||
return fmt.Errorf("端口范围过大,最多支持 30000 个端口")
|
||||
}
|
||||
// if c.PortRange.From <= 0 || c.PortRange.End <= 0 {
|
||||
// return fmt.Errorf("端口范围必须大于 0")
|
||||
// }
|
||||
// if c.PortRange.From > c.PortRange.End {
|
||||
// return fmt.Errorf("起始端口不能大于结束端口")
|
||||
// }
|
||||
// if c.PortRange.End-c.PortRange.From > 30000 {
|
||||
// return fmt.Errorf("端口范围过大,最多支持 30000 个端口")
|
||||
// }
|
||||
if c.Tunnel.Enabled && c.Tunnel.ListenPort <= 0 {
|
||||
return fmt.Errorf("内网穿透端口必须大于 0")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,12 +40,12 @@ database:
|
|||
}
|
||||
|
||||
// 验证配置
|
||||
if cfg.PortRange.From != 10000 {
|
||||
t.Errorf("期望起始端口为 10000,得到 %d", cfg.PortRange.From)
|
||||
}
|
||||
if cfg.PortRange.End != 10100 {
|
||||
t.Errorf("期望结束端口为 10100,得到 %d", cfg.PortRange.End)
|
||||
}
|
||||
// if cfg.PortRange.From != 10000 {
|
||||
// t.Errorf("期望起始端口为 10000,得到 %d", cfg.PortRange.From)
|
||||
// }
|
||||
// if cfg.PortRange.End != 10100 {
|
||||
// t.Errorf("期望结束端口为 10100,得到 %d", cfg.PortRange.End)
|
||||
// }
|
||||
if !cfg.Tunnel.Enabled {
|
||||
t.Error("期望隧道启用")
|
||||
}
|
||||
|
|
@ -111,14 +111,14 @@ database:
|
|||
tmpFile.Write([]byte(configContent))
|
||||
tmpFile.Close()
|
||||
|
||||
cfg, err := Load(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Errorf("边界值配置应该有效: %v", err)
|
||||
}
|
||||
// cfg, err := Load(tmpFile.Name())
|
||||
// if err != nil {
|
||||
// t.Errorf("边界值配置应该有效: %v", err)
|
||||
// }
|
||||
|
||||
if cfg != nil && (cfg.PortRange.End-cfg.PortRange.From) != 9999 {
|
||||
t.Errorf("端口范围计算不正确")
|
||||
}
|
||||
// if cfg != nil && (cfg.PortRange.End-cfg.PortRange.From) != 9999 {
|
||||
// t.Errorf("端口范围计算不正确")
|
||||
// }
|
||||
}
|
||||
|
||||
// BenchmarkLoadConfig 基准测试配置加载
|
||||
|
|
@ -150,7 +150,7 @@ database:
|
|||
// BenchmarkValidateConfig 基准测试配置验证
|
||||
func BenchmarkValidateConfig(b *testing.B) {
|
||||
cfg := &Config{
|
||||
PortRange: PortRangeConfig{From: 10000, End: 20000},
|
||||
// PortRange: PortRangeConfig{From: 10000, End: 20000},
|
||||
Tunnel: TunnelConfig{Enabled: true, ListenPort: 9000},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/db.sqlite"},
|
||||
|
|
@ -170,7 +170,7 @@ func TestValidateConfig(t *testing.T) {
|
|||
{
|
||||
name: "有效配置",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
// PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
Tunnel: TunnelConfig{Enabled: true, ListenPort: 9000},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
|
|
@ -180,7 +180,7 @@ func TestValidateConfig(t *testing.T) {
|
|||
{
|
||||
name: "无效端口范围 - 起始端口为0",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 0, End: 10100},
|
||||
// PortRange: PortRangeConfig{From: 0, End: 10100},
|
||||
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
|
|
@ -190,7 +190,7 @@ func TestValidateConfig(t *testing.T) {
|
|||
{
|
||||
name: "无效端口范围 - 起始大于结束",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 10100, End: 10000},
|
||||
// PortRange: PortRangeConfig{From: 10100, End: 10000},
|
||||
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
|
|
@ -200,7 +200,7 @@ func TestValidateConfig(t *testing.T) {
|
|||
{
|
||||
name: "端口范围过大",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 1, End: 40000},
|
||||
// PortRange: PortRangeConfig{From: 1, End: 40000},
|
||||
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
|
|
@ -210,7 +210,7 @@ func TestValidateConfig(t *testing.T) {
|
|||
{
|
||||
name: "启用隧道但端口无效",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
// PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
Tunnel: TunnelConfig{Enabled: true, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
|
|
@ -220,7 +220,7 @@ func TestValidateConfig(t *testing.T) {
|
|||
{
|
||||
name: "数据库路径为空",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
// PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: ""},
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"port-forward/server/db"
|
||||
"port-forward/server/forwarder"
|
||||
"port-forward/server/tunnel"
|
||||
"port-forward/server/utils"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -60,8 +61,16 @@ func main() {
|
|||
|
||||
for _, mapping := range mappings {
|
||||
// 验证端口在范围内
|
||||
if mapping.SourcePort < cfg.PortRange.From || mapping.SourcePort > cfg.PortRange.End {
|
||||
log.Printf("警告: 端口 %d 超出范围,跳过", mapping.SourcePort)
|
||||
// if mapping.SourcePort < cfg.PortRange.From || mapping.SourcePort > cfg.PortRange.End {
|
||||
// log.Printf("警告: 端口 %d 超出范围,跳过", mapping.SourcePort)
|
||||
// continue
|
||||
// }
|
||||
|
||||
used := utils.PortCheck(mapping.SourcePort)
|
||||
|
||||
if used {
|
||||
log.Printf("警告: 端口 %d 已被占!", mapping.SourcePort)
|
||||
os.Exit(1)
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
@ -91,8 +100,8 @@ func main() {
|
|||
database,
|
||||
fwdManager,
|
||||
tunnelServer,
|
||||
cfg.PortRange.From,
|
||||
cfg.PortRange.End,
|
||||
// cfg.PortRange.From,
|
||||
// cfg.PortRange.End,
|
||||
)
|
||||
|
||||
// 启动 HTTP API 服务器
|
||||
|
|
@ -123,7 +132,7 @@ func main() {
|
|||
|
||||
log.Println("===========================================")
|
||||
log.Printf("服务器启动成功!")
|
||||
log.Printf("端口范围: %d-%d", cfg.PortRange.From, cfg.PortRange.End)
|
||||
// log.Printf("端口范围: %d-%d", cfg.PortRange.From, cfg.PortRange.End)
|
||||
log.Printf("HTTP API: http://localhost:%d", cfg.API.ListenPort)
|
||||
// log.Printf("调试接口: http://localhost:%d/debug/pprof/", pprofPort)
|
||||
if cfg.Tunnel.Enabled {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// PortCheck 检查端口是否可用,可用-true 不可用-false
|
||||
func PortCheck(port int) bool {
|
||||
l, err := net.Listen("tcp", fmt.Sprintf(":%s", strconv.Itoa(port)))
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer l.Close()
|
||||
return true
|
||||
}
|
||||
Loading…
Reference in New Issue