feat: 新增api keys
This commit is contained in:
parent
550043f3f4
commit
a3ddc33d17
|
|
@ -0,0 +1,136 @@
|
||||||
|
# API 认证说明
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
所有 API 接口(除了 `/health` 健康检查接口)都需要提供有效的 API 密钥才能访问。
|
||||||
|
|
||||||
|
## 配置 API 密钥
|
||||||
|
|
||||||
|
在配置文件 `config.yaml` 中设置 API 密钥:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# HTTP API 配置
|
||||||
|
api:
|
||||||
|
listen_port: 8080
|
||||||
|
api_key: "your-secret-api-key-here" # 修改为你的密钥
|
||||||
|
```
|
||||||
|
|
||||||
|
**安全建议:**
|
||||||
|
- 使用强密码生成器生成复杂的 API 密钥
|
||||||
|
- 定期更换 API 密钥
|
||||||
|
- 不要在公共代码库中提交包含真实密钥的配置文件
|
||||||
|
|
||||||
|
## 使用 API 密钥
|
||||||
|
|
||||||
|
### 方式 1: 通过 HTTP 请求头(推荐)
|
||||||
|
|
||||||
|
在请求头中添加 `X-API-Key`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/api/mapping/create \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "X-API-Key: your-secret-api-key-here" \
|
||||||
|
-d '{
|
||||||
|
"source_port": 30001,
|
||||||
|
"target_host": "192.168.1.100",
|
||||||
|
"target_port": 3306,
|
||||||
|
"use_tunnel": false
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方式 2: 通过 URL 查询参数
|
||||||
|
|
||||||
|
在 URL 中添加 `api_key` 参数:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://localhost:8080/api/mapping/create?api_key=your-secret-api-key-here" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"source_port": 30001,
|
||||||
|
"target_host": "192.168.1.100",
|
||||||
|
"target_port": 3306,
|
||||||
|
"use_tunnel": false
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用 PowerShell
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
$headers = @{
|
||||||
|
"Content-Type" = "application/json"
|
||||||
|
"X-API-Key" = "your-secret-api-key-here"
|
||||||
|
}
|
||||||
|
|
||||||
|
$body = @{
|
||||||
|
source_port = 30001
|
||||||
|
target_host = "192.168.1.100"
|
||||||
|
target_port = 3306
|
||||||
|
use_tunnel = $false
|
||||||
|
} | ConvertTo-Json
|
||||||
|
|
||||||
|
Invoke-RestMethod -Uri "http://localhost:8080/api/mapping/create" `
|
||||||
|
-Method Post `
|
||||||
|
-Headers $headers `
|
||||||
|
-Body $body
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用 Python
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'X-API-Key': 'your-secret-api-key-here'
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'source_port': 30001,
|
||||||
|
'target_host': '192.168.1.100',
|
||||||
|
'target_port': 3306,
|
||||||
|
'use_tunnel': False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
'http://localhost:8080/api/mapping/create',
|
||||||
|
headers=headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.json())
|
||||||
|
```
|
||||||
|
|
||||||
|
## 需要认证的 API 接口
|
||||||
|
|
||||||
|
以下接口都需要提供有效的 API 密钥:
|
||||||
|
|
||||||
|
- `POST /api/mapping/create` - 创建端口映射
|
||||||
|
- `POST /api/mapping/remove` - 删除端口映射
|
||||||
|
- `GET /api/mapping/list` - 列出所有映射
|
||||||
|
- `GET /api/stats/traffic` - 获取流量统计
|
||||||
|
- `GET /api/stats/monitor` - 流量监控页面
|
||||||
|
- `GET /admin` - 管理页面
|
||||||
|
|
||||||
|
## 不需要认证的接口
|
||||||
|
|
||||||
|
- `GET /health` - 健康检查接口(公开访问)
|
||||||
|
|
||||||
|
## 错误响应
|
||||||
|
|
||||||
|
如果 API 密钥无效或缺失,服务器将返回 401 状态码:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的 API 密钥"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 浏览器访问
|
||||||
|
|
||||||
|
对于需要通过浏览器访问的页面(如 `/admin` 和 `/api/stats/monitor`),可以在 URL 中添加 `api_key` 参数:
|
||||||
|
|
||||||
|
```
|
||||||
|
http://localhost:8080/admin?api_key=your-secret-api-key-here
|
||||||
|
http://localhost:8080/api/stats/monitor?api_key=your-secret-api-key-here
|
||||||
|
```
|
||||||
|
|
@ -13,6 +13,7 @@ tunnel:
|
||||||
# HTTP API 配置
|
# HTTP API 配置
|
||||||
api:
|
api:
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
|
api_key: "your-secret-api-key-here"
|
||||||
|
|
||||||
# 数据库配置
|
# 数据库配置
|
||||||
database:
|
database:
|
||||||
|
|
|
||||||
|
|
@ -17,19 +17,21 @@ import (
|
||||||
|
|
||||||
// Handler HTTP API 处理器
|
// Handler HTTP API 处理器
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
db *db.Database
|
db *db.Database
|
||||||
forwarderMgr *forwarder.Manager
|
forwarderMgr *forwarder.Manager
|
||||||
tunnelServer *tunnel.Server
|
tunnelServer *tunnel.Server
|
||||||
|
apiKey string
|
||||||
// portRangeFrom int
|
// portRangeFrom int
|
||||||
// portRangeEnd int
|
// portRangeEnd int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler 创建新的 API 处理器
|
// NewHandler 创建新的 API 处理器
|
||||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server) *Handler {
|
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, apiKey string) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
db: database,
|
db: database,
|
||||||
forwarderMgr: fwdMgr,
|
forwarderMgr: fwdMgr,
|
||||||
tunnelServer: ts,
|
tunnelServer: ts,
|
||||||
|
apiKey: apiKey,
|
||||||
// portRangeFrom: portFrom,
|
// portRangeFrom: portFrom,
|
||||||
// portRangeEnd: portEnd,
|
// portRangeEnd: portEnd,
|
||||||
}
|
}
|
||||||
|
|
@ -57,15 +59,37 @@ type Response struct {
|
||||||
|
|
||||||
// RegisterRoutes 注册路由
|
// RegisterRoutes 注册路由
|
||||||
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/mapping/create", h.handleCreateMapping)
|
mux.HandleFunc("/api/mapping/create", h.authMiddleware(h.handleCreateMapping))
|
||||||
mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping)
|
mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping))
|
||||||
mux.HandleFunc("/api/mapping/list", h.handleListMappings)
|
mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings))
|
||||||
mux.HandleFunc("/api/stats/traffic", h.handleGetTrafficStats)
|
mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats))
|
||||||
mux.HandleFunc("/api/stats/monitor", h.handleTrafficMonitor)
|
mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor))
|
||||||
mux.HandleFunc("/admin", h.handleManagement)
|
mux.HandleFunc("/admin", h.authMiddleware(h.handleManagement))
|
||||||
mux.HandleFunc("/health", h.handleHealth)
|
mux.HandleFunc("/health", h.handleHealth)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authMiddleware 认证中间件
|
||||||
|
func (h *Handler) authMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 从请求头中获取 API Key
|
||||||
|
apiKey := r.Header.Get("X-API-Key")
|
||||||
|
|
||||||
|
// 如果请求头中没有,尝试从查询参数中获取
|
||||||
|
if apiKey == "" {
|
||||||
|
apiKey = r.URL.Query().Get("api_key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 API Key
|
||||||
|
if apiKey != h.apiKey {
|
||||||
|
h.writeError(w, http.StatusUnauthorized, "无效的 API 密钥")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 认证通过,继续处理请求
|
||||||
|
next(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// validateHostOrIP 验证主机名或IP地址
|
// validateHostOrIP 验证主机名或IP地址
|
||||||
func (h *Handler) validateHostOrIP(hostOrIP string) error {
|
func (h *Handler) validateHostOrIP(hostOrIP string) error {
|
||||||
if hostOrIP == "" {
|
if hostOrIP == "" {
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const testAPIKey = "test-api-key-12345"
|
||||||
|
|
||||||
// setupTestHandler 创建测试用的 Handler
|
// setupTestHandler 创建测试用的 Handler
|
||||||
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
|
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
|
||||||
// 创建临时数据库
|
// 创建临时数据库
|
||||||
|
|
@ -40,7 +42,7 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
|
||||||
tunnelServer.Start()
|
tunnelServer.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := NewHandler(database, fwdMgr, tunnelServer)
|
handler := NewHandler(database, fwdMgr, tunnelServer, testAPIKey)
|
||||||
|
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
fwdMgr.StopAll()
|
fwdMgr.StopAll()
|
||||||
|
|
@ -54,6 +56,11 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
|
||||||
return handler, database, cleanup
|
return handler, database, cleanup
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addAuthHeader 添加认证头到请求
|
||||||
|
func addAuthHeader(req *http.Request) {
|
||||||
|
req.Header.Set("X-API-Key", testAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
// TestNewHandler 测试创建处理器
|
// TestNewHandler 测试创建处理器
|
||||||
func TestNewHandler(t *testing.T) {
|
func TestNewHandler(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
|
|
@ -138,6 +145,7 @@ func TestHandleCreateMapping(t *testing.T) {
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
@ -192,6 +200,7 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
@ -218,6 +227,7 @@ func TestHandleCreateMappingDuplicate(t *testing.T) {
|
||||||
// 第一次创建
|
// 第一次创建
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
|
|
@ -228,6 +238,7 @@ func TestHandleCreateMappingDuplicate(t *testing.T) {
|
||||||
// 第二次创建(应该失败)
|
// 第二次创建(应该失败)
|
||||||
body, _ = json.Marshal(reqBody)
|
body, _ = json.Marshal(reqBody)
|
||||||
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
|
|
@ -242,6 +253,7 @@ func TestHandleCreateMappingInvalidJSON(t *testing.T) {
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
@ -265,6 +277,7 @@ func TestHandleCreateMappingInvalidIP(t *testing.T) {
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
@ -288,6 +301,7 @@ func TestHandleCreateMappingEmptyIP(t *testing.T) {
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
@ -306,11 +320,12 @@ func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
|
||||||
// Port: 15000,
|
// Port: 15000,
|
||||||
SourcePort: 15000,
|
SourcePort: 15000,
|
||||||
TargetPort: 15000,
|
TargetPort: 15000,
|
||||||
UseTunnel: true, // 明确指定使用隧道模式
|
UseTunnel: true, // 明确指定使用隧道模式
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
@ -381,6 +396,7 @@ func TestHandleListMappings(t *testing.T) {
|
||||||
database.AddMapping(15002, "192.168.1.102", 15002, false)
|
database.AddMapping(15002, "192.168.1.102", 15002, false)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleListMappings(w, req)
|
handler.handleListMappings(w, req)
|
||||||
|
|
@ -410,6 +426,7 @@ func TestHandleListMappingsEmpty(t *testing.T) {
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleListMappings(w, req)
|
handler.handleListMappings(w, req)
|
||||||
|
|
@ -447,6 +464,7 @@ func TestHandleMethodNotAllowed(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
tt.handler(w, req)
|
tt.handler(w, req)
|
||||||
|
|
@ -543,7 +561,7 @@ func BenchmarkHandleHealth(b *testing.B) {
|
||||||
defer database.Close()
|
defer database.Close()
|
||||||
|
|
||||||
fwdMgr := forwarder.NewManager()
|
fwdMgr := forwarder.NewManager()
|
||||||
handler := NewHandler(database, fwdMgr, nil)
|
handler := NewHandler(database, fwdMgr, nil, "test-api-key")
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||||
|
|
||||||
|
|
@ -568,9 +586,10 @@ func BenchmarkHandleListMappings(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fwdMgr := forwarder.NewManager()
|
fwdMgr := forwarder.NewManager()
|
||||||
handler := NewHandler(database, fwdMgr, nil)
|
handler := NewHandler(database, fwdMgr, nil, "test-api-key")
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||||
|
req.Header.Set("X-API-Key", "test-api-key")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|
@ -578,3 +597,73 @@ func BenchmarkHandleListMappings(b *testing.B) {
|
||||||
handler.handleListMappings(w, req)
|
handler.handleListMappings(w, req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAuthMiddleware 测试认证中间件
|
||||||
|
func TestAuthMiddleware(t *testing.T) {
|
||||||
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
apiKey string
|
||||||
|
useHeader bool
|
||||||
|
useQueryParam bool
|
||||||
|
expectedStatus int
|
||||||
|
expectedMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "有效的API密钥(请求头)",
|
||||||
|
apiKey: testAPIKey,
|
||||||
|
useHeader: true,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "有效的API密钥(查询参数)",
|
||||||
|
apiKey: testAPIKey,
|
||||||
|
useQueryParam: true,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无效的API密钥",
|
||||||
|
apiKey: "invalid-key",
|
||||||
|
useHeader: true,
|
||||||
|
expectedStatus: http.StatusUnauthorized,
|
||||||
|
expectedMsg: "无效的 API 密钥",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "缺少API密钥",
|
||||||
|
apiKey: "",
|
||||||
|
expectedStatus: http.StatusUnauthorized,
|
||||||
|
expectedMsg: "无效的 API 密钥",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
url := "/api/mapping/list"
|
||||||
|
if tt.useQueryParam {
|
||||||
|
url += "?api_key=" + tt.apiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if tt.useHeader && tt.apiKey != "" {
|
||||||
|
req.Header.Set("X-API-Key", tt.apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.handleListMappings(w, req)
|
||||||
|
|
||||||
|
if w.Code != tt.expectedStatus {
|
||||||
|
t.Errorf("状态码不正确,期望 %d,得到 %d", tt.expectedStatus, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectedMsg != "" {
|
||||||
|
var result Response
|
||||||
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
if result.Message != tt.expectedMsg {
|
||||||
|
t.Errorf("错误消息不正确,期望 '%s',得到 '%s'", tt.expectedMsg, result.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,9 @@ import (
|
||||||
// Config 应用配置结构
|
// Config 应用配置结构
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// PortRange PortRangeConfig `yaml:"port_range"`
|
// PortRange PortRangeConfig `yaml:"port_range"`
|
||||||
Tunnel TunnelConfig `yaml:"tunnel"`
|
Tunnel TunnelConfig `yaml:"tunnel"`
|
||||||
API APIConfig `yaml:"api"`
|
API APIConfig `yaml:"api"`
|
||||||
Database DatabaseConfig `yaml:"database"`
|
Database DatabaseConfig `yaml:"database"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PortRangeConfig 端口范围配置
|
// PortRangeConfig 端口范围配置
|
||||||
|
|
@ -29,7 +29,8 @@ type TunnelConfig struct {
|
||||||
|
|
||||||
// APIConfig HTTP API 配置
|
// APIConfig HTTP API 配置
|
||||||
type APIConfig struct {
|
type APIConfig struct {
|
||||||
ListenPort int `yaml:"listen_port"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
|
APIKey string `yaml:"api_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseConfig 数据库配置
|
// DatabaseConfig 数据库配置
|
||||||
|
|
@ -74,6 +75,9 @@ func (c *Config) Validate() error {
|
||||||
if c.API.ListenPort <= 0 {
|
if c.API.ListenPort <= 0 {
|
||||||
return fmt.Errorf("API 端口必须大于 0")
|
return fmt.Errorf("API 端口必须大于 0")
|
||||||
}
|
}
|
||||||
|
if c.API.APIKey == "" {
|
||||||
|
return fmt.Errorf("API 密钥不能为空")
|
||||||
|
}
|
||||||
if c.Database.Path == "" {
|
if c.Database.Path == "" {
|
||||||
return fmt.Errorf("数据库路径不能为空")
|
return fmt.Errorf("数据库路径不能为空")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ func (s *serverService) Start() error {
|
||||||
|
|
||||||
// 创建 HTTP API 处理器
|
// 创建 HTTP API 处理器
|
||||||
log.Println("初始化 HTTP API...")
|
log.Println("初始化 HTTP API...")
|
||||||
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer)
|
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer, cfg.API.APIKey)
|
||||||
|
|
||||||
// 启动 HTTP API 服务器
|
// 启动 HTTP API 服务器
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue