feat: 新增api keys

This commit is contained in:
pqcqaq 2025-10-21 22:20:26 +08:00
parent 550043f3f4
commit a3ddc33d17
6 changed files with 388 additions and 134 deletions

136
API_AUTH.md Normal file
View File

@ -0,0 +1,136 @@
# API 认证说明
## 概述
所有 API 接口(除了 `/health` 健康检查接口)都需要提供有效的 API 密钥才能访问。
## 配置 API 密钥
在配置文件 `config.yaml` 中设置 API 密钥:
```yaml
# HTTP API 配置
api:
listen_port: 8080
api_key: "your-secret-api-key-here" # 修改为你的密钥
```
**安全建议:**
- 使用强密码生成器生成复杂的 API 密钥
- 定期更换 API 密钥
- 不要在公共代码库中提交包含真实密钥的配置文件
## 使用 API 密钥
### 方式 1: 通过 HTTP 请求头(推荐)
在请求头中添加 `X-API-Key`
```bash
curl -X POST http://localhost:8080/api/mapping/create \
-H "Content-Type: application/json" \
-H "X-API-Key: your-secret-api-key-here" \
-d '{
"source_port": 30001,
"target_host": "192.168.1.100",
"target_port": 3306,
"use_tunnel": false
}'
```
### 方式 2: 通过 URL 查询参数
在 URL 中添加 `api_key` 参数:
```bash
curl -X POST "http://localhost:8080/api/mapping/create?api_key=your-secret-api-key-here" \
-H "Content-Type: application/json" \
-d '{
"source_port": 30001,
"target_host": "192.168.1.100",
"target_port": 3306,
"use_tunnel": false
}'
```
### 使用 PowerShell
```powershell
$headers = @{
"Content-Type" = "application/json"
"X-API-Key" = "your-secret-api-key-here"
}
$body = @{
source_port = 30001
target_host = "192.168.1.100"
target_port = 3306
use_tunnel = $false
} | ConvertTo-Json
Invoke-RestMethod -Uri "http://localhost:8080/api/mapping/create" `
-Method Post `
-Headers $headers `
-Body $body
```
### 使用 Python
```python
import requests
headers = {
'Content-Type': 'application/json',
'X-API-Key': 'your-secret-api-key-here'
}
data = {
'source_port': 30001,
'target_host': '192.168.1.100',
'target_port': 3306,
'use_tunnel': False
}
response = requests.post(
'http://localhost:8080/api/mapping/create',
headers=headers,
json=data
)
print(response.json())
```
## 需要认证的 API 接口
以下接口都需要提供有效的 API 密钥:
- `POST /api/mapping/create` - 创建端口映射
- `POST /api/mapping/remove` - 删除端口映射
- `GET /api/mapping/list` - 列出所有映射
- `GET /api/stats/traffic` - 获取流量统计
- `GET /api/stats/monitor` - 流量监控页面
- `GET /admin` - 管理页面
## 不需要认证的接口
- `GET /health` - 健康检查接口(公开访问)
## 错误响应
如果 API 密钥无效或缺失,服务器将返回 401 状态码:
```json
{
"success": false,
"message": "无效的 API 密钥"
}
```
## 浏览器访问
对于需要通过浏览器访问的页面(如 `/admin``/api/stats/monitor`),可以在 URL 中添加 `api_key` 参数:
```
http://localhost:8080/admin?api_key=your-secret-api-key-here
http://localhost:8080/api/stats/monitor?api_key=your-secret-api-key-here
```

View File

@ -13,6 +13,7 @@ tunnel:
# HTTP API 配置 # HTTP API 配置
api: api:
listen_port: 8080 listen_port: 8080
api_key: "your-secret-api-key-here"
# 数据库配置 # 数据库配置
database: database:

View File

@ -20,16 +20,18 @@ type Handler struct {
db *db.Database db *db.Database
forwarderMgr *forwarder.Manager forwarderMgr *forwarder.Manager
tunnelServer *tunnel.Server tunnelServer *tunnel.Server
apiKey string
// portRangeFrom int // portRangeFrom int
// portRangeEnd int // portRangeEnd int
} }
// NewHandler 创建新的 API 处理器 // NewHandler 创建新的 API 处理器
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server) *Handler { func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, apiKey string) *Handler {
return &Handler{ return &Handler{
db: database, db: database,
forwarderMgr: fwdMgr, forwarderMgr: fwdMgr,
tunnelServer: ts, tunnelServer: ts,
apiKey: apiKey,
// portRangeFrom: portFrom, // portRangeFrom: portFrom,
// portRangeEnd: portEnd, // portRangeEnd: portEnd,
} }
@ -57,15 +59,37 @@ type Response struct {
// RegisterRoutes 注册路由 // RegisterRoutes 注册路由
func (h *Handler) RegisterRoutes(mux *http.ServeMux) { func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/mapping/create", h.handleCreateMapping) mux.HandleFunc("/api/mapping/create", h.authMiddleware(h.handleCreateMapping))
mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping) mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping))
mux.HandleFunc("/api/mapping/list", h.handleListMappings) mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings))
mux.HandleFunc("/api/stats/traffic", h.handleGetTrafficStats) mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats))
mux.HandleFunc("/api/stats/monitor", h.handleTrafficMonitor) mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor))
mux.HandleFunc("/admin", h.handleManagement) mux.HandleFunc("/admin", h.authMiddleware(h.handleManagement))
mux.HandleFunc("/health", h.handleHealth) mux.HandleFunc("/health", h.handleHealth)
} }
// authMiddleware 认证中间件
func (h *Handler) authMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 从请求头中获取 API Key
apiKey := r.Header.Get("X-API-Key")
// 如果请求头中没有,尝试从查询参数中获取
if apiKey == "" {
apiKey = r.URL.Query().Get("api_key")
}
// 验证 API Key
if apiKey != h.apiKey {
h.writeError(w, http.StatusUnauthorized, "无效的 API 密钥")
return
}
// 认证通过,继续处理请求
next(w, r)
}
}
// validateHostOrIP 验证主机名或IP地址 // validateHostOrIP 验证主机名或IP地址
func (h *Handler) validateHostOrIP(hostOrIP string) error { func (h *Handler) validateHostOrIP(hostOrIP string) error {
if hostOrIP == "" { if hostOrIP == "" {

View File

@ -14,6 +14,8 @@ import (
"testing" "testing"
) )
const testAPIKey = "test-api-key-12345"
// setupTestHandler 创建测试用的 Handler // setupTestHandler 创建测试用的 Handler
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) { func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
// 创建临时数据库 // 创建临时数据库
@ -40,7 +42,7 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
tunnelServer.Start() tunnelServer.Start()
} }
handler := NewHandler(database, fwdMgr, tunnelServer) handler := NewHandler(database, fwdMgr, tunnelServer, testAPIKey)
cleanup := func() { cleanup := func() {
fwdMgr.StopAll() fwdMgr.StopAll()
@ -54,6 +56,11 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
return handler, database, cleanup return handler, database, cleanup
} }
// addAuthHeader 添加认证头到请求
func addAuthHeader(req *http.Request) {
req.Header.Set("X-API-Key", testAPIKey)
}
// TestNewHandler 测试创建处理器 // TestNewHandler 测试创建处理器
func TestNewHandler(t *testing.T) { func TestNewHandler(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false) handler, _, cleanup := setupTestHandler(t, false)
@ -138,6 +145,7 @@ func TestHandleCreateMapping(t *testing.T) {
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleCreateMapping(w, req) handler.handleCreateMapping(w, req)
@ -192,6 +200,7 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleCreateMapping(w, req) handler.handleCreateMapping(w, req)
@ -218,6 +227,7 @@ func TestHandleCreateMappingDuplicate(t *testing.T) {
// 第一次创建 // 第一次创建
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleCreateMapping(w, req) handler.handleCreateMapping(w, req)
@ -228,6 +238,7 @@ func TestHandleCreateMappingDuplicate(t *testing.T) {
// 第二次创建(应该失败) // 第二次创建(应该失败)
body, _ = json.Marshal(reqBody) body, _ = json.Marshal(reqBody)
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w = httptest.NewRecorder() w = httptest.NewRecorder()
handler.handleCreateMapping(w, req) handler.handleCreateMapping(w, req)
@ -242,6 +253,7 @@ func TestHandleCreateMappingInvalidJSON(t *testing.T) {
defer cleanup() defer cleanup()
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json"))) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleCreateMapping(w, req) handler.handleCreateMapping(w, req)
@ -265,6 +277,7 @@ func TestHandleCreateMappingInvalidIP(t *testing.T) {
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleCreateMapping(w, req) handler.handleCreateMapping(w, req)
@ -288,6 +301,7 @@ func TestHandleCreateMappingEmptyIP(t *testing.T) {
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleCreateMapping(w, req) handler.handleCreateMapping(w, req)
@ -311,6 +325,7 @@ func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body)) req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleCreateMapping(w, req) handler.handleCreateMapping(w, req)
@ -381,6 +396,7 @@ func TestHandleListMappings(t *testing.T) {
database.AddMapping(15002, "192.168.1.102", 15002, false) database.AddMapping(15002, "192.168.1.102", 15002, false)
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil) req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleListMappings(w, req) handler.handleListMappings(w, req)
@ -410,6 +426,7 @@ func TestHandleListMappingsEmpty(t *testing.T) {
defer cleanup() defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil) req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.handleListMappings(w, req) handler.handleListMappings(w, req)
@ -447,6 +464,7 @@ func TestHandleMethodNotAllowed(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/test", nil) req := httptest.NewRequest(tt.method, "/test", nil)
addAuthHeader(req)
w := httptest.NewRecorder() w := httptest.NewRecorder()
tt.handler(w, req) tt.handler(w, req)
@ -543,7 +561,7 @@ func BenchmarkHandleHealth(b *testing.B) {
defer database.Close() defer database.Close()
fwdMgr := forwarder.NewManager() fwdMgr := forwarder.NewManager()
handler := NewHandler(database, fwdMgr, nil) handler := NewHandler(database, fwdMgr, nil, "test-api-key")
req := httptest.NewRequest(http.MethodGet, "/health", nil) req := httptest.NewRequest(http.MethodGet, "/health", nil)
@ -568,9 +586,10 @@ func BenchmarkHandleListMappings(b *testing.B) {
} }
fwdMgr := forwarder.NewManager() fwdMgr := forwarder.NewManager()
handler := NewHandler(database, fwdMgr, nil) handler := NewHandler(database, fwdMgr, nil, "test-api-key")
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil) req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
req.Header.Set("X-API-Key", "test-api-key")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -578,3 +597,73 @@ func BenchmarkHandleListMappings(b *testing.B) {
handler.handleListMappings(w, req) handler.handleListMappings(w, req)
} }
} }
// TestAuthMiddleware 测试认证中间件
func TestAuthMiddleware(t *testing.T) {
handler, _, cleanup := setupTestHandler(t, false)
defer cleanup()
tests := []struct {
name string
apiKey string
useHeader bool
useQueryParam bool
expectedStatus int
expectedMsg string
}{
{
name: "有效的API密钥(请求头)",
apiKey: testAPIKey,
useHeader: true,
expectedStatus: http.StatusOK,
},
{
name: "有效的API密钥(查询参数)",
apiKey: testAPIKey,
useQueryParam: true,
expectedStatus: http.StatusOK,
},
{
name: "无效的API密钥",
apiKey: "invalid-key",
useHeader: true,
expectedStatus: http.StatusUnauthorized,
expectedMsg: "无效的 API 密钥",
},
{
name: "缺少API密钥",
apiKey: "",
expectedStatus: http.StatusUnauthorized,
expectedMsg: "无效的 API 密钥",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url := "/api/mapping/list"
if tt.useQueryParam {
url += "?api_key=" + tt.apiKey
}
req := httptest.NewRequest(http.MethodGet, url, nil)
if tt.useHeader && tt.apiKey != "" {
req.Header.Set("X-API-Key", tt.apiKey)
}
w := httptest.NewRecorder()
handler.handleListMappings(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("状态码不正确,期望 %d得到 %d", tt.expectedStatus, w.Code)
}
if tt.expectedMsg != "" {
var result Response
json.NewDecoder(w.Body).Decode(&result)
if result.Message != tt.expectedMsg {
t.Errorf("错误消息不正确,期望 '%s',得到 '%s'", tt.expectedMsg, result.Message)
}
}
})
}
}

View File

@ -30,6 +30,7 @@ type TunnelConfig struct {
// APIConfig HTTP API 配置 // APIConfig HTTP API 配置
type APIConfig struct { type APIConfig struct {
ListenPort int `yaml:"listen_port"` ListenPort int `yaml:"listen_port"`
APIKey string `yaml:"api_key"`
} }
// DatabaseConfig 数据库配置 // DatabaseConfig 数据库配置
@ -74,6 +75,9 @@ func (c *Config) Validate() error {
if c.API.ListenPort <= 0 { if c.API.ListenPort <= 0 {
return fmt.Errorf("API 端口必须大于 0") return fmt.Errorf("API 端口必须大于 0")
} }
if c.API.APIKey == "" {
return fmt.Errorf("API 密钥不能为空")
}
if c.Database.Path == "" { if c.Database.Path == "" {
return fmt.Errorf("数据库路径不能为空") return fmt.Errorf("数据库路径不能为空")
} }

View File

@ -96,7 +96,7 @@ func (s *serverService) Start() error {
// 创建 HTTP API 处理器 // 创建 HTTP API 处理器
log.Println("初始化 HTTP API...") log.Println("初始化 HTTP API...")
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer) s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer, cfg.API.APIKey)
// 启动 HTTP API 服务器 // 启动 HTTP API 服务器
go func() { go func() {