feat: 新增api keys
This commit is contained in:
parent
550043f3f4
commit
a3ddc33d17
|
|
@ -0,0 +1,136 @@
|
||||||
|
# API 认证说明
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
所有 API 接口(除了 `/health` 健康检查接口)都需要提供有效的 API 密钥才能访问。
|
||||||
|
|
||||||
|
## 配置 API 密钥
|
||||||
|
|
||||||
|
在配置文件 `config.yaml` 中设置 API 密钥:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# HTTP API 配置
|
||||||
|
api:
|
||||||
|
listen_port: 8080
|
||||||
|
api_key: "your-secret-api-key-here" # 修改为你的密钥
|
||||||
|
```
|
||||||
|
|
||||||
|
**安全建议:**
|
||||||
|
- 使用强密码生成器生成复杂的 API 密钥
|
||||||
|
- 定期更换 API 密钥
|
||||||
|
- 不要在公共代码库中提交包含真实密钥的配置文件
|
||||||
|
|
||||||
|
## 使用 API 密钥
|
||||||
|
|
||||||
|
### 方式 1: 通过 HTTP 请求头(推荐)
|
||||||
|
|
||||||
|
在请求头中添加 `X-API-Key`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/api/mapping/create \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "X-API-Key: your-secret-api-key-here" \
|
||||||
|
-d '{
|
||||||
|
"source_port": 30001,
|
||||||
|
"target_host": "192.168.1.100",
|
||||||
|
"target_port": 3306,
|
||||||
|
"use_tunnel": false
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方式 2: 通过 URL 查询参数
|
||||||
|
|
||||||
|
在 URL 中添加 `api_key` 参数:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://localhost:8080/api/mapping/create?api_key=your-secret-api-key-here" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"source_port": 30001,
|
||||||
|
"target_host": "192.168.1.100",
|
||||||
|
"target_port": 3306,
|
||||||
|
"use_tunnel": false
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用 PowerShell
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
$headers = @{
|
||||||
|
"Content-Type" = "application/json"
|
||||||
|
"X-API-Key" = "your-secret-api-key-here"
|
||||||
|
}
|
||||||
|
|
||||||
|
$body = @{
|
||||||
|
source_port = 30001
|
||||||
|
target_host = "192.168.1.100"
|
||||||
|
target_port = 3306
|
||||||
|
use_tunnel = $false
|
||||||
|
} | ConvertTo-Json
|
||||||
|
|
||||||
|
Invoke-RestMethod -Uri "http://localhost:8080/api/mapping/create" `
|
||||||
|
-Method Post `
|
||||||
|
-Headers $headers `
|
||||||
|
-Body $body
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用 Python
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'X-API-Key': 'your-secret-api-key-here'
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'source_port': 30001,
|
||||||
|
'target_host': '192.168.1.100',
|
||||||
|
'target_port': 3306,
|
||||||
|
'use_tunnel': False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
'http://localhost:8080/api/mapping/create',
|
||||||
|
headers=headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.json())
|
||||||
|
```
|
||||||
|
|
||||||
|
## 需要认证的 API 接口
|
||||||
|
|
||||||
|
以下接口都需要提供有效的 API 密钥:
|
||||||
|
|
||||||
|
- `POST /api/mapping/create` - 创建端口映射
|
||||||
|
- `POST /api/mapping/remove` - 删除端口映射
|
||||||
|
- `GET /api/mapping/list` - 列出所有映射
|
||||||
|
- `GET /api/stats/traffic` - 获取流量统计
|
||||||
|
- `GET /api/stats/monitor` - 流量监控页面
|
||||||
|
- `GET /admin` - 管理页面
|
||||||
|
|
||||||
|
## 不需要认证的接口
|
||||||
|
|
||||||
|
- `GET /health` - 健康检查接口(公开访问)
|
||||||
|
|
||||||
|
## 错误响应
|
||||||
|
|
||||||
|
如果 API 密钥无效或缺失,服务器将返回 401 状态码:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的 API 密钥"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 浏览器访问
|
||||||
|
|
||||||
|
对于需要通过浏览器访问的页面(如 `/admin` 和 `/api/stats/monitor`),可以在 URL 中添加 `api_key` 参数:
|
||||||
|
|
||||||
|
```
|
||||||
|
http://localhost:8080/admin?api_key=your-secret-api-key-here
|
||||||
|
http://localhost:8080/api/stats/monitor?api_key=your-secret-api-key-here
|
||||||
|
```
|
||||||
|
|
@ -13,6 +13,7 @@ tunnel:
|
||||||
# HTTP API 配置
|
# HTTP API 配置
|
||||||
api:
|
api:
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
|
api_key: "your-secret-api-key-here"
|
||||||
|
|
||||||
# 数据库配置
|
# 数据库配置
|
||||||
database:
|
database:
|
||||||
|
|
|
||||||
|
|
@ -17,19 +17,21 @@ import (
|
||||||
|
|
||||||
// Handler HTTP API 处理器
|
// Handler HTTP API 处理器
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
db *db.Database
|
db *db.Database
|
||||||
forwarderMgr *forwarder.Manager
|
forwarderMgr *forwarder.Manager
|
||||||
tunnelServer *tunnel.Server
|
tunnelServer *tunnel.Server
|
||||||
|
apiKey string
|
||||||
// portRangeFrom int
|
// portRangeFrom int
|
||||||
// portRangeEnd int
|
// portRangeEnd int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler 创建新的 API 处理器
|
// NewHandler 创建新的 API 处理器
|
||||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server) *Handler {
|
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, apiKey string) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
db: database,
|
db: database,
|
||||||
forwarderMgr: fwdMgr,
|
forwarderMgr: fwdMgr,
|
||||||
tunnelServer: ts,
|
tunnelServer: ts,
|
||||||
|
apiKey: apiKey,
|
||||||
// portRangeFrom: portFrom,
|
// portRangeFrom: portFrom,
|
||||||
// portRangeEnd: portEnd,
|
// portRangeEnd: portEnd,
|
||||||
}
|
}
|
||||||
|
|
@ -57,15 +59,37 @@ type Response struct {
|
||||||
|
|
||||||
// RegisterRoutes 注册路由
|
// RegisterRoutes 注册路由
|
||||||
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/api/mapping/create", h.handleCreateMapping)
|
mux.HandleFunc("/api/mapping/create", h.authMiddleware(h.handleCreateMapping))
|
||||||
mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping)
|
mux.HandleFunc("/api/mapping/remove", h.authMiddleware(h.handleRemoveMapping))
|
||||||
mux.HandleFunc("/api/mapping/list", h.handleListMappings)
|
mux.HandleFunc("/api/mapping/list", h.authMiddleware(h.handleListMappings))
|
||||||
mux.HandleFunc("/api/stats/traffic", h.handleGetTrafficStats)
|
mux.HandleFunc("/api/stats/traffic", h.authMiddleware(h.handleGetTrafficStats))
|
||||||
mux.HandleFunc("/api/stats/monitor", h.handleTrafficMonitor)
|
mux.HandleFunc("/api/stats/monitor", h.authMiddleware(h.handleTrafficMonitor))
|
||||||
mux.HandleFunc("/admin", h.handleManagement)
|
mux.HandleFunc("/admin", h.authMiddleware(h.handleManagement))
|
||||||
mux.HandleFunc("/health", h.handleHealth)
|
mux.HandleFunc("/health", h.handleHealth)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authMiddleware 认证中间件
|
||||||
|
func (h *Handler) authMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 从请求头中获取 API Key
|
||||||
|
apiKey := r.Header.Get("X-API-Key")
|
||||||
|
|
||||||
|
// 如果请求头中没有,尝试从查询参数中获取
|
||||||
|
if apiKey == "" {
|
||||||
|
apiKey = r.URL.Query().Get("api_key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 API Key
|
||||||
|
if apiKey != h.apiKey {
|
||||||
|
h.writeError(w, http.StatusUnauthorized, "无效的 API 密钥")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 认证通过,继续处理请求
|
||||||
|
next(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// validateHostOrIP 验证主机名或IP地址
|
// validateHostOrIP 验证主机名或IP地址
|
||||||
func (h *Handler) validateHostOrIP(hostOrIP string) error {
|
func (h *Handler) validateHostOrIP(hostOrIP string) error {
|
||||||
if hostOrIP == "" {
|
if hostOrIP == "" {
|
||||||
|
|
@ -156,7 +180,7 @@ func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
|
||||||
// 直接模式:直接TCP转发
|
// 直接模式:直接TCP转发
|
||||||
err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort)
|
err = h.forwarderMgr.Add(req.SourcePort, req.TargetHost, req.TargetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 回滚数据库操作
|
// 回滚数据库操作
|
||||||
h.db.RemoveMapping(req.SourcePort)
|
h.db.RemoveMapping(req.SourcePort)
|
||||||
|
|
@ -305,11 +329,11 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request)
|
||||||
|
|
||||||
// 获取所有端口映射的流量统计
|
// 获取所有端口映射的流量统计
|
||||||
forwarderStats := h.forwarderMgr.GetAllTrafficStats()
|
forwarderStats := h.forwarderMgr.GetAllTrafficStats()
|
||||||
|
|
||||||
// 构建响应
|
// 构建响应
|
||||||
mappings := make([]stats.PortTrafficStats, 0, len(forwarderStats))
|
mappings := make([]stats.PortTrafficStats, 0, len(forwarderStats))
|
||||||
var totalSent, totalReceived uint64
|
var totalSent, totalReceived uint64
|
||||||
|
|
||||||
for port, stat := range forwarderStats {
|
for port, stat := range forwarderStats {
|
||||||
mappings = append(mappings, stats.PortTrafficStats{
|
mappings = append(mappings, stats.PortTrafficStats{
|
||||||
Port: port,
|
Port: port,
|
||||||
|
|
@ -320,11 +344,11 @@ func (h *Handler) handleGetTrafficStats(w http.ResponseWriter, r *http.Request)
|
||||||
totalSent += stat.BytesSent
|
totalSent += stat.BytesSent
|
||||||
totalReceived += stat.BytesReceived
|
totalReceived += stat.BytesReceived
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加上隧道的流量
|
// 加上隧道的流量
|
||||||
totalSent += tunnelStats.BytesSent
|
totalSent += tunnelStats.BytesSent
|
||||||
totalReceived += tunnelStats.BytesReceived
|
totalReceived += tunnelStats.BytesReceived
|
||||||
|
|
||||||
response := stats.AllTrafficStats{
|
response := stats.AllTrafficStats{
|
||||||
Tunnel: tunnelStats,
|
Tunnel: tunnelStats,
|
||||||
Mappings: mappings,
|
Mappings: mappings,
|
||||||
|
|
@ -346,4 +370,4 @@ func (h *Handler) handleTrafficMonitor(w http.ResponseWriter, r *http.Request) {
|
||||||
func (h *Handler) handleManagement(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) handleManagement(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
fmt.Fprint(w, managementHTML)
|
fmt.Fprint(w, managementHTML)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,20 +14,22 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const testAPIKey = "test-api-key-12345"
|
||||||
|
|
||||||
// setupTestHandler 创建测试用的 Handler
|
// setupTestHandler 创建测试用的 Handler
|
||||||
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
|
func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, func()) {
|
||||||
// 创建临时数据库
|
// 创建临时数据库
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
dbPath := filepath.Join(tmpDir, "test.db")
|
dbPath := filepath.Join(tmpDir, "test.db")
|
||||||
|
|
||||||
database, err := db.New(dbPath)
|
database, err := db.New(dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("创建数据库失败: %v", err)
|
t.Fatalf("创建数据库失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建转发器管理器
|
// 创建转发器管理器
|
||||||
fwdMgr := forwarder.NewManager()
|
fwdMgr := forwarder.NewManager()
|
||||||
|
|
||||||
// 创建隧道服务器(如果启用)
|
// 创建隧道服务器(如果启用)
|
||||||
var tunnelServer *tunnel.Server
|
var tunnelServer *tunnel.Server
|
||||||
if useTunnel {
|
if useTunnel {
|
||||||
|
|
@ -35,13 +37,13 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
|
||||||
listener, _ := net.Listen("tcp", "127.0.0.1:0")
|
listener, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||||
port := listener.Addr().(*net.TCPAddr).Port
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
listener.Close()
|
listener.Close()
|
||||||
|
|
||||||
tunnelServer = tunnel.NewServer(port)
|
tunnelServer = tunnel.NewServer(port)
|
||||||
tunnelServer.Start()
|
tunnelServer.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := NewHandler(database, fwdMgr, tunnelServer)
|
handler := NewHandler(database, fwdMgr, tunnelServer, testAPIKey)
|
||||||
|
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
fwdMgr.StopAll()
|
fwdMgr.StopAll()
|
||||||
if tunnelServer != nil {
|
if tunnelServer != nil {
|
||||||
|
|
@ -50,23 +52,28 @@ func setupTestHandler(t *testing.T, useTunnel bool) (*Handler, *db.Database, fun
|
||||||
database.Close()
|
database.Close()
|
||||||
os.RemoveAll(tmpDir)
|
os.RemoveAll(tmpDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
return handler, database, cleanup
|
return handler, database, cleanup
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addAuthHeader 添加认证头到请求
|
||||||
|
func addAuthHeader(req *http.Request) {
|
||||||
|
req.Header.Set("X-API-Key", testAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
// TestNewHandler 测试创建处理器
|
// TestNewHandler 测试创建处理器
|
||||||
func TestNewHandler(t *testing.T) {
|
func TestNewHandler(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
t.Fatal("创建处理器失败")
|
t.Fatal("创建处理器失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if handler.portRangeFrom != 10000 {
|
// if handler.portRangeFrom != 10000 {
|
||||||
// t.Errorf("起始端口不正确,期望 10000,得到 %d", handler.portRangeFrom)
|
// t.Errorf("起始端口不正确,期望 10000,得到 %d", handler.portRangeFrom)
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// if handler.portRangeEnd != 20000 {
|
// if handler.portRangeEnd != 20000 {
|
||||||
// t.Errorf("结束端口不正确,期望 20000,得到 %d", handler.portRangeEnd)
|
// t.Errorf("结束端口不正确,期望 20000,得到 %d", handler.portRangeEnd)
|
||||||
// }
|
// }
|
||||||
|
|
@ -76,22 +83,22 @@ func TestNewHandler(t *testing.T) {
|
||||||
func TestHandleHealth(t *testing.T) {
|
func TestHandleHealth(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleHealth(w, req)
|
handler.handleHealth(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
err := json.NewDecoder(w.Body).Decode(&result)
|
err := json.NewDecoder(w.Body).Decode(&result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("解析响应失败: %v", err)
|
t.Fatalf("解析响应失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result["status"] != "ok" {
|
if result["status"] != "ok" {
|
||||||
t.Errorf("健康状态不正确,期望 ok,得到 %v", result["status"])
|
t.Errorf("健康状态不正确,期望 ok,得到 %v", result["status"])
|
||||||
}
|
}
|
||||||
|
|
@ -101,23 +108,23 @@ func TestHandleHealth(t *testing.T) {
|
||||||
func TestHandleHealthWithTunnel(t *testing.T) {
|
func TestHandleHealthWithTunnel(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, true)
|
handler, _, cleanup := setupTestHandler(t, true)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleHealth(w, req)
|
handler.handleHealth(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
json.NewDecoder(w.Body).Decode(&result)
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
|
||||||
if result["tunnel_enabled"] != true {
|
if result["tunnel_enabled"] != true {
|
||||||
t.Error("隧道应该启用")
|
t.Error("隧道应该启用")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 隧道未连接客户端时应该为 false
|
// 隧道未连接客户端时应该为 false
|
||||||
if result["tunnel_connected"] != false {
|
if result["tunnel_connected"] != false {
|
||||||
t.Error("隧道应该未连接")
|
t.Error("隧道应该未连接")
|
||||||
|
|
@ -128,41 +135,42 @@ func TestHandleHealthWithTunnel(t *testing.T) {
|
||||||
func TestHandleCreateMapping(t *testing.T) {
|
func TestHandleCreateMapping(t *testing.T) {
|
||||||
handler, database, cleanup := setupTestHandler(t, false)
|
handler, database, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
reqBody := CreateMappingRequest{
|
reqBody := CreateMappingRequest{
|
||||||
// Port: 15000,
|
// Port: 15000,
|
||||||
SourcePort: 15000,
|
SourcePort: 15000,
|
||||||
TargetPort: 15000,
|
TargetPort: 15000,
|
||||||
TargetHost: "192.168.1.100",
|
TargetHost: "192.168.1.100",
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result Response
|
var result Response
|
||||||
json.NewDecoder(w.Body).Decode(&result)
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
|
||||||
if !result.Success {
|
if !result.Success {
|
||||||
t.Errorf("创建映射失败: %s", result.Message)
|
t.Errorf("创建映射失败: %s", result.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证数据库中存在映射
|
// 验证数据库中存在映射
|
||||||
mapping, err := database.GetMapping(15000)
|
mapping, err := database.GetMapping(15000)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("获取映射失败: %v", err)
|
t.Fatalf("获取映射失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if mapping == nil {
|
if mapping == nil {
|
||||||
t.Fatal("映射不存在")
|
t.Fatal("映射不存在")
|
||||||
}
|
}
|
||||||
|
|
||||||
if mapping.TargetHost != "192.168.1.100" {
|
if mapping.TargetHost != "192.168.1.100" {
|
||||||
t.Errorf("目标 IP 不正确,期望 192.168.1.100,得到 %s", mapping.TargetHost)
|
t.Errorf("目标 IP 不正确,期望 192.168.1.100,得到 %s", mapping.TargetHost)
|
||||||
}
|
}
|
||||||
|
|
@ -172,7 +180,7 @@ func TestHandleCreateMapping(t *testing.T) {
|
||||||
func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
port int
|
port int
|
||||||
|
|
@ -181,7 +189,7 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
||||||
{"端口太大", 25000},
|
{"端口太大", 25000},
|
||||||
{"端口为0", 0},
|
{"端口为0", 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
reqBody := CreateMappingRequest{
|
reqBody := CreateMappingRequest{
|
||||||
|
|
@ -189,13 +197,14 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
||||||
TargetPort: tt.port,
|
TargetPort: tt.port,
|
||||||
TargetHost: "192.168.1.100",
|
TargetHost: "192.168.1.100",
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusBadRequest {
|
||||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
@ -207,30 +216,32 @@ func TestHandleCreateMappingInvalidPort(t *testing.T) {
|
||||||
func TestHandleCreateMappingDuplicate(t *testing.T) {
|
func TestHandleCreateMappingDuplicate(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
reqBody := CreateMappingRequest{
|
reqBody := CreateMappingRequest{
|
||||||
// Port: 15000,
|
// Port: 15000,
|
||||||
SourcePort: 15000,
|
SourcePort: 15000,
|
||||||
TargetPort: 15000,
|
TargetPort: 15000,
|
||||||
TargetHost: "192.168.1.100",
|
TargetHost: "192.168.1.100",
|
||||||
}
|
}
|
||||||
|
|
||||||
// 第一次创建
|
// 第一次创建
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Fatalf("第一次创建失败")
|
t.Fatalf("第一次创建失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 第二次创建(应该失败)
|
// 第二次创建(应该失败)
|
||||||
body, _ = json.Marshal(reqBody)
|
body, _ = json.Marshal(reqBody)
|
||||||
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req = httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusConflict {
|
if w.Code != http.StatusConflict {
|
||||||
t.Errorf("状态码不正确,期望 409,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 409,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
@ -240,12 +251,13 @@ func TestHandleCreateMappingDuplicate(t *testing.T) {
|
||||||
func TestHandleCreateMappingInvalidJSON(t *testing.T) {
|
func TestHandleCreateMappingInvalidJSON(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader([]byte("invalid json")))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusBadRequest {
|
||||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
@ -255,20 +267,21 @@ func TestHandleCreateMappingInvalidJSON(t *testing.T) {
|
||||||
func TestHandleCreateMappingInvalidIP(t *testing.T) {
|
func TestHandleCreateMappingInvalidIP(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
reqBody := CreateMappingRequest{
|
reqBody := CreateMappingRequest{
|
||||||
// Port: 15000,
|
// Port: 15000,
|
||||||
SourcePort: 15000,
|
SourcePort: 15000,
|
||||||
TargetPort: 15000,
|
TargetPort: 15000,
|
||||||
TargetHost: "", // 使用空字符串而不是无效域名,避免DNS查询超时
|
TargetHost: "", // 使用空字符串而不是无效域名,避免DNS查询超时
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusBadRequest {
|
||||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
@ -278,20 +291,21 @@ func TestHandleCreateMappingInvalidIP(t *testing.T) {
|
||||||
func TestHandleCreateMappingEmptyIP(t *testing.T) {
|
func TestHandleCreateMappingEmptyIP(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
reqBody := CreateMappingRequest{
|
reqBody := CreateMappingRequest{
|
||||||
// Port: 15000,
|
// Port: 15000,
|
||||||
SourcePort: 15000,
|
SourcePort: 15000,
|
||||||
TargetPort: 15000,
|
TargetPort: 15000,
|
||||||
TargetHost: "",
|
TargetHost: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusBadRequest {
|
||||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
@ -301,20 +315,21 @@ func TestHandleCreateMappingEmptyIP(t *testing.T) {
|
||||||
func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
|
func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, true)
|
handler, _, cleanup := setupTestHandler(t, true)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
reqBody := CreateMappingRequest{
|
reqBody := CreateMappingRequest{
|
||||||
// Port: 15000,
|
// Port: 15000,
|
||||||
SourcePort: 15000,
|
SourcePort: 15000,
|
||||||
TargetPort: 15000,
|
TargetPort: 15000,
|
||||||
UseTunnel: true, // 明确指定使用隧道模式
|
UseTunnel: true, // 明确指定使用隧道模式
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/create", bytes.NewReader(body))
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleCreateMapping(w, req)
|
handler.handleCreateMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusServiceUnavailable {
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
t.Errorf("状态码不正确,期望 503,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 503,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
@ -324,25 +339,25 @@ func TestHandleCreateMappingTunnelNotConnected(t *testing.T) {
|
||||||
func TestHandleRemoveMapping(t *testing.T) {
|
func TestHandleRemoveMapping(t *testing.T) {
|
||||||
handler, database, cleanup := setupTestHandler(t, false)
|
handler, database, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
// 先创建一个映射
|
// 先创建一个映射
|
||||||
database.AddMapping(15000, "192.168.1.100", 15000, false)
|
database.AddMapping(15000, "192.168.1.100", 15000, false)
|
||||||
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000)
|
handler.forwarderMgr.Add(15000, "192.168.1.100", 15000)
|
||||||
|
|
||||||
reqBody := RemoveMappingRequest{
|
reqBody := RemoveMappingRequest{
|
||||||
Port: 15000,
|
Port: 15000,
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleRemoveMapping(w, req)
|
handler.handleRemoveMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证映射已删除
|
// 验证映射已删除
|
||||||
mapping, _ := database.GetMapping(15000)
|
mapping, _ := database.GetMapping(15000)
|
||||||
if mapping != nil {
|
if mapping != nil {
|
||||||
|
|
@ -354,17 +369,17 @@ func TestHandleRemoveMapping(t *testing.T) {
|
||||||
func TestHandleRemoveMappingNotExist(t *testing.T) {
|
func TestHandleRemoveMappingNotExist(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
reqBody := RemoveMappingRequest{
|
reqBody := RemoveMappingRequest{
|
||||||
Port: 15000,
|
Port: 15000,
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/api/mapping/remove", bytes.NewReader(body))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleRemoveMapping(w, req)
|
handler.handleRemoveMapping(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusNotFound {
|
if w.Code != http.StatusNotFound {
|
||||||
t.Errorf("状态码不正确,期望 404,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 404,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
@ -374,31 +389,32 @@ func TestHandleRemoveMappingNotExist(t *testing.T) {
|
||||||
func TestHandleListMappings(t *testing.T) {
|
func TestHandleListMappings(t *testing.T) {
|
||||||
handler, database, cleanup := setupTestHandler(t, false)
|
handler, database, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
// 添加一些映射
|
// 添加一些映射
|
||||||
database.AddMapping(15000, "192.168.1.100", 15000, false)
|
database.AddMapping(15000, "192.168.1.100", 15000, false)
|
||||||
database.AddMapping(15001, "192.168.1.101", 15001, true)
|
database.AddMapping(15001, "192.168.1.101", 15001, true)
|
||||||
database.AddMapping(15002, "192.168.1.102", 15002, false)
|
database.AddMapping(15002, "192.168.1.102", 15002, false)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleListMappings(w, req)
|
handler.handleListMappings(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result Response
|
var result Response
|
||||||
json.NewDecoder(w.Body).Decode(&result)
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
|
||||||
if !result.Success {
|
if !result.Success {
|
||||||
t.Errorf("列出映射失败: %s", result.Message)
|
t.Errorf("列出映射失败: %s", result.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
data := result.Data.(map[string]interface{})
|
data := result.Data.(map[string]interface{})
|
||||||
count := int(data["count"].(float64))
|
count := int(data["count"].(float64))
|
||||||
|
|
||||||
if count != 3 {
|
if count != 3 {
|
||||||
t.Errorf("映射数量不正确,期望 3,得到 %d", count)
|
t.Errorf("映射数量不正确,期望 3,得到 %d", count)
|
||||||
}
|
}
|
||||||
|
|
@ -408,22 +424,23 @@ func TestHandleListMappings(t *testing.T) {
|
||||||
func TestHandleListMappingsEmpty(t *testing.T) {
|
func TestHandleListMappingsEmpty(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.handleListMappings(w, req)
|
handler.handleListMappings(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result Response
|
var result Response
|
||||||
json.NewDecoder(w.Body).Decode(&result)
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
|
||||||
data := result.Data.(map[string]interface{})
|
data := result.Data.(map[string]interface{})
|
||||||
count := int(data["count"].(float64))
|
count := int(data["count"].(float64))
|
||||||
|
|
||||||
if count != 0 {
|
if count != 0 {
|
||||||
t.Errorf("映射数量不正确,期望 0,得到 %d", count)
|
t.Errorf("映射数量不正确,期望 0,得到 %d", count)
|
||||||
}
|
}
|
||||||
|
|
@ -433,7 +450,7 @@ func TestHandleListMappingsEmpty(t *testing.T) {
|
||||||
func TestHandleMethodNotAllowed(t *testing.T) {
|
func TestHandleMethodNotAllowed(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
handler func(http.ResponseWriter, *http.Request)
|
handler func(http.ResponseWriter, *http.Request)
|
||||||
|
|
@ -443,14 +460,15 @@ func TestHandleMethodNotAllowed(t *testing.T) {
|
||||||
{"删除映射 GET", handler.handleRemoveMapping, http.MethodGet},
|
{"删除映射 GET", handler.handleRemoveMapping, http.MethodGet},
|
||||||
{"列出映射 POST", handler.handleListMappings, http.MethodPost},
|
{"列出映射 POST", handler.handleListMappings, http.MethodPost},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||||
|
addAuthHeader(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
tt.handler(w, req)
|
tt.handler(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusMethodNotAllowed {
|
if w.Code != http.StatusMethodNotAllowed {
|
||||||
t.Errorf("状态码不正确,期望 405,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 405,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
@ -462,10 +480,10 @@ func TestHandleMethodNotAllowed(t *testing.T) {
|
||||||
func TestRegisterRoutes(t *testing.T) {
|
func TestRegisterRoutes(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
handler.RegisterRoutes(mux)
|
handler.RegisterRoutes(mux)
|
||||||
|
|
||||||
// 测试路由是否注册
|
// 测试路由是否注册
|
||||||
routes := []string{
|
routes := []string{
|
||||||
"/api/mapping/create",
|
"/api/mapping/create",
|
||||||
|
|
@ -473,13 +491,13 @@ func TestRegisterRoutes(t *testing.T) {
|
||||||
"/api/mapping/list",
|
"/api/mapping/list",
|
||||||
"/health",
|
"/health",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
req := httptest.NewRequest(http.MethodGet, route, nil)
|
req := httptest.NewRequest(http.MethodGet, route, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
mux.ServeHTTP(w, req)
|
mux.ServeHTTP(w, req)
|
||||||
|
|
||||||
// 如果路由不存在,应该返回 404
|
// 如果路由不存在,应该返回 404
|
||||||
if w.Code == http.StatusNotFound {
|
if w.Code == http.StatusNotFound {
|
||||||
t.Errorf("路由 %s 未注册", route)
|
t.Errorf("路由 %s 未注册", route)
|
||||||
|
|
@ -491,21 +509,21 @@ func TestRegisterRoutes(t *testing.T) {
|
||||||
func TestWriteSuccess(t *testing.T) {
|
func TestWriteSuccess(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.writeSuccess(w, "测试成功", map[string]string{"key": "value"})
|
handler.writeSuccess(w, "测试成功", map[string]string{"key": "value"})
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 200,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result Response
|
var result Response
|
||||||
json.NewDecoder(w.Body).Decode(&result)
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
|
||||||
if !result.Success {
|
if !result.Success {
|
||||||
t.Error("Success 应该为 true")
|
t.Error("Success 应该为 true")
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.Message != "测试成功" {
|
if result.Message != "测试成功" {
|
||||||
t.Errorf("消息不正确,期望 '测试成功',得到 '%s'", result.Message)
|
t.Errorf("消息不正确,期望 '测试成功',得到 '%s'", result.Message)
|
||||||
}
|
}
|
||||||
|
|
@ -515,21 +533,21 @@ func TestWriteSuccess(t *testing.T) {
|
||||||
func TestWriteError(t *testing.T) {
|
func TestWriteError(t *testing.T) {
|
||||||
handler, _, cleanup := setupTestHandler(t, false)
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.writeError(w, http.StatusBadRequest, "测试错误")
|
handler.writeError(w, http.StatusBadRequest, "测试错误")
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
if w.Code != http.StatusBadRequest {
|
||||||
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
t.Errorf("状态码不正确,期望 400,得到 %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result Response
|
var result Response
|
||||||
json.NewDecoder(w.Body).Decode(&result)
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
|
||||||
if result.Success {
|
if result.Success {
|
||||||
t.Error("Success 应该为 false")
|
t.Error("Success 应该为 false")
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.Message != "测试错误" {
|
if result.Message != "测试错误" {
|
||||||
t.Errorf("消息不正确,期望 '测试错误',得到 '%s'", result.Message)
|
t.Errorf("消息不正确,期望 '测试错误',得到 '%s'", result.Message)
|
||||||
}
|
}
|
||||||
|
|
@ -541,12 +559,12 @@ func BenchmarkHandleHealth(b *testing.B) {
|
||||||
dbPath := filepath.Join(tmpDir, "bench.db")
|
dbPath := filepath.Join(tmpDir, "bench.db")
|
||||||
database, _ := db.New(dbPath)
|
database, _ := db.New(dbPath)
|
||||||
defer database.Close()
|
defer database.Close()
|
||||||
|
|
||||||
fwdMgr := forwarder.NewManager()
|
fwdMgr := forwarder.NewManager()
|
||||||
handler := NewHandler(database, fwdMgr, nil)
|
handler := NewHandler(database, fwdMgr, nil, "test-api-key")
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
@ -560,21 +578,92 @@ func BenchmarkHandleListMappings(b *testing.B) {
|
||||||
dbPath := filepath.Join(tmpDir, "bench.db")
|
dbPath := filepath.Join(tmpDir, "bench.db")
|
||||||
database, _ := db.New(dbPath)
|
database, _ := db.New(dbPath)
|
||||||
defer database.Close()
|
defer database.Close()
|
||||||
|
|
||||||
// 添加一些映射
|
// 添加一些映射
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
useTunnel := i%2 == 0 // 偶数使用隧道模式
|
useTunnel := i%2 == 0 // 偶数使用隧道模式
|
||||||
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel)
|
database.AddMapping(10000+i, "192.168.1.1", 10000+i, useTunnel)
|
||||||
}
|
}
|
||||||
|
|
||||||
fwdMgr := forwarder.NewManager()
|
fwdMgr := forwarder.NewManager()
|
||||||
handler := NewHandler(database, fwdMgr, nil)
|
handler := NewHandler(database, fwdMgr, nil, "test-api-key")
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/mapping/list", nil)
|
||||||
|
req.Header.Set("X-API-Key", "test-api-key")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.handleListMappings(w, req)
|
handler.handleListMappings(w, req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAuthMiddleware 测试认证中间件
|
||||||
|
func TestAuthMiddleware(t *testing.T) {
|
||||||
|
handler, _, cleanup := setupTestHandler(t, false)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
apiKey string
|
||||||
|
useHeader bool
|
||||||
|
useQueryParam bool
|
||||||
|
expectedStatus int
|
||||||
|
expectedMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "有效的API密钥(请求头)",
|
||||||
|
apiKey: testAPIKey,
|
||||||
|
useHeader: true,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "有效的API密钥(查询参数)",
|
||||||
|
apiKey: testAPIKey,
|
||||||
|
useQueryParam: true,
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无效的API密钥",
|
||||||
|
apiKey: "invalid-key",
|
||||||
|
useHeader: true,
|
||||||
|
expectedStatus: http.StatusUnauthorized,
|
||||||
|
expectedMsg: "无效的 API 密钥",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "缺少API密钥",
|
||||||
|
apiKey: "",
|
||||||
|
expectedStatus: http.StatusUnauthorized,
|
||||||
|
expectedMsg: "无效的 API 密钥",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
url := "/api/mapping/list"
|
||||||
|
if tt.useQueryParam {
|
||||||
|
url += "?api_key=" + tt.apiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if tt.useHeader && tt.apiKey != "" {
|
||||||
|
req.Header.Set("X-API-Key", tt.apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.handleListMappings(w, req)
|
||||||
|
|
||||||
|
if w.Code != tt.expectedStatus {
|
||||||
|
t.Errorf("状态码不正确,期望 %d,得到 %d", tt.expectedStatus, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectedMsg != "" {
|
||||||
|
var result Response
|
||||||
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
if result.Message != tt.expectedMsg {
|
||||||
|
t.Errorf("错误消息不正确,期望 '%s',得到 '%s'", tt.expectedMsg, result.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,9 @@ import (
|
||||||
// Config 应用配置结构
|
// Config 应用配置结构
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// PortRange PortRangeConfig `yaml:"port_range"`
|
// PortRange PortRangeConfig `yaml:"port_range"`
|
||||||
Tunnel TunnelConfig `yaml:"tunnel"`
|
Tunnel TunnelConfig `yaml:"tunnel"`
|
||||||
API APIConfig `yaml:"api"`
|
API APIConfig `yaml:"api"`
|
||||||
Database DatabaseConfig `yaml:"database"`
|
Database DatabaseConfig `yaml:"database"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PortRangeConfig 端口范围配置
|
// PortRangeConfig 端口范围配置
|
||||||
|
|
@ -29,7 +29,8 @@ type TunnelConfig struct {
|
||||||
|
|
||||||
// APIConfig HTTP API 配置
|
// APIConfig HTTP API 配置
|
||||||
type APIConfig struct {
|
type APIConfig struct {
|
||||||
ListenPort int `yaml:"listen_port"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
|
APIKey string `yaml:"api_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseConfig 数据库配置
|
// DatabaseConfig 数据库配置
|
||||||
|
|
@ -74,8 +75,11 @@ func (c *Config) Validate() error {
|
||||||
if c.API.ListenPort <= 0 {
|
if c.API.ListenPort <= 0 {
|
||||||
return fmt.Errorf("API 端口必须大于 0")
|
return fmt.Errorf("API 端口必须大于 0")
|
||||||
}
|
}
|
||||||
|
if c.API.APIKey == "" {
|
||||||
|
return fmt.Errorf("API 密钥不能为空")
|
||||||
|
}
|
||||||
if c.Database.Path == "" {
|
if c.Database.Path == "" {
|
||||||
return fmt.Errorf("数据库路径不能为空")
|
return fmt.Errorf("数据库路径不能为空")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ func (s *serverService) Start() error {
|
||||||
|
|
||||||
// 创建 HTTP API 处理器
|
// 创建 HTTP API 处理器
|
||||||
log.Println("初始化 HTTP API...")
|
log.Println("初始化 HTTP API...")
|
||||||
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer)
|
s.apiHandler = api.NewHandler(database, s.fwdManager, s.tunnelServer, cfg.API.APIKey)
|
||||||
|
|
||||||
// 启动 HTTP API 服务器
|
// 启动 HTTP API 服务器
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue