feat: 新增api keys
This commit is contained in:
parent
550043f3f4
commit
a3ddc33d17
|
|
@ -0,0 +1,136 @@
|
|||
# API 认证说明
|
||||
|
||||
## 概述
|
||||
|
||||
所有 API 接口(除了 `/health` 健康检查接口)都需要提供有效的 API 密钥才能访问。
|
||||
|
||||
## 配置 API 密钥
|
||||
|
||||
在配置文件 `config.yaml` 中设置 API 密钥:
|
||||
|
||||
```yaml
|
||||
# HTTP API 配置
|
||||
api:
|
||||
listen_port: 8080
|
||||
api_key: "your-secret-api-key-here" # 修改为你的密钥
|
||||
```
|
||||
|
||||
**安全建议:**
|
||||
- 使用强密码生成器生成复杂的 API 密钥
|
||||
- 定期更换 API 密钥
|
||||
- 不要在公共代码库中提交包含真实密钥的配置文件
|
||||
|
||||
## 使用 API 密钥
|
||||
|
||||
### 方式 1: 通过 HTTP 请求头(推荐)
|
||||
|
||||
在请求头中添加 `X-API-Key`:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/mapping/create \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-API-Key: your-secret-api-key-here" \
|
||||
-d '{
|
||||
"source_port": 30001,
|
||||
"target_host": "192.168.1.100",
|
||||
"target_port": 3306,
|
||||
"use_tunnel": false
|
||||
}'
|
||||
```
|
||||
|
||||
### 方式 2: 通过 URL 查询参数
|
||||
|
||||
在 URL 中添加 `api_key` 参数:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8080/api/mapping/create?api_key=your-secret-api-key-here" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"source_port": 30001,
|
||||
"target_host": "192.168.1.100",
|
||||
"target_port": 3306,
|
||||
"use_tunnel": false
|
||||
}'
|
||||
```
|
||||
|
||||
### 使用 PowerShell
|
||||
|
||||
```powershell
|
||||
$headers = @{
|
||||
"Content-Type" = "application/json"
|
||||
"X-API-Key" = "your-secret-api-key-here"
|
||||
}
|
||||
|
||||
$body = @{
|
||||
source_port = 30001
|
||||
target_host = "192.168.1.100"
|
||||
target_port = 3306
|
||||
use_tunnel = $false
|
||||
} | ConvertTo-Json
|
||||
|
||||
Invoke-RestMethod -Uri "http://localhost:8080/api/mapping/create" `
|
||||
-Method Post `
|
||||
-Headers $headers `
|
||||
-Body $body
|
||||
```
|
||||
|
||||
### 使用 Python
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'X-API-Key': 'your-secret-api-key-here'
|
||||
}
|
||||
|
||||
data = {
|
||||
'source_port': 30001,
|
||||
'target_host': '192.168.1.100',
|
||||
'target_port': 3306,
|
||||
'use_tunnel': False
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://localhost:8080/api/mapping/create',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
|
||||
print(response.json())
|
||||
```
|
||||
|
||||
## 需要认证的 API 接口
|
||||
|
||||
以下接口都需要提供有效的 API 密钥:
|
||||
|
||||
- `POST /api/mapping/create` - 创建端口映射
|
||||
- `POST /api/mapping/remove` - 删除端口映射
|
||||
- `GET /api/mapping/list` - 列出所有映射
|
||||
- `GET /api/stats/traffic` - 获取流量统计
|
||||
- `GET /api/stats/monitor` - 流量监控页面
|
||||
- `GET /admin` - 管理页面
|
||||
|
||||
## 不需要认证的接口
|
||||
|
||||
- `GET /health` - 健康检查接口(公开访问)
|
||||
|
||||
## 错误响应
|
||||
|
||||
如果 API 密钥无效或缺失,服务器将返回 401 状态码:
|
||||
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"message": "无效的 API 密钥"
|
||||
}
|
||||
```
|
||||
|
||||
## 浏览器访问
|
||||
|
||||
对于需要通过浏览器访问的页面(如 `/admin` 和 `/api/stats/monitor`),可以在 URL 中添加 `api_key` 参数:
|
||||
|
||||
```
|
||||
http://localhost:8080/admin?api_key=your-secret-api-key-here
|
||||
http://localhost:8080/api/stats/monitor?api_key=your-secret-api-key-here
|
||||
```
|
||||
|
|
@ -13,6 +13,7 @@ tunnel:
|
|||
# HTTP API 配置
|
||||
api:
|
||||
listen_port: 8080
|
||||
api_key: "your-secret-api-key-here"
|
||||
|
||||
# 数据库配置
|
||||
database:
|
||||
|
|
|
|||
|
|
@ -17,19 +17,21 @@ import (
|
|||
|
||||
// Handler HTTP API 处理器
|
||||
type Handler struct {
|
||||
db *db.Database
|
||||
forwarderMgr *forwarder.Manager
|
||||
tunnelServer *tunnel.Server
|
||||
db *db.Database
|
||||
forwarderMgr *forwarder.Manager
|
||||
tunnelServer *tunnel.Server
|
||||
apiKey string
|
||||
// portRangeFrom int
|
||||
// portRangeEnd int
|
||||
}
|
||||
|
||||
// NewHandler 创建新的 API 处理器
|
||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server) *Handler {
|
||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, apiKey string) *Handler {
|
||||
return &Handler{
|
||||
db: database,
|
||||
forwarderMgr: fwdMgr,
|
||||
tunnelServer: ts,
|
||||
db: database,
|
||||
forwarderMgr: fwdMgr,
|
||||
tunnelServer: ts,
|
||||
apiKey: apiKey,
|
||||
// portRangeFrom: portFrom,
|
||||
// portRangeEnd: portEnd,
|
||||
}
|
||||
|
|
@ -57,15 +59,37 @@ type Response struct {
|
|||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/mapping/create", h.handleCreateMapping)
|
||||
mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping)
|
||||
mux.HandleFunc("/api/mapping/list", h.handleListMappings)
|
||||
mux.HandleFunc("/api/stats/traffic", h.handleGetTrafficStats)
|
||||
mux.HandleFunc("/api/stats/monitor", h.handleTrafficMonitor)
|
||||
mux.HandleFunc("/admin", h.handleManagement)
|
||||
mux.HandleFunc("/api/mapping/create", h.authMiddleware(h.handleCreateMapping))
|
||||
mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping))
|
||||
mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings))
|
||||
mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats))
|
||||
mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor))
|
||||
mux.HandleFunc("/admin", h.authMiddleware(h.handleManagement))
|
||||
mux.HandleFunc("/health", h.handleHealth)
|
||||
}
|
||||
|
||||
// authMiddleware 认证中间件
|
||||
func (h *Handler) authMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// 从请求头中获取 API Key
|
||||
apiKey := r.Header.Get("X-API-Key")
|
||||
|
||||
// 如果请求头中没有,尝试从查询参数中获取
|
||||
if apiKey == "" {
|
||||
apiKey = r.URL.Query().Get("api_key")
|
||||
}
|
||||
|
||||
// 验证 API Key
|
||||
if apiKey != h.apiKey {
|
||||
h.writeError(w, http.StatusUnauthorized, "无效的 API 密钥")
|
||||
return
|
||||
}
|
||||
|
||||
// 认证通过,继续处理请求
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// validateHostOrIP 验证主机名或IP地址
|
||||
func (h *Handler) validateHostOrIP(hostOrIP string) error {
|
||||
if hostOrIP == "" {
|
||||
|
|
@ -156,7 +180,7 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
|
|||
// 直接模式:直接TCP转发
|
||||
err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort)
|
||||
}
|
||||
|
||||
|
||||
if err != nil {
|
||||
// 回滚数据库操作
|
||||
h.db.RemoveMapping(req.SourcePort)
|
||||
|
|
@ -305,11 +329,11 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
// 获取所有端口映射的流量统计
|
||||
forwarderStats := h.forwarderMgr.GetAllTrafficStats()
|
||||
|
||||
|
||||
// 构建响应
|
||||
mappings := make([]stats.PortTrafficStats, 0, len(forwarderStats))
|
||||
var totalSent, totalReceived uint64
|
||||
|
||||
|
||||
for port, stat := range forwarderStats {
|
||||
mappings = append(mappings, stats.PortTrafficStats{
|
||||
Port: port,
|
||||
|
|
@ -320,11 +344,11 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request)
|
|||
totalSent += stat.BytesSent
|
||||
totalReceived += stat.BytesReceived
|
||||
}
|
||||
|
||||
|
||||
// 加上隧道的流量
|
||||
totalSent += tunnelStats.BytesSent
|
||||
totalReceived += tunnelStats.BytesReceived
|
||||
|
||||
|
||||
response := stats.AllTrafficStats{
|
||||
Tunnel: tunnelStats,
|
||||
Mappings: mappings,
|
||||
|
|
@ -346,4 +370,4 @@ func (h *Handler) handleTrafficMonitor(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *Handler) handleManagement(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
fmt.Fprint(w, managementHTML)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,20 +14,22 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
const testAPIKey = "test-api-key-12345"
|
||||
|
||||
// setupTestHandler 创建测试用的 Handler
|
||||
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
|
||||
// 创建临时数据库
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
|
||||
database, err := db.New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("创建数据库失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 创建转发器管理器
|
||||
fwdMgr := forwarder.NewManager()
|
||||
|
||||
|
||||
// 创建隧道服务器(如果启用)
|
||||
var tunnelServer *tunnel.Server
|
||||
if useTunnel {
|
||||
|
|
@ -35,13 +37,13 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
|
|||
listener, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
|
||||
tunnelServer = tunnel.NewServer(port)
|
||||
tunnelServer.Start()
|
||||
}
|
||||
|
||||
handler := NewHandler(database, fwdMgr, tunnelServer)
|
||||
|
||||
|
||||
handler := NewHandler(database, fwdMgr, tunnelServer, testAPIKey)
|
||||
|
||||
cleanup := func() {
|
||||
fwdMgr.StopAll()
|
||||
if tunnelServer != nil {
|
||||
|
|
@ -50,23 +52,28 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
|
|||
database.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
|
||||
return handler, database, cleanup
|
||||
}
|
||||
|
||||
// addAuthHeader 添加认证头到请求
|
||||
func addAuthHeader(req *http.Request) {
|
||||
req.Header.Set("X-API-Key", testAPIKey)
|
||||
}
|
||||
|
||||
// TestNewHandler 测试创建处理器
|
||||
func TestNewHandler(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("创建处理器失败")
|
||||
}
|
||||
|
||||
|
||||
// if handler.portRangeFrom != 10000 {
|
||||
// t.Errorf("起始端口不正确,期望 10000,得到 %d", handler.portRangeFrom)
|
||||
// }
|
||||
|
||||
|
||||
// if handler.portRangeEnd != 20000 {
|
||||
// t.Errorf("结束端口不正确,期望 20000,得到 %d", handler.portRangeEnd)
|
||||
// }
|
||||
|
|
@ -76,22 +83,22 @@ func TestNewHandler(t *testing.T) {
|
|||
func TestHandleHealth(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleHealth(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
var result map[string]interface{}
|
||||
err := json.NewDecoder(w.Body).Decode(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if result["status"] != "ok" {
|
||||
t.Errorf("健康状态不正确,期望 ok,得到 %v", result["status"])
|
||||
}
|
||||
|
|
@ -101,23 +108,23 @@ func TestHandleHealth(t *testing.T) {
|
|||
func TestHandleHealthWithTunnel(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, true)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleHealth(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
|
||||
if result["tunnel_enabled"] != true {
|
||||
t.Error("隧道应该启用")
|
||||
}
|
||||
|
||||
|
||||
// 隧道未连接客户端时应该为 false
|
||||
if result["tunnel_connected"] != false {
|
||||
t.Error("隧道应该未连接")
|
||||
|
|
@ -128,41 +135,42 @@ func TestHandleHealthWithTunnel(t *testing.T) {
|
|||
func TestHandleCreateMapping(t *testing.T) {
|
||||
handler, database, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
TargetHost: "192.168.1.100",
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
|
||||
if !result.Success {
|
||||
t.Errorf("创建映射失败: %s", result.Message)
|
||||
}
|
||||
|
||||
|
||||
// 验证数据库中存在映射
|
||||
mapping, err := database.GetMapping(15000)
|
||||
if err != nil {
|
||||
t.Fatalf("获取映射失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if mapping == nil {
|
||||
t.Fatal("映射不存在")
|
||||
}
|
||||
|
||||
|
||||
if mapping.TargetHost != "192.168.1.100" {
|
||||
t.Errorf("目标 IP 不正确,期望 192.168.1.100,得到 %s", mapping.TargetHost)
|
||||
}
|
||||
|
|
@ -172,7 +180,7 @@ func TestHandleCreateMapping(t *testing.T) {
|
|||
func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port int
|
||||
|
|
@ -181,7 +189,7 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
|||
{"端口太大", 25000},
|
||||
{"端口为0", 0},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqBody := CreateMappingRequest{
|
||||
|
|
@ -189,13 +197,14 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
|||
TargetPort: tt.port,
|
||||
TargetHost: "192.168.1.100",
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
|
|
@ -207,30 +216,32 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
|||
func TestHandleCreateMappingDuplicate(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
TargetHost: "192.168.1.100",
|
||||
}
|
||||
|
||||
|
||||
// 第一次创建
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("第一次创建失败")
|
||||
}
|
||||
|
||||
|
||||
// 第二次创建(应该失败)
|
||||
body, _ = json.Marshal(reqBody)
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
addAuthHeader(req)
|
||||
w = httptest.NewRecorder()
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusConflict {
|
||||
t.Errorf("状态码不正确,期望 409,得到 %d", w.Code)
|
||||
}
|
||||
|
|
@ -240,12 +251,13 @@ func TestHandleCreateMappingDuplicate(t *testing.T) {
|
|||
func TestHandleCreateMappingInvalidJSON(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
|
|
@ -255,20 +267,21 @@ func TestHandleCreateMappingInvalidJSON(t *testing.T) {
|
|||
func TestHandleCreateMappingInvalidIP(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
TargetHost: "", // 使用空字符串而不是无效域名,避免DNS查询超时
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
|
|
@ -278,20 +291,21 @@ func TestHandleCreateMappingInvalidIP(t *testing.T) {
|
|||
func TestHandleCreateMappingEmptyIP(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
TargetHost: "",
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
|
|
@ -301,20 +315,21 @@ func TestHandleCreateMappingEmptyIP(t *testing.T) {
|
|||
func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, true)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
reqBody := CreateMappingRequest{
|
||||
// Port: 15000,
|
||||
SourcePort: 15000,
|
||||
TargetPort: 15000,
|
||||
UseTunnel: true, // 明确指定使用隧道模式
|
||||
UseTunnel: true, // 明确指定使用隧道模式
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleCreateMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("状态码不正确,期望 503,得到 %d", w.Code)
|
||||
}
|
||||
|
|
@ -324,25 +339,25 @@ func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
|
|||
func TestHandleRemoveMapping(t *testing.T) {
|
||||
handler, database, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
// 先创建一个映射
|
||||
database.AddMapping(15000, "192.168.1.100", 15000, false)
|
||||
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000)
|
||||
|
||||
|
||||
reqBody := RemoveMappingRequest{
|
||||
Port: 15000,
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleRemoveMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
// 验证映射已删除
|
||||
mapping, _ := database.GetMapping(15000)
|
||||
if mapping != nil {
|
||||
|
|
@ -354,17 +369,17 @@ func TestHandleRemoveMapping(t *testing.T) {
|
|||
func TestHandleRemoveMappingNotExist(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
reqBody := RemoveMappingRequest{
|
||||
Port: 15000,
|
||||
}
|
||||
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleRemoveMapping(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("状态码不正确,期望 404,得到 %d", w.Code)
|
||||
}
|
||||
|
|
@ -374,31 +389,32 @@ func TestHandleRemoveMappingNotExist(t *testing.T) {
|
|||
func TestHandleListMappings(t *testing.T) {
|
||||
handler, database, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
// 添加一些映射
|
||||
database.AddMapping(15000, "192.168.1.100", 15000, false)
|
||||
database.AddMapping(15001, "192.168.1.101", 15001, true)
|
||||
database.AddMapping(15002, "192.168.1.102", 15002, false)
|
||||
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleListMappings(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
|
||||
if !result.Success {
|
||||
t.Errorf("列出映射失败: %s", result.Message)
|
||||
}
|
||||
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
count := int(data["count"].(float64))
|
||||
|
||||
|
||||
if count != 3 {
|
||||
t.Errorf("映射数量不正确,期望 3,得到 %d", count)
|
||||
}
|
||||
|
|
@ -408,22 +424,23 @@ func TestHandleListMappings(t *testing.T) {
|
|||
func TestHandleListMappingsEmpty(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
handler.handleListMappings(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
|
||||
data := result.Data.(map[string]interface{})
|
||||
count := int(data["count"].(float64))
|
||||
|
||||
|
||||
if count != 0 {
|
||||
t.Errorf("映射数量不正确,期望 0,得到 %d", count)
|
||||
}
|
||||
|
|
@ -433,7 +450,7 @@ func TestHandleListMappingsEmpty(t *testing.T) {
|
|||
func TestHandleMethodNotAllowed(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler func(http.ResponseWriter, *http.Request)
|
||||
|
|
@ -443,14 +460,15 @@ func TestHandleMethodNotAllowed(t *testing.T) {
|
|||
{"删除映射 GET", handler.handleRemoveMapping, http.MethodGet},
|
||||
{"列出映射 POST", handler.handleListMappings, http.MethodPost},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||
addAuthHeader(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
tt.handler(w, req)
|
||||
|
||||
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("状态码不正确,期望 405,得到 %d", w.Code)
|
||||
}
|
||||
|
|
@ -462,10 +480,10 @@ func TestHandleMethodNotAllowed(t *testing.T) {
|
|||
func TestRegisterRoutes(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
mux := http.NewServeMux()
|
||||
handler.RegisterRoutes(mux)
|
||||
|
||||
|
||||
// 测试路由是否注册
|
||||
routes := []string{
|
||||
"/api/mapping/create",
|
||||
|
|
@ -473,13 +491,13 @@ func TestRegisterRoutes(t *testing.T) {
|
|||
"/api/mapping/list",
|
||||
"/health",
|
||||
}
|
||||
|
||||
|
||||
for _, route := range routes {
|
||||
req := httptest.NewRequest(http.MethodGet, route, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
|
||||
mux.ServeHTTP(w, req)
|
||||
|
||||
|
||||
// 如果路由不存在,应该返回 404
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Errorf("路由 %s 未注册", route)
|
||||
|
|
@ -491,21 +509,21 @@ func TestRegisterRoutes(t *testing.T) {
|
|||
func TestWriteSuccess(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.writeSuccess(w, "测试成功", map[string]string{"key": "value"})
|
||||
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
|
||||
if !result.Success {
|
||||
t.Error("Success 应该为 true")
|
||||
}
|
||||
|
||||
|
||||
if result.Message != "测试成功" {
|
||||
t.Errorf("消息不正确,期望 '测试成功',得到 '%s'", result.Message)
|
||||
}
|
||||
|
|
@ -515,21 +533,21 @@ func TestWriteSuccess(t *testing.T) {
|
|||
func TestWriteError(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.writeError(w, http.StatusBadRequest, "测试错误")
|
||||
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
|
||||
|
||||
if result.Success {
|
||||
t.Error("Success 应该为 false")
|
||||
}
|
||||
|
||||
|
||||
if result.Message != "测试错误" {
|
||||
t.Errorf("消息不正确,期望 '测试错误',得到 '%s'", result.Message)
|
||||
}
|
||||
|
|
@ -541,12 +559,12 @@ func BenchmarkHandleHealth(b *testing.B) {
|
|||
dbPath := filepath.Join(tmpDir, "bench.db")
|
||||
database, _ := db.New(dbPath)
|
||||
defer database.Close()
|
||||
|
||||
|
||||
fwdMgr := forwarder.NewManager()
|
||||
handler := NewHandler(database, fwdMgr, nil)
|
||||
|
||||
handler := NewHandler(database, fwdMgr, nil, "test-api-key")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
|
|
@ -560,21 +578,92 @@ func BenchmarkHandleListMappings(b *testing.B) {
|
|||
dbPath := filepath.Join(tmpDir, "bench.db")
|
||||
database, _ := db.New(dbPath)
|
||||
defer database.Close()
|
||||
|
||||
|
||||
// 添加一些映射
|
||||
for i := 0; i < 100; i++ {
|
||||
useTunnel := i%2 == 0 // 偶数使用隧道模式
|
||||
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel)
|
||||
}
|
||||
|
||||
|
||||
fwdMgr := forwarder.NewManager()
|
||||
handler := NewHandler(database, fwdMgr, nil)
|
||||
|
||||
handler := NewHandler(database, fwdMgr, nil, "test-api-key")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||
|
||||
req.Header.Set("X-API-Key", "test-api-key")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
handler.handleListMappings(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware 测试认证中间件
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
handler, _, cleanup := setupTestHandler(t, false)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
apiKey string
|
||||
useHeader bool
|
||||
useQueryParam bool
|
||||
expectedStatus int
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
name: "有效的API密钥(请求头)",
|
||||
apiKey: testAPIKey,
|
||||
useHeader: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "有效的API密钥(查询参数)",
|
||||
apiKey: testAPIKey,
|
||||
useQueryParam: true,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的API密钥",
|
||||
apiKey: "invalid-key",
|
||||
useHeader: true,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedMsg: "无效的 API 密钥",
|
||||
},
|
||||
{
|
||||
name: "缺少API密钥",
|
||||
apiKey: "",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedMsg: "无效的 API 密钥",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url := "/api/mapping/list"
|
||||
if tt.useQueryParam {
|
||||
url += "?api_key=" + tt.apiKey
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||
if tt.useHeader && tt.apiKey != "" {
|
||||
req.Header.Set("X-API-Key", tt.apiKey)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.handleListMappings(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("状态码不正确,期望 %d,得到 %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.expectedMsg != "" {
|
||||
var result Response
|
||||
json.NewDecoder(w.Body).Decode(&result)
|
||||
if result.Message != tt.expectedMsg {
|
||||
t.Errorf("错误消息不正确,期望 '%s',得到 '%s'", tt.expectedMsg, result.Message)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ import (
|
|||
// Config 应用配置结构
|
||||
type Config struct {
|
||||
// PortRange PortRangeConfig `yaml:"port_range"`
|
||||
Tunnel TunnelConfig `yaml:"tunnel"`
|
||||
API APIConfig `yaml:"api"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Tunnel TunnelConfig `yaml:"tunnel"`
|
||||
API APIConfig `yaml:"api"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
}
|
||||
|
||||
// PortRangeConfig 端口范围配置
|
||||
|
|
@ -29,7 +29,8 @@ type TunnelConfig struct {
|
|||
|
||||
// APIConfig HTTP API 配置
|
||||
type APIConfig struct {
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
|
|
@ -74,8 +75,11 @@ func (c *Config) Validate() error {
|
|||
if c.API.ListenPort <= 0 {
|
||||
return fmt.Errorf("API 端口必须大于 0")
|
||||
}
|
||||
if c.API.APIKey == "" {
|
||||
return fmt.Errorf("API 密钥不能为空")
|
||||
}
|
||||
if c.Database.Path == "" {
|
||||
return fmt.Errorf("数据库路径不能为空")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ func (s *serverService) Start() error {
|
|||
|
||||
// 创建 HTTP API 处理器
|
||||
log.Println("初始化 HTTP API...")
|
||||
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer)
|
||||
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer, cfg.API.APIKey)
|
||||
|
||||
// 启动 HTTP API 服务器
|
||||
go func() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue