feat: 新增api keys

This commit is contained in:
pqcqaq 2025-10-21 22:20:26 +08:00
parent 550043f3f4
commit a3ddc33d17
6 changed files with 388 additions and 134 deletions

136
API_AUTH.md Normal file
View File

@ -0,0 +1,136 @@
# API 认证说明
## 概述
所有 API 接口(除了 `/health` 健康检查接口)都需要提供有效的 API 密钥才能访问。
## 配置 API 密钥
在配置文件 `config.yaml` 中设置 API 密钥:
```yaml
# HTTP API 配置
api:
listen_port: 8080
api_key: "your-secret-api-key-here" # 修改为你的密钥
```
**安全建议:**
- 使用强密码生成器生成复杂的 API 密钥
- 定期更换 API 密钥
- 不要在公共代码库中提交包含真实密钥的配置文件
## 使用 API 密钥
### 方式 1: 通过 HTTP 请求头(推荐)
在请求头中添加 `X-API-Key`
```bash
curl -X POST http://localhost:8080/api/mapping/create \
-H "Content-Type: application/json" \
-H "X-API-Key: your-secret-api-key-here" \
-d '{
"source_port": 30001,
"target_host": "192.168.1.100",
"target_port": 3306,
"use_tunnel": false
}'
```
### 方式 2: 通过 URL 查询参数
在 URL 中添加 `api_key` 参数:
```bash
curl -X POST "http://localhost:8080/api/mapping/create?api_key=your-secret-api-key-here" \
-H "Content-Type: application/json" \
-d '{
"source_port": 30001,
"target_host": "192.168.1.100",
"target_port": 3306,
"use_tunnel": false
}'
```
### 使用 PowerShell
```powershell
$headers = @{
"Content-Type" = "application/json"
"X-API-Key" = "your-secret-api-key-here"
}
$body = @{
source_port = 30001
target_host = "192.168.1.100"
target_port = 3306
use_tunnel = $false
} | ConvertTo-Json
Invoke-RestMethod -Uri "http://localhost:8080/api/mapping/create" `
-Method Post `
-Headers $headers `
-Body $body
```
### 使用 Python
```python
import requests
headers = {
'Content-Type': 'application/json',
'X-API-Key': 'your-secret-api-key-here'
}
data = {
'source_port': 30001,
'target_host': '192.168.1.100',
'target_port': 3306,
'use_tunnel': False
}
response = requests.post(
'http://localhost:8080/api/mapping/create',
headers=headers,
json=data
)
print(response.json())
```
## 需要认证的 API 接口
以下接口都需要提供有效的 API 密钥:
- `POST /api/mapping/create` - 创建端口映射
- `POST /api/mapping/remove` - 删除端口映射
- `GET /api/mapping/list` - 列出所有映射
- `GET /api/stats/traffic` - 获取流量统计
- `GET /api/stats/monitor` - 流量监控页面
- `GET /admin` - 管理页面
## 不需要认证的接口
- `GET /health` - 健康检查接口(公开访问)
## 错误响应
如果 API 密钥无效或缺失,服务器将返回 401 状态码:
```json
{
"success": false,
"message": "无效的 API 密钥"
}
```
## 浏览器访问
对于需要通过浏览器访问的页面(如 `/admin``/api/stats/monitor`),可以在 URL 中添加 `api_key` 参数:
```
http://localhost:8080/admin?api_key=your-secret-api-key-here
http://localhost:8080/api/stats/monitor?api_key=your-secret-api-key-here
```

View File

@ -13,6 +13,7 @@ tunnel:
# HTTP API 配置
api:
listen_port: 8080
api_key: "your-secret-api-key-here"
# 数据库配置
database:

View File

@ -17,19 +17,21 @@ import (
// Handler HTTP API 处理器
type Handler struct {
db *db.Database
forwarderMgr *forwarder.Manager
tunnelServer *tunnel.Server
db *db.Database
forwarderMgr *forwarder.Manager
tunnelServer *tunnel.Server
apiKey string
// portRangeFrom int
// portRangeEnd int
}
// NewHandler 创建新的 API 处理器
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server) *Handler {
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, apiKey string) *Handler {
return &Handler{
db: database,
forwarderMgr: fwdMgr,
tunnelServer: ts,
db: database,
forwarderMgr: fwdMgr,
tunnelServer: ts,
apiKey: apiKey,
// portRangeFrom: portFrom,
// portRangeEnd: portEnd,
}
@ -57,15 +59,37 @@ type Response struct {
// RegisterRoutes 注册路由
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/mapping/create", h.handleCreateMapping)
mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping)
mux.HandleFunc("/api/mapping/list", h.handleListMappings)
mux.HandleFunc("/api/stats/traffic", h.handleGetTrafficStats)
mux.HandleFunc("/api/stats/monitor", h.handleTrafficMonitor)
mux.HandleFunc("/admin", h.handleManagement)
mux.HandleFunc("/api/mapping/create", h.authMiddleware(h.handleCreateMapping))
mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping))
mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings))
mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats))
mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor))
mux.HandleFunc("/admin", h.authMiddleware(h.handleManagement))
mux.HandleFunc("/health", h.handleHealth)
}
// authMiddleware 认证中间件
func (h *Handler) authMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 从请求头中获取 API Key
apiKey := r.Header.Get("X-API-Key")
// 如果请求头中没有,尝试从查询参数中获取
if apiKey == "" {
apiKey = r.URL.Query().Get("api_key")
}
// 验证 API Key
if apiKey != h.apiKey {
h.writeError(w, http.StatusUnauthorized, "无效的 API 密钥")
return
}
// 认证通过,继续处理请求
next(w, r)
}
}
// validateHostOrIP 验证主机名或IP地址
func (h *Handler) validateHostOrIP(hostOrIP string) error {
if hostOrIP == "" {
@ -156,7 +180,7 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
// 直接模式直接TCP转发
err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort)
}
if err != nil {
// 回滚数据库操作
h.db.RemoveMapping(req.SourcePort)
@ -305,11 +329,11 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request)
// 获取所有端口映射的流量统计
forwarderStats := h.forwarderMgr.GetAllTrafficStats()
// 构建响应
mappings := make([]stats.PortTrafficStats, 0, len(forwarderStats))
var totalSent, totalReceived uint64
for port, stat := range forwarderStats {
mappings = append(mappings, stats.PortTrafficStats{
Port: port,
@ -320,11 +344,11 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request)
totalSent += stat.BytesSent
totalReceived += stat.BytesReceived
}
// 加上隧道的流量
totalSent += tunnelStats.BytesSent
totalReceived += tunnelStats.BytesReceived
response := stats.AllTrafficStats{
Tunnel: tunnelStats,
Mappings: mappings,
@ -346,4 +370,4 @@ func (h *Handler) handleTrafficMonitor(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleManagement(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, managementHTML)
}
}

View File

@ -14,20 +14,22 @@ import (
"testing"
)
const testAPIKey = "test-api-key-12345"
// setupTestHandler 创建测试用的 Handler
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
// 创建临时数据库
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
database, err := db.New(dbPath)
if err != nil {
t.Fatalf("创建数据库失败: %v", err)
}
// 创建转发器管理器
fwdMgr := forwarder.NewManager()
// 创建隧道服务器(如果启用)
var tunnelServer *tunnel.Server
if useTunnel {
@ -35,13 +37,13 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
listener, _ := net.Listen("tcp", "127.0.0.1:0")
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
tunnelServer = tunnel.NewServer(port)
tunnelServer.Start()
}
handler := NewHandler(database, fwdMgr, tunnelServer)
handler := NewHandler(database, fwdMgr, tunnelServer, testAPIKey)
cleanup := func() {
fwdMgr.StopAll()
if tunnelServer != nil {
@ -50,23 +52,28 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
database.Close()
os.RemoveAll(tmpDir)
}
return handler, database, cleanup
}
// addAuthHeader 添加认证头到请求
func addAuthHeader(req *http.Request) {
req.Header.Set("X-API-Key", testAPIKey)
}
// TestNewHandler 测试创建处理器
func TestNewHandler(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
if handler == nil {
t.Fatal("创建处理器失败")
}
// if handler.portRangeFrom != 10000 {
// t.Errorf("起始端口不正确,期望 10000得到 %d", handler.portRangeFrom)
// }
// if handler.portRangeEnd != 20000 {
// t.Errorf("结束端口不正确,期望 20000得到 %d", handler.portRangeEnd)
// }
@ -76,22 +83,22 @@ func TestNewHandler(t *testing.T) {
func TestHandleHealth(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
handler.handleHealth(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result map[string]interface{}
err := json.NewDecoder(w.Body).Decode(&result)
if err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if result["status"] != "ok" {
t.Errorf("健康状态不正确,期望 ok得到 %v", result["status"])
}
@ -101,23 +108,23 @@ func TestHandleHealth(t *testing.T) {
func TestHandleHealthWithTunnel(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, true)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
handler.handleHealth(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result map[string]interface{}
json.NewDecoder(w.Body).Decode(&result)
if result["tunnel_enabled"] != true {
t.Error("隧道应该启用")
}
// 隧道未连接客户端时应该为 false
if result["tunnel_connected"] != false {
t.Error("隧道应该未连接")
@ -128,41 +135,42 @@ func TestHandleHealthWithTunnel(t *testing.T) {
func TestHandleCreateMapping(t *testing.T) {
handler, database, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetHost: "192.168.1.100",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
if !result.Success {
t.Errorf("创建映射失败: %s", result.Message)
}
// 验证数据库中存在映射
mapping, err := database.GetMapping(15000)
if err != nil {
t.Fatalf("获取映射失败: %v", err)
}
if mapping == nil {
t.Fatal("映射不存在")
}
if mapping.TargetHost != "192.168.1.100" {
t.Errorf("目标 IP 不正确,期望 192.168.1.100,得到 %s", mapping.TargetHost)
}
@ -172,7 +180,7 @@ func TestHandleCreateMapping(t *testing.T) {
func TestHandleCreateMappingInvalidPort(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
tests := []struct {
name string
port int
@ -181,7 +189,7 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
{"端口太大", 25000},
{"端口为0", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqBody := CreateMappingRequest{
@ -189,13 +197,14 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
TargetPort: tt.port,
TargetHost: "192.168.1.100",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
@ -207,30 +216,32 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
func TestHandleCreateMappingDuplicate(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetHost: "192.168.1.100",
}
// 第一次创建
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusOK {
t.Fatalf("第一次创建失败")
}
// 第二次创建(应该失败)
body, _ = json.Marshal(reqBody)
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w = httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusConflict {
t.Errorf("状态码不正确,期望 409得到 %d", w.Code)
}
@ -240,12 +251,13 @@ func TestHandleCreateMappingDuplicate(t *testing.T) {
func TestHandleCreateMappingInvalidJSON(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
@ -255,20 +267,21 @@ func TestHandleCreateMappingInvalidJSON(t *testing.T) {
func TestHandleCreateMappingInvalidIP(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetHost: "", // 使用空字符串而不是无效域名避免DNS查询超时
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
@ -278,20 +291,21 @@ func TestHandleCreateMappingInvalidIP(t *testing.T) {
func TestHandleCreateMappingEmptyIP(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
TargetHost: "",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
@ -301,20 +315,21 @@ func TestHandleCreateMappingEmptyIP(t *testing.T) {
func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, true)
defer cleanup()
reqBody := CreateMappingRequest{
// Port: 15000,
SourcePort: 15000,
TargetPort: 15000,
UseTunnel: true, // 明确指定使用隧道模式
UseTunnel: true, // 明确指定使用隧道模式
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleCreateMapping(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("状态码不正确,期望 503得到 %d", w.Code)
}
@ -324,25 +339,25 @@ func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
func TestHandleRemoveMapping(t *testing.T) {
handler, database, cleanup := setupTestHandler(t, false)
defer cleanup()
// 先创建一个映射
database.AddMapping(15000, "192.168.1.100", 15000, false)
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000)
reqBody := RemoveMappingRequest{
Port: 15000,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleRemoveMapping(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
// 验证映射已删除
mapping, _ := database.GetMapping(15000)
if mapping != nil {
@ -354,17 +369,17 @@ func TestHandleRemoveMapping(t *testing.T) {
func TestHandleRemoveMappingNotExist(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
reqBody := RemoveMappingRequest{
Port: 15000,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
w := httptest.NewRecorder()
handler.handleRemoveMapping(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("状态码不正确,期望 404得到 %d", w.Code)
}
@ -374,31 +389,32 @@ func TestHandleRemoveMappingNotExist(t *testing.T) {
func TestHandleListMappings(t *testing.T) {
handler, database, cleanup := setupTestHandler(t, false)
defer cleanup()
// 添加一些映射
database.AddMapping(15000, "192.168.1.100", 15000, false)
database.AddMapping(15001, "192.168.1.101", 15001, true)
database.AddMapping(15002, "192.168.1.102", 15002, false)
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleListMappings(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
if !result.Success {
t.Errorf("列出映射失败: %s", result.Message)
}
data := result.Data.(map[string]interface{})
count := int(data["count"].(float64))
if count != 3 {
t.Errorf("映射数量不正确,期望 3得到 %d", count)
}
@ -408,22 +424,23 @@ func TestHandleListMappings(t *testing.T) {
func TestHandleListMappingsEmpty(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
addAuthHeader(req)
w := httptest.NewRecorder()
handler.handleListMappings(w, req)
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
data := result.Data.(map[string]interface{})
count := int(data["count"].(float64))
if count != 0 {
t.Errorf("映射数量不正确,期望 0得到 %d", count)
}
@ -433,7 +450,7 @@ func TestHandleListMappingsEmpty(t *testing.T) {
func TestHandleMethodNotAllowed(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
tests := []struct {
name string
handler func(http.ResponseWriter, *http.Request)
@ -443,14 +460,15 @@ func TestHandleMethodNotAllowed(t *testing.T) {
{"删除映射 GET", handler.handleRemoveMapping, http.MethodGet},
{"列出映射 POST", handler.handleListMappings, http.MethodPost},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/test", nil)
addAuthHeader(req)
w := httptest.NewRecorder()
tt.handler(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("状态码不正确,期望 405得到 %d", w.Code)
}
@ -462,10 +480,10 @@ func TestHandleMethodNotAllowed(t *testing.T) {
func TestRegisterRoutes(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
mux := http.NewServeMux()
handler.RegisterRoutes(mux)
// 测试路由是否注册
routes := []string{
"/api/mapping/create",
@ -473,13 +491,13 @@ func TestRegisterRoutes(t *testing.T) {
"/api/mapping/list",
"/health",
}
for _, route := range routes {
req := httptest.NewRequest(http.MethodGet, route, nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
// 如果路由不存在,应该返回 404
if w.Code == http.StatusNotFound {
t.Errorf("路由 %s 未注册", route)
@ -491,21 +509,21 @@ func TestRegisterRoutes(t *testing.T) {
func TestWriteSuccess(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
w := httptest.NewRecorder()
handler.writeSuccess(w, "测试成功", map[string]string{"key": "value"})
if w.Code != http.StatusOK {
t.Errorf("状态码不正确,期望 200得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
if !result.Success {
t.Error("Success 应该为 true")
}
if result.Message != "测试成功" {
t.Errorf("消息不正确,期望 '测试成功',得到 '%s'", result.Message)
}
@ -515,21 +533,21 @@ func TestWriteSuccess(t *testing.T) {
func TestWriteError(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
w := httptest.NewRecorder()
handler.writeError(w, http.StatusBadRequest, "测试错误")
if w.Code != http.StatusBadRequest {
t.Errorf("状态码不正确,期望 400得到 %d", w.Code)
}
var result Response
json.NewDecoder(w.Body).Decode(&result)
if result.Success {
t.Error("Success 应该为 false")
}
if result.Message != "测试错误" {
t.Errorf("消息不正确,期望 '测试错误',得到 '%s'", result.Message)
}
@ -541,12 +559,12 @@ func BenchmarkHandleHealth(b *testing.B) {
dbPath := filepath.Join(tmpDir, "bench.db")
database, _ := db.New(dbPath)
defer database.Close()
fwdMgr := forwarder.NewManager()
handler := NewHandler(database, fwdMgr, nil)
handler := NewHandler(database, fwdMgr, nil, "test-api-key")
req := httptest.NewRequest(http.MethodGet, "/health", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
@ -560,21 +578,92 @@ func BenchmarkHandleListMappings(b *testing.B) {
dbPath := filepath.Join(tmpDir, "bench.db")
database, _ := db.New(dbPath)
defer database.Close()
// 添加一些映射
for i := 0; i < 100; i++ {
useTunnel := i%2 == 0 // 偶数使用隧道模式
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel)
}
fwdMgr := forwarder.NewManager()
handler := NewHandler(database, fwdMgr, nil)
handler := NewHandler(database, fwdMgr, nil, "test-api-key")
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
req.Header.Set("X-API-Key", "test-api-key")
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
handler.handleListMappings(w, req)
}
}
// TestAuthMiddleware 测试认证中间件
func TestAuthMiddleware(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
tests := []struct {
name string
apiKey string
useHeader bool
useQueryParam bool
expectedStatus int
expectedMsg string
}{
{
name: "有效的API密钥(请求头)",
apiKey: testAPIKey,
useHeader: true,
expectedStatus: http.StatusOK,
},
{
name: "有效的API密钥(查询参数)",
apiKey: testAPIKey,
useQueryParam: true,
expectedStatus: http.StatusOK,
},
{
name: "无效的API密钥",
apiKey: "invalid-key",
useHeader: true,
expectedStatus: http.StatusUnauthorized,
expectedMsg: "无效的 API 密钥",
},
{
name: "缺少API密钥",
apiKey: "",
expectedStatus: http.StatusUnauthorized,
expectedMsg: "无效的 API 密钥",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url := "/api/mapping/list"
if tt.useQueryParam {
url += "?api_key=" + tt.apiKey
}
req := httptest.NewRequest(http.MethodGet, url, nil)
if tt.useHeader && tt.apiKey != "" {
req.Header.Set("X-API-Key", tt.apiKey)
}
w := httptest.NewRecorder()
handler.handleListMappings(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("状态码不正确,期望 %d得到 %d", tt.expectedStatus, w.Code)
}
if tt.expectedMsg != "" {
var result Response
json.NewDecoder(w.Body).Decode(&result)
if result.Message != tt.expectedMsg {
t.Errorf("错误消息不正确,期望 '%s',得到 '%s'", tt.expectedMsg, result.Message)
}
}
})
}
}

View File

@ -10,9 +10,9 @@ import (
// Config 应用配置结构
type Config struct {
// PortRange PortRangeConfig `yaml:"port_range"`
Tunnel TunnelConfig `yaml:"tunnel"`
API APIConfig `yaml:"api"`
Database DatabaseConfig `yaml:"database"`
Tunnel TunnelConfig `yaml:"tunnel"`
API APIConfig `yaml:"api"`
Database DatabaseConfig `yaml:"database"`
}
// PortRangeConfig 端口范围配置
@ -29,7 +29,8 @@ type TunnelConfig struct {
// APIConfig HTTP API 配置
type APIConfig struct {
ListenPort int `yaml:"listen_port"`
ListenPort int `yaml:"listen_port"`
APIKey string `yaml:"api_key"`
}
// DatabaseConfig 数据库配置
@ -74,8 +75,11 @@ func (c *Config) Validate() error {
if c.API.ListenPort <= 0 {
return fmt.Errorf("API 端口必须大于 0")
}
if c.API.APIKey == "" {
return fmt.Errorf("API 密钥不能为空")
}
if c.Database.Path == "" {
return fmt.Errorf("数据库路径不能为空")
}
return nil
}
}

View File

@ -96,7 +96,7 @@ func (s *serverService) Start() error {
// 创建 HTTP API 处理器
log.Println("初始化 HTTP API...")
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer)
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer, cfg.API.APIKey)
// 启动 HTTP API 服务器
go func() {