diff --git a/API_AUTH.md b/API_AUTH.md new file mode 100644 index 0000000..ae661e1 --- /dev/null +++ b/API_AUTH.md @@ -0,0 +1,136 @@ +# API 认证说明 + +## 概述 + +所有 API 接口(除了 `/health` 健康检查接口)都需要提供有效的 API 密钥才能访问。 + +## 配置 API 密钥 + +在配置文件 `config.yaml` 中设置 API 密钥: + +```yaml +# HTTP API 配置 +api: + listen_port: 8080 + api_key: "your-secret-api-key-here" # 修改为你的密钥 +``` + +**安全建议:** +- 使用强密码生成器生成复杂的 API 密钥 +- 定期更换 API 密钥 +- 不要在公共代码库中提交包含真实密钥的配置文件 + +## 使用 API 密钥 + +### 方式 1: 通过 HTTP 请求头(推荐) + +在请求头中添加 `X-API-Key`: + +```bash +curl -X POST http://localhost:8080/api/mapping/create \ + -H "Content-Type: application/json" \ + -H "X-API-Key: your-secret-api-key-here" \ + -d '{ + "source_port": 30001, + "target_host": "192.168.1.100", + "target_port": 3306, + "use_tunnel": false + }' +``` + +### 方式 2: 通过 URL 查询参数 + +在 URL 中添加 `api_key` 参数: + +```bash +curl -X POST "http://localhost:8080/api/mapping/create?api_key=your-secret-api-key-here" \ + -H "Content-Type: application/json" \ + -d '{ + "source_port": 30001, + "target_host": "192.168.1.100", + "target_port": 3306, + "use_tunnel": false + }' +``` + +### 使用 PowerShell + +```powershell +$headers = @{ + "Content-Type" = "application/json" + "X-API-Key" = "your-secret-api-key-here" +} + +$body = @{ + source_port = 30001 + target_host = "192.168.1.100" + target_port = 3306 + use_tunnel = $false +} | ConvertTo-Json + +Invoke-RestMethod -Uri "http://localhost:8080/api/mapping/create" ` + -Method Post ` + -Headers $headers ` + -Body $body +``` + +### 使用 Python + +```python +import requests + +headers = { + 'Content-Type': 'application/json', + 'X-API-Key': 'your-secret-api-key-here' +} + +data = { + 'source_port': 30001, + 'target_host': '192.168.1.100', + 'target_port': 3306, + 'use_tunnel': False +} + +response = requests.post( + 'http://localhost:8080/api/mapping/create', + headers=headers, + json=data +) + +print(response.json()) +``` + +## 需要认证的 API 接口 + +以下接口都需要提供有效的 API 密钥: + +- `POST /api/mapping/create` - 创建端口映射 +- `POST /api/mapping/remove` - 删除端口映射 +- `GET /api/mapping/list` - 列出所有映射 +- `GET /api/stats/traffic` - 获取流量统计 +- `GET /api/stats/monitor` - 流量监控页面 +- `GET /admin` - 管理页面 + +## 不需要认证的接口 + +- `GET /health` - 健康检查接口(公开访问) + +## 错误响应 + +如果 API 密钥无效或缺失,服务器将返回 401 状态码: + +```json +{ + "success": false, + "message": "无效的 API 密钥" +} +``` + +## 浏览器访问 + +对于需要通过浏览器访问的页面(如 `/admin` 和 `/api/stats/monitor`),可以在 URL 中添加 `api_key` 参数: + +``` +http://localhost:8080/admin?api_key=your-secret-api-key-here +http://localhost:8080/api/stats/monitor?api_key=your-secret-api-key-here +``` diff --git a/src/config.yaml b/src/config.yaml index 8514b18..81dfef5 100644 --- a/src/config.yaml +++ b/src/config.yaml @@ -13,6 +13,7 @@ tunnel: # HTTP API 配置 api: listen_port: 8080 + api_key: "your-secret-api-key-here" # 数据库配置 database: diff --git a/src/server/api/api.go b/src/server/api/api.go index 3fcce0d..6a9c573 100644 --- a/src/server/api/api.go +++ b/src/server/api/api.go @@ -17,19 +17,21 @@ import ( // Handler HTTP API 处理器 type Handler struct { - db *db.Database - forwarderMgr *forwarder.Manager - tunnelServer *tunnel.Server + db *db.Database + forwarderMgr *forwarder.Manager + tunnelServer *tunnel.Server + apiKey string // portRangeFrom int // portRangeEnd int } // NewHandler 创建新的 API 处理器 -func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server) *Handler { +func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, apiKey string) *Handler { return &Handler{ - db: database, - forwarderMgr: fwdMgr, - tunnelServer: ts, + db: database, + forwarderMgr: fwdMgr, + tunnelServer: ts, + apiKey: apiKey, // portRangeFrom: portFrom, // portRangeEnd: portEnd, } @@ -57,15 +59,37 @@ type Response struct { // RegisterRoutes 注册路由 func (h *Handler) RegisterRoutes(mux *http.ServeMux) { - mux.HandleFunc("/api/mapping/create", h.handleCreateMapping) - mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping) - mux.HandleFunc("/api/mapping/list", h.handleListMappings) - mux.HandleFunc("/api/stats/traffic", h.handleGetTrafficStats) - mux.HandleFunc("/api/stats/monitor", h.handleTrafficMonitor) - mux.HandleFunc("/admin", h.handleManagement) + mux.HandleFunc("/api/mapping/create", h.authMiddleware(h.handleCreateMapping)) + mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping)) + mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings)) + mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats)) + mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor)) + mux.HandleFunc("/admin", h.authMiddleware(h.handleManagement)) mux.HandleFunc("/health", h.handleHealth) } +// authMiddleware 认证中间件 +func (h *Handler) authMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // 从请求头中获取 API Key + apiKey := r.Header.Get("X-API-Key") + + // 如果请求头中没有,尝试从查询参数中获取 + if apiKey == "" { + apiKey = r.URL.Query().Get("api_key") + } + + // 验证 API Key + if apiKey != h.apiKey { + h.writeError(w, http.StatusUnauthorized, "无效的 API 密钥") + return + } + + // 认证通过,继续处理请求 + next(w, r) + } +} + // validateHostOrIP 验证主机名或IP地址 func (h *Handler) validateHostOrIP(hostOrIP string) error { if hostOrIP == "" { @@ -156,7 +180,7 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) { // 直接模式:直接TCP转发 err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort) } - + if err != nil { // 回滚数据库操作 h.db.RemoveMapping(req.SourcePort) @@ -305,11 +329,11 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request) // 获取所有端口映射的流量统计 forwarderStats := h.forwarderMgr.GetAllTrafficStats() - + // 构建响应 mappings := make([]stats.PortTrafficStats, 0, len(forwarderStats)) var totalSent, totalReceived uint64 - + for port, stat := range forwarderStats { mappings = append(mappings, stats.PortTrafficStats{ Port: port, @@ -320,11 +344,11 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request) totalSent += stat.BytesSent totalReceived += stat.BytesReceived } - + // 加上隧道的流量 totalSent += tunnelStats.BytesSent totalReceived += tunnelStats.BytesReceived - + response := stats.AllTrafficStats{ Tunnel: tunnelStats, Mappings: mappings, @@ -346,4 +370,4 @@ func (h *Handler) handleTrafficMonitor(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleManagement(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprint(w, managementHTML) -} \ No newline at end of file +} diff --git a/src/server/api/api_test.go b/src/server/api/api_test.go index 4799b41..23f1c97 100644 --- a/src/server/api/api_test.go +++ b/src/server/api/api_test.go @@ -14,20 +14,22 @@ import ( "testing" ) +const testAPIKey = "test-api-key-12345" + // setupTestHandler 创建测试用的 Handler func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) { // 创建临时数据库 tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") - + database, err := db.New(dbPath) if err != nil { t.Fatalf("创建数据库失败: %v", err) } - + // 创建转发器管理器 fwdMgr := forwarder.NewManager() - + // 创建隧道服务器(如果启用) var tunnelServer *tunnel.Server if useTunnel { @@ -35,13 +37,13 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun listener, _ := net.Listen("tcp", "127.0.0.1:0") port := listener.Addr().(*net.TCPAddr).Port listener.Close() - + tunnelServer = tunnel.NewServer(port) tunnelServer.Start() } - - handler := NewHandler(database, fwdMgr, tunnelServer) - + + handler := NewHandler(database, fwdMgr, tunnelServer, testAPIKey) + cleanup := func() { fwdMgr.StopAll() if tunnelServer != nil { @@ -50,23 +52,28 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun database.Close() os.RemoveAll(tmpDir) } - + return handler, database, cleanup } +// addAuthHeader 添加认证头到请求 +func addAuthHeader(req *http.Request) { + req.Header.Set("X-API-Key", testAPIKey) +} + // TestNewHandler 测试创建处理器 func TestNewHandler(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + if handler == nil { t.Fatal("创建处理器失败") } - + // if handler.portRangeFrom != 10000 { // t.Errorf("起始端口不正确,期望 10000,得到 %d", handler.portRangeFrom) // } - + // if handler.portRangeEnd != 20000 { // t.Errorf("结束端口不正确,期望 20000,得到 %d", handler.portRangeEnd) // } @@ -76,22 +83,22 @@ func TestNewHandler(t *testing.T) { func TestHandleHealth(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + req := httptest.NewRequest(http.MethodGet, "/health", nil) w := httptest.NewRecorder() - + handler.handleHealth(w, req) - + if w.Code != http.StatusOK { t.Errorf("状态码不正确,期望 200,得到 %d", w.Code) } - + var result map[string]interface{} err := json.NewDecoder(w.Body).Decode(&result) if err != nil { t.Fatalf("解析响应失败: %v", err) } - + if result["status"] != "ok" { t.Errorf("健康状态不正确,期望 ok,得到 %v", result["status"]) } @@ -101,23 +108,23 @@ func TestHandleHealth(t *testing.T) { func TestHandleHealthWithTunnel(t *testing.T) { handler, _, cleanup := setupTestHandler(t, true) defer cleanup() - + req := httptest.NewRequest(http.MethodGet, "/health", nil) w := httptest.NewRecorder() - + handler.handleHealth(w, req) - + if w.Code != http.StatusOK { t.Errorf("状态码不正确,期望 200,得到 %d", w.Code) } - + var result map[string]interface{} json.NewDecoder(w.Body).Decode(&result) - + if result["tunnel_enabled"] != true { t.Error("隧道应该启用") } - + // 隧道未连接客户端时应该为 false if result["tunnel_connected"] != false { t.Error("隧道应该未连接") @@ -128,41 +135,42 @@ func TestHandleHealthWithTunnel(t *testing.T) { func TestHandleCreateMapping(t *testing.T) { handler, database, cleanup := setupTestHandler(t, false) defer cleanup() - + reqBody := CreateMappingRequest{ // Port: 15000, SourcePort: 15000, TargetPort: 15000, TargetHost: "192.168.1.100", } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) + addAuthHeader(req) w := httptest.NewRecorder() - + handler.handleCreateMapping(w, req) - + if w.Code != http.StatusOK { t.Errorf("状态码不正确,期望 200,得到 %d", w.Code) } - + var result Response json.NewDecoder(w.Body).Decode(&result) - + if !result.Success { t.Errorf("创建映射失败: %s", result.Message) } - + // 验证数据库中存在映射 mapping, err := database.GetMapping(15000) if err != nil { t.Fatalf("获取映射失败: %v", err) } - + if mapping == nil { t.Fatal("映射不存在") } - + if mapping.TargetHost != "192.168.1.100" { t.Errorf("目标 IP 不正确,期望 192.168.1.100,得到 %s", mapping.TargetHost) } @@ -172,7 +180,7 @@ func TestHandleCreateMapping(t *testing.T) { func TestHandleCreateMappingInvalidPort(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + tests := []struct { name string port int @@ -181,7 +189,7 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) { {"端口太大", 25000}, {"端口为0", 0}, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { reqBody := CreateMappingRequest{ @@ -189,13 +197,14 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) { TargetPort: tt.port, TargetHost: "192.168.1.100", } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) + addAuthHeader(req) w := httptest.NewRecorder() - + handler.handleCreateMapping(w, req) - + if w.Code != http.StatusBadRequest { t.Errorf("状态码不正确,期望 400,得到 %d", w.Code) } @@ -207,30 +216,32 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) { func TestHandleCreateMappingDuplicate(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + reqBody := CreateMappingRequest{ // Port: 15000, SourcePort: 15000, TargetPort: 15000, TargetHost: "192.168.1.100", } - + // 第一次创建 body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) + addAuthHeader(req) w := httptest.NewRecorder() handler.handleCreateMapping(w, req) - + if w.Code != http.StatusOK { t.Fatalf("第一次创建失败") } - + // 第二次创建(应该失败) body, _ = json.Marshal(reqBody) req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) + addAuthHeader(req) w = httptest.NewRecorder() handler.handleCreateMapping(w, req) - + if w.Code != http.StatusConflict { t.Errorf("状态码不正确,期望 409,得到 %d", w.Code) } @@ -240,12 +251,13 @@ func TestHandleCreateMappingDuplicate(t *testing.T) { func TestHandleCreateMappingInvalidJSON(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json"))) + addAuthHeader(req) w := httptest.NewRecorder() - + handler.handleCreateMapping(w, req) - + if w.Code != http.StatusBadRequest { t.Errorf("状态码不正确,期望 400,得到 %d", w.Code) } @@ -255,20 +267,21 @@ func TestHandleCreateMappingInvalidJSON(t *testing.T) { func TestHandleCreateMappingInvalidIP(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + reqBody := CreateMappingRequest{ // Port: 15000, SourcePort: 15000, TargetPort: 15000, TargetHost: "", // 使用空字符串而不是无效域名,避免DNS查询超时 } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) + addAuthHeader(req) w := httptest.NewRecorder() - + handler.handleCreateMapping(w, req) - + if w.Code != http.StatusBadRequest { t.Errorf("状态码不正确,期望 400,得到 %d", w.Code) } @@ -278,20 +291,21 @@ func TestHandleCreateMappingInvalidIP(t *testing.T) { func TestHandleCreateMappingEmptyIP(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + reqBody := CreateMappingRequest{ // Port: 15000, SourcePort: 15000, TargetPort: 15000, TargetHost: "", } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) + addAuthHeader(req) w := httptest.NewRecorder() - + handler.handleCreateMapping(w, req) - + if w.Code != http.StatusBadRequest { t.Errorf("状态码不正确,期望 400,得到 %d", w.Code) } @@ -301,20 +315,21 @@ func TestHandleCreateMappingEmptyIP(t *testing.T) { func TestHandleCreateMappingTunnelNotConnected(t *testing.T) { handler, _, cleanup := setupTestHandler(t, true) defer cleanup() - + reqBody := CreateMappingRequest{ // Port: 15000, SourcePort: 15000, TargetPort: 15000, - UseTunnel: true, // 明确指定使用隧道模式 + UseTunnel: true, // 明确指定使用隧道模式 } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) + addAuthHeader(req) w := httptest.NewRecorder() - + handler.handleCreateMapping(w, req) - + if w.Code != http.StatusServiceUnavailable { t.Errorf("状态码不正确,期望 503,得到 %d", w.Code) } @@ -324,25 +339,25 @@ func TestHandleCreateMappingTunnelNotConnected(t *testing.T) { func TestHandleRemoveMapping(t *testing.T) { handler, database, cleanup := setupTestHandler(t, false) defer cleanup() - + // 先创建一个映射 database.AddMapping(15000, "192.168.1.100", 15000, false) handler.forwarderMgr.Add(15000, "192.168.1.100", 15000) - + reqBody := RemoveMappingRequest{ Port: 15000, } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body)) w := httptest.NewRecorder() - + handler.handleRemoveMapping(w, req) - + if w.Code != http.StatusOK { t.Errorf("状态码不正确,期望 200,得到 %d", w.Code) } - + // 验证映射已删除 mapping, _ := database.GetMapping(15000) if mapping != nil { @@ -354,17 +369,17 @@ func TestHandleRemoveMapping(t *testing.T) { func TestHandleRemoveMappingNotExist(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + reqBody := RemoveMappingRequest{ Port: 15000, } - + body, _ := json.Marshal(reqBody) req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body)) w := httptest.NewRecorder() - + handler.handleRemoveMapping(w, req) - + if w.Code != http.StatusNotFound { t.Errorf("状态码不正确,期望 404,得到 %d", w.Code) } @@ -374,31 +389,32 @@ func TestHandleRemoveMappingNotExist(t *testing.T) { func TestHandleListMappings(t *testing.T) { handler, database, cleanup := setupTestHandler(t, false) defer cleanup() - + // 添加一些映射 database.AddMapping(15000, "192.168.1.100", 15000, false) database.AddMapping(15001, "192.168.1.101", 15001, true) database.AddMapping(15002, "192.168.1.102", 15002, false) - + req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil) + addAuthHeader(req) w := httptest.NewRecorder() - + handler.handleListMappings(w, req) - + if w.Code != http.StatusOK { t.Errorf("状态码不正确,期望 200,得到 %d", w.Code) } - + var result Response json.NewDecoder(w.Body).Decode(&result) - + if !result.Success { t.Errorf("列出映射失败: %s", result.Message) } - + data := result.Data.(map[string]interface{}) count := int(data["count"].(float64)) - + if count != 3 { t.Errorf("映射数量不正确,期望 3,得到 %d", count) } @@ -408,22 +424,23 @@ func TestHandleListMappings(t *testing.T) { func TestHandleListMappingsEmpty(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil) + addAuthHeader(req) w := httptest.NewRecorder() - + handler.handleListMappings(w, req) - + if w.Code != http.StatusOK { t.Errorf("状态码不正确,期望 200,得到 %d", w.Code) } - + var result Response json.NewDecoder(w.Body).Decode(&result) - + data := result.Data.(map[string]interface{}) count := int(data["count"].(float64)) - + if count != 0 { t.Errorf("映射数量不正确,期望 0,得到 %d", count) } @@ -433,7 +450,7 @@ func TestHandleListMappingsEmpty(t *testing.T) { func TestHandleMethodNotAllowed(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + tests := []struct { name string handler func(http.ResponseWriter, *http.Request) @@ -443,14 +460,15 @@ func TestHandleMethodNotAllowed(t *testing.T) { {"删除映射 GET", handler.handleRemoveMapping, http.MethodGet}, {"列出映射 POST", handler.handleListMappings, http.MethodPost}, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(tt.method, "/test", nil) + addAuthHeader(req) w := httptest.NewRecorder() - + tt.handler(w, req) - + if w.Code != http.StatusMethodNotAllowed { t.Errorf("状态码不正确,期望 405,得到 %d", w.Code) } @@ -462,10 +480,10 @@ func TestHandleMethodNotAllowed(t *testing.T) { func TestRegisterRoutes(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + mux := http.NewServeMux() handler.RegisterRoutes(mux) - + // 测试路由是否注册 routes := []string{ "/api/mapping/create", @@ -473,13 +491,13 @@ func TestRegisterRoutes(t *testing.T) { "/api/mapping/list", "/health", } - + for _, route := range routes { req := httptest.NewRequest(http.MethodGet, route, nil) w := httptest.NewRecorder() - + mux.ServeHTTP(w, req) - + // 如果路由不存在,应该返回 404 if w.Code == http.StatusNotFound { t.Errorf("路由 %s 未注册", route) @@ -491,21 +509,21 @@ func TestRegisterRoutes(t *testing.T) { func TestWriteSuccess(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + w := httptest.NewRecorder() handler.writeSuccess(w, "测试成功", map[string]string{"key": "value"}) - + if w.Code != http.StatusOK { t.Errorf("状态码不正确,期望 200,得到 %d", w.Code) } - + var result Response json.NewDecoder(w.Body).Decode(&result) - + if !result.Success { t.Error("Success 应该为 true") } - + if result.Message != "测试成功" { t.Errorf("消息不正确,期望 '测试成功',得到 '%s'", result.Message) } @@ -515,21 +533,21 @@ func TestWriteSuccess(t *testing.T) { func TestWriteError(t *testing.T) { handler, _, cleanup := setupTestHandler(t, false) defer cleanup() - + w := httptest.NewRecorder() handler.writeError(w, http.StatusBadRequest, "测试错误") - + if w.Code != http.StatusBadRequest { t.Errorf("状态码不正确,期望 400,得到 %d", w.Code) } - + var result Response json.NewDecoder(w.Body).Decode(&result) - + if result.Success { t.Error("Success 应该为 false") } - + if result.Message != "测试错误" { t.Errorf("消息不正确,期望 '测试错误',得到 '%s'", result.Message) } @@ -541,12 +559,12 @@ func BenchmarkHandleHealth(b *testing.B) { dbPath := filepath.Join(tmpDir, "bench.db") database, _ := db.New(dbPath) defer database.Close() - + fwdMgr := forwarder.NewManager() - handler := NewHandler(database, fwdMgr, nil) - + handler := NewHandler(database, fwdMgr, nil, "test-api-key") + req := httptest.NewRequest(http.MethodGet, "/health", nil) - + b.ResetTimer() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() @@ -560,21 +578,92 @@ func BenchmarkHandleListMappings(b *testing.B) { dbPath := filepath.Join(tmpDir, "bench.db") database, _ := db.New(dbPath) defer database.Close() - + // 添加一些映射 for i := 0; i < 100; i++ { useTunnel := i%2 == 0 // 偶数使用隧道模式 database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel) } - + fwdMgr := forwarder.NewManager() - handler := NewHandler(database, fwdMgr, nil) - + handler := NewHandler(database, fwdMgr, nil, "test-api-key") + req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil) - + req.Header.Set("X-API-Key", "test-api-key") + b.ResetTimer() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() handler.handleListMappings(w, req) } } + +// TestAuthMiddleware 测试认证中间件 +func TestAuthMiddleware(t *testing.T) { + handler, _, cleanup := setupTestHandler(t, false) + defer cleanup() + + tests := []struct { + name string + apiKey string + useHeader bool + useQueryParam bool + expectedStatus int + expectedMsg string + }{ + { + name: "有效的API密钥(请求头)", + apiKey: testAPIKey, + useHeader: true, + expectedStatus: http.StatusOK, + }, + { + name: "有效的API密钥(查询参数)", + apiKey: testAPIKey, + useQueryParam: true, + expectedStatus: http.StatusOK, + }, + { + name: "无效的API密钥", + apiKey: "invalid-key", + useHeader: true, + expectedStatus: http.StatusUnauthorized, + expectedMsg: "无效的 API 密钥", + }, + { + name: "缺少API密钥", + apiKey: "", + expectedStatus: http.StatusUnauthorized, + expectedMsg: "无效的 API 密钥", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/api/mapping/list" + if tt.useQueryParam { + url += "?api_key=" + tt.apiKey + } + + req := httptest.NewRequest(http.MethodGet, url, nil) + if tt.useHeader && tt.apiKey != "" { + req.Header.Set("X-API-Key", tt.apiKey) + } + + w := httptest.NewRecorder() + handler.handleListMappings(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("状态码不正确,期望 %d,得到 %d", tt.expectedStatus, w.Code) + } + + if tt.expectedMsg != "" { + var result Response + json.NewDecoder(w.Body).Decode(&result) + if result.Message != tt.expectedMsg { + t.Errorf("错误消息不正确,期望 '%s',得到 '%s'", tt.expectedMsg, result.Message) + } + } + }) + } +} diff --git a/src/server/config/config.go b/src/server/config/config.go index d0ba32f..8dd06d8 100644 --- a/src/server/config/config.go +++ b/src/server/config/config.go @@ -10,9 +10,9 @@ import ( // Config 应用配置结构 type Config struct { // PortRange PortRangeConfig `yaml:"port_range"` - Tunnel TunnelConfig `yaml:"tunnel"` - API APIConfig `yaml:"api"` - Database DatabaseConfig `yaml:"database"` + Tunnel TunnelConfig `yaml:"tunnel"` + API APIConfig `yaml:"api"` + Database DatabaseConfig `yaml:"database"` } // PortRangeConfig 端口范围配置 @@ -29,7 +29,8 @@ type TunnelConfig struct { // APIConfig HTTP API 配置 type APIConfig struct { - ListenPort int `yaml:"listen_port"` + ListenPort int `yaml:"listen_port"` + APIKey string `yaml:"api_key"` } // DatabaseConfig 数据库配置 @@ -74,8 +75,11 @@ func (c *Config) Validate() error { if c.API.ListenPort <= 0 { return fmt.Errorf("API 端口必须大于 0") } + if c.API.APIKey == "" { + return fmt.Errorf("API 密钥不能为空") + } if c.Database.Path == "" { return fmt.Errorf("数据库路径不能为空") } return nil -} \ No newline at end of file +} diff --git a/src/server/main.go b/src/server/main.go index cd8fc23..08d313e 100644 --- a/src/server/main.go +++ b/src/server/main.go @@ -96,7 +96,7 @@ func (s *serverService) Start() error { // 创建 HTTP API 处理器 log.Println("初始化 HTTP API...") - s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer) + s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer, cfg.API.APIKey) // 启动 HTTP API 服务器 go func() {