diff --git a/.gitignore b/.gitignore index 156f893..1f38000 100644 --- a/.gitignore +++ b/.gitignore @@ -5,10 +5,10 @@ bin/ *.dll *.so *.dylib -server -client -server-linux -client-linux +/server +/client +/server-linux +/client-linux # 测试文件 *.test diff --git a/src/client/main.go b/src/client/main.go new file mode 100644 index 0000000..b374c04 --- /dev/null +++ b/src/client/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "flag" + "log" + "os" + "os/signal" + "port-forward/client/tunnel" + "syscall" +) + +func main() { + // 解析命令行参数 + serverAddr := flag.String("server", "localhost:9000", "隧道服务器地址 (host:port)") + flag.Parse() + + log.SetFlags(log.LstdFlags | log.Lshortfile) + + // 创建隧道客户端 + log.Printf("隧道客户端启动...") + log.Printf("服务器地址: %s", *serverAddr) + + client := tunnel.NewClient(*serverAddr) + + // 启动客户端 + if err := client.Start(); err != nil { + log.Fatalf("启动隧道客户端失败: %v", err) + } + + log.Println("===========================================") + log.Println("隧道客户端运行中...") + log.Println("按 Ctrl+C 退出") + log.Println("===========================================") + + // 等待中断信号 + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + <-sigChan + log.Println("\n接收到关闭信号,正在关闭...") + + // 停止客户端 + if err := client.Stop(); err != nil { + log.Printf("停止客户端失败: %v", err) + } + + log.Println("客户端已关闭") +} \ No newline at end of file diff --git a/src/client/tunnel/client.go b/src/client/tunnel/client.go new file mode 100644 index 0000000..6c2fd3f --- /dev/null +++ b/src/client/tunnel/client.go @@ -0,0 +1,285 @@ +package tunnel + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +const ( + // HeaderSize 消息头大小 + HeaderSize = 8 + // MaxPacketSize 最大包大小 + MaxPacketSize = 1024 * 1024 + // ReconnectDelay 重连延迟 + ReconnectDelay = 5 * time.Second +) + +// Client 内网穿透客户端 +type Client struct { + serverAddr string + serverConn net.Conn + cancel context.CancelFunc + ctx context.Context + wg sync.WaitGroup + mu sync.RWMutex + + // 连接管理 + connections map[uint32]*LocalConnection + connMu sync.RWMutex +} + +// LocalConnection 本地连接 +type LocalConnection struct { + ID uint32 + TargetAddr string + Conn net.Conn + closeChan chan struct{} +} + +// NewClient 创建新的隧道客户端 +func NewClient(serverAddr string) *Client { + ctx, cancel := context.WithCancel(context.Background()) + return &Client{ + serverAddr: serverAddr, + cancel: cancel, + ctx: ctx, + connections: make(map[uint32]*LocalConnection), + } +} + +// Start 启动隧道客户端 +func (c *Client) Start() error { + log.Printf("正在连接到隧道服务器: %s", c.serverAddr) + + c.wg.Add(1) + go c.connectLoop() + + return nil +} + +// connectLoop 连接循环(支持自动重连) +func (c *Client) connectLoop() { + defer c.wg.Done() + + for { + select { + case <-c.ctx.Done(): + return + default: + } + + conn, err := net.DialTimeout("tcp", c.serverAddr, 10*time.Second) + if err != nil { + log.Printf("连接隧道服务器失败: %v,%v 后重试", err, ReconnectDelay) + time.Sleep(ReconnectDelay) + continue + } + + log.Printf("已连接到隧道服务器: %s", c.serverAddr) + + c.mu.Lock() + c.serverConn = conn + c.mu.Unlock() + + // 处理连接 + if err := c.handleServerConnection(conn); err != nil { + if err != io.EOF { + log.Printf("处理服务器连接出错: %v", err) + } + } + + c.mu.Lock() + c.serverConn = nil + c.mu.Unlock() + + // 关闭所有本地连接 + c.connMu.Lock() + for _, conn := range c.connections { + close(conn.closeChan) + if conn.Conn != nil { + conn.Conn.Close() + } + } + c.connections = make(map[uint32]*LocalConnection) + c.connMu.Unlock() + + log.Printf("与服务器断开连接,%v 后重连", ReconnectDelay) + time.Sleep(ReconnectDelay) + } +} + +// handleServerConnection 处理服务器连接 +func (c *Client) handleServerConnection(conn net.Conn) error { + for { + select { + case <-c.ctx.Done(): + return nil + default: + } + + // 读取消息头 + header := make([]byte, HeaderSize) + if _, err := io.ReadFull(conn, header); err != nil { + return err + } + + dataLen := binary.BigEndian.Uint32(header[0:4]) + connID := binary.BigEndian.Uint32(header[4:8]) + + if dataLen > MaxPacketSize { + return fmt.Errorf("数据包过大: %d bytes", dataLen) + } + + // 读取数据 + data := make([]byte, dataLen) + if _, err := io.ReadFull(conn, data); err != nil { + return err + } + + // 处理数据 + c.handleData(connID, data) + } +} + +// handleData 处理接收到的数据 +func (c *Client) handleData(connID uint32, data []byte) { + c.connMu.Lock() + localConn, exists := c.connections[connID] + + if !exists { + // 新连接,需要建立到本地服务的连接 + // 从数据中解析目标端口(这里简化处理,实际应该从协议中获取) + localConn = &LocalConnection{ + ID: connID, + closeChan: make(chan struct{}), + } + c.connections[connID] = localConn + c.connMu.Unlock() + + // 启动本地连接处理 + c.wg.Add(1) + go c.handleLocalConnection(localConn) + + // 重新获取锁并发送数据 + c.connMu.Lock() + } + c.connMu.Unlock() + + // 发送数据到本地连接 + if localConn.Conn != nil { + localConn.Conn.Write(data) + } +} + +// handleLocalConnection 处理本地连接 +func (c *Client) handleLocalConnection(localConn *LocalConnection) { + defer c.wg.Done() + defer func() { + c.connMu.Lock() + delete(c.connections, localConn.ID) + c.connMu.Unlock() + + if localConn.Conn != nil { + localConn.Conn.Close() + } + }() + + // 连接到本地目标服务 + // 这里使用固定的本地地址,实际应该根据映射配置 + targetAddr := localConn.TargetAddr + if targetAddr == "" { + targetAddr = "127.0.0.1:22" // 默认 SSH + } + + conn, err := net.DialTimeout("tcp", targetAddr, 5*time.Second) + if err != nil { + log.Printf("连接本地服务失败 (连接 %d -> %s): %v", localConn.ID, targetAddr, err) + return + } + defer conn.Close() + + localConn.Conn = conn + log.Printf("建立本地连接: %d -> %s", localConn.ID, targetAddr) + + // 从本地服务读取数据并发送到服务器 + buffer := make([]byte, 32*1024) + for { + select { + case <-localConn.closeChan: + return + case <-c.ctx.Done(): + return + default: + } + + conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + n, err := conn.Read(buffer) + if err != nil { + if err != io.EOF && !isTimeout(err) { + log.Printf("读取本地连接失败 (连接 %d): %v", localConn.ID, err) + } + return + } + + // 发送到服务器 + c.mu.RLock() + serverConn := c.serverConn + c.mu.RUnlock() + + if serverConn == nil { + return + } + + data := make([]byte, HeaderSize+n) + binary.BigEndian.PutUint32(data[0:4], uint32(n)) + binary.BigEndian.PutUint32(data[4:8], localConn.ID) + copy(data[HeaderSize:], buffer[:n]) + + if _, err := serverConn.Write(data); err != nil { + log.Printf("发送数据到服务器失败 (连接 %d): %v", localConn.ID, err) + return + } + } +} + +// Stop 停止隧道客户端 +func (c *Client) Stop() error { + log.Println("正在停止隧道客户端...") + c.cancel() + + c.mu.Lock() + if c.serverConn != nil { + c.serverConn.Close() + } + c.mu.Unlock() + + // 等待所有协程结束 + done := make(chan struct{}) + go func() { + c.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Println("隧道客户端已停止") + case <-time.After(5 * time.Second): + log.Println("隧道客户端停止超时") + } + + return nil +} + +// isTimeout 检查是否为超时错误 +func isTimeout(err error) bool { + if netErr, ok := err.(net.Error); ok { + return netErr.Timeout() + } + return false +} \ No newline at end of file diff --git a/config.yaml b/src/config.yaml similarity index 100% rename from config.yaml rename to src/config.yaml diff --git a/go.mod b/src/go.mod similarity index 100% rename from go.mod rename to src/go.mod diff --git a/go.sum b/src/go.sum similarity index 100% rename from go.sum rename to src/go.sum diff --git a/src/server/api/api.go b/src/server/api/api.go new file mode 100644 index 0000000..8ed910d --- /dev/null +++ b/src/server/api/api.go @@ -0,0 +1,247 @@ +package api + +import ( + "encoding/json" + "fmt" + "log" + "net" + "net/http" + "port-forward/server/db" + "port-forward/server/forwarder" + "port-forward/server/tunnel" + "strconv" + "time" +) + +// Handler HTTP API 处理器 +type Handler struct { + db *db.Database + forwarderMgr *forwarder.Manager + tunnelServer *tunnel.Server + portRangeFrom int + portRangeEnd int + useTunnel bool +} + +// NewHandler 创建新的 API 处理器 +func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, portFrom, portEnd int, useTunnel bool) *Handler { + return &Handler{ + db: database, + forwarderMgr: fwdMgr, + tunnelServer: ts, + portRangeFrom: portFrom, + portRangeEnd: portEnd, + useTunnel: useTunnel, + } +} + +// CreateMappingRequest 创建映射请求 +type CreateMappingRequest struct { + Port int `json:"port"` // 源端口和目标端口(相同) + TargetIP string `json:"target_ip"` // 目标 IP(非隧道模式使用) +} + +// RemoveMappingRequest 删除映射请求 +type RemoveMappingRequest struct { + Port int `json:"port"` +} + +// Response 统一响应格式 +type Response struct { + Success bool `json:"success"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// RegisterRoutes 注册路由 +func (h *Handler) RegisterRoutes(mux *http.ServeMux) { + mux.HandleFunc("/api/mapping/create", h.handleCreateMapping) + mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping) + mux.HandleFunc("/api/mapping/list", h.handleListMappings) + mux.HandleFunc("/health", h.handleHealth) +} + +// handleCreateMapping 处理创建映射请求 +func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + h.writeError(w, http.StatusMethodNotAllowed, "只支持 POST 方法") + return + } + + var req CreateMappingRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error()) + return + } + + // 验证端口范围 + if req.Port < h.portRangeFrom || req.Port > h.portRangeEnd { + h.writeError(w, http.StatusBadRequest, fmt.Sprintf("端口必须在 %d-%d 范围内", h.portRangeFrom, h.portRangeEnd)) + return + } + + // 检查端口是否已被使用 + if h.forwarderMgr.Exists(req.Port) { + h.writeError(w, http.StatusConflict, "端口已被占用") + return + } + + // 非隧道模式需要验证 IP + if !h.useTunnel { + if req.TargetIP == "" { + h.writeError(w, http.StatusBadRequest, "target_ip 不能为空") + return + } + if net.ParseIP(req.TargetIP) == nil { + h.writeError(w, http.StatusBadRequest, "target_ip 格式无效") + return + } + } else { + // 隧道模式,检查隧道是否连接 + if !h.tunnelServer.IsConnected() { + h.writeError(w, http.StatusServiceUnavailable, "隧道未连接") + return + } + // 隧道模式使用本地地址 + req.TargetIP = "127.0.0.1" + } + + // 添加到数据库 + if err := h.db.AddMapping(req.Port, req.TargetIP, req.Port); err != nil { + h.writeError(w, http.StatusInternalServerError, "保存映射失败: "+err.Error()) + return + } + + // 启动转发器 + if err := h.forwarderMgr.Add(req.Port, req.TargetIP, req.Port); err != nil { + // 回滚数据库操作 + h.db.RemoveMapping(req.Port) + h.writeError(w, http.StatusInternalServerError, "启动转发失败: "+err.Error()) + return + } + + log.Printf("创建端口映射: %d -> %s:%d", req.Port, req.TargetIP, req.Port) + + h.writeSuccess(w, "端口映射创建成功", map[string]interface{}{ + "port": req.Port, + "target_ip": req.TargetIP, + "use_tunnel": h.useTunnel, + }) +} + +// handleRemoveMapping 处理删除映射请求 +func (h *Handler) handleRemoveMapping(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + h.writeError(w, http.StatusMethodNotAllowed, "只支持 POST 方法") + return + } + + var req RemoveMappingRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error()) + return + } + + // 验证端口范围 + if req.Port < h.portRangeFrom || req.Port > h.portRangeEnd { + h.writeError(w, http.StatusBadRequest, fmt.Sprintf("端口必须在 %d-%d 范围内", h.portRangeFrom, h.portRangeEnd)) + return + } + + // 检查映射是否存在 + if !h.forwarderMgr.Exists(req.Port) { + h.writeError(w, http.StatusNotFound, "端口映射不存在") + return + } + + // 停止转发器 + if err := h.forwarderMgr.Remove(req.Port); err != nil { + h.writeError(w, http.StatusInternalServerError, "停止转发失败: "+err.Error()) + return + } + + // 从数据库删除 + if err := h.db.RemoveMapping(req.Port); err != nil { + log.Printf("从数据库删除映射失败 (端口 %d): %v", req.Port, err) + // 即使数据库删除失败,转发器已经停止,仍然返回成功 + } + + log.Printf("删除端口映射: %d", req.Port) + + h.writeSuccess(w, "端口映射删除成功", map[string]interface{}{ + "port": req.Port, + }) +} + +// handleListMappings 处理列出所有映射请求 +func (h *Handler) handleListMappings(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + h.writeError(w, http.StatusMethodNotAllowed, "只支持 GET 方法") + return + } + + mappings, err := h.db.GetAllMappings() + if err != nil { + h.writeError(w, http.StatusInternalServerError, "获取映射列表失败: "+err.Error()) + return + } + + h.writeSuccess(w, "获取映射列表成功", map[string]interface{}{ + "mappings": mappings, + "count": len(mappings), + "use_tunnel": h.useTunnel, + }) +} + +// handleHealth 健康检查 +func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) { + status := map[string]interface{}{ + "status": "ok", + "tunnel_enabled": h.useTunnel, + "tunnel_connected": false, + } + + if h.useTunnel { + status["tunnel_connected"] = h.tunnelServer.IsConnected() + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(status) +} + +// writeSuccess 写入成功响应 +func (h *Handler) writeSuccess(w http.ResponseWriter, message string, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(Response{ + Success: true, + Message: message, + Data: data, + }) +} + +// writeError 写入错误响应 +func (h *Handler) writeError(w http.ResponseWriter, statusCode int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(Response{ + Success: false, + Message: message, + }) +} + +// Start 启动 HTTP 服务器 +func Start(handler *Handler, port int) error { + mux := http.NewServeMux() + handler.RegisterRoutes(mux) + + server := &http.Server{ + Addr: ":" + strconv.Itoa(port), + Handler: mux, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + + log.Printf("HTTP API 服务启动: 端口 %d", port) + return server.ListenAndServe() +} \ No newline at end of file diff --git a/src/server/config/config.go b/src/server/config/config.go new file mode 100644 index 0000000..2e6dd4c --- /dev/null +++ b/src/server/config/config.go @@ -0,0 +1,81 @@ +package config + +import ( + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +// Config 应用配置结构 +type Config struct { + PortRange PortRangeConfig `yaml:"port_range"` + Tunnel TunnelConfig `yaml:"tunnel"` + API APIConfig `yaml:"api"` + Database DatabaseConfig `yaml:"database"` +} + +// PortRangeConfig 端口范围配置 +type PortRangeConfig struct { + From int `yaml:"from"` + End int `yaml:"end"` +} + +// TunnelConfig 内网穿透配置 +type TunnelConfig struct { + Enabled bool `yaml:"enabled"` + ListenPort int `yaml:"listen_port"` +} + +// APIConfig HTTP API 配置 +type APIConfig struct { + ListenPort int `yaml:"listen_port"` +} + +// DatabaseConfig 数据库配置 +type DatabaseConfig struct { + Path string `yaml:"path"` +} + +// Load 从文件加载配置 +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取配置文件失败: %w", err) + } + + var config Config + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("解析配置文件失败: %w", err) + } + + // 验证配置 + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("配置验证失败: %w", err) + } + + return &config, nil +} + +// Validate 验证配置的有效性 +func (c *Config) Validate() error { + if c.PortRange.From <= 0 || c.PortRange.End <= 0 { + return fmt.Errorf("端口范围必须大于 0") + } + if c.PortRange.From > c.PortRange.End { + return fmt.Errorf("起始端口不能大于结束端口") + } + if c.PortRange.End-c.PortRange.From > 10000 { + return fmt.Errorf("端口范围过大,最多支持 10000 个端口") + } + if c.Tunnel.Enabled && c.Tunnel.ListenPort <= 0 { + return fmt.Errorf("内网穿透端口必须大于 0") + } + if c.API.ListenPort <= 0 { + return fmt.Errorf("API 端口必须大于 0") + } + if c.Database.Path == "" { + return fmt.Errorf("数据库路径不能为空") + } + return nil +} \ No newline at end of file diff --git a/src/server/config/config_test.go b/src/server/config/config_test.go new file mode 100644 index 0000000..dbc1592 --- /dev/null +++ b/src/server/config/config_test.go @@ -0,0 +1,136 @@ +package config + +import ( + "os" + "testing" +) + +func TestLoadConfig(t *testing.T) { + // 创建临时配置文件 + configContent := ` +port_range: + from: 10000 + end: 10100 + +tunnel: + enabled: true + listen_port: 9000 + +api: + listen_port: 8080 + +database: + path: "./data/mappings.db" +` + tmpFile, err := os.CreateTemp("", "config_test_*.yaml") + if err != nil { + t.Fatalf("创建临时文件失败: %v", err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.Write([]byte(configContent)); err != nil { + t.Fatalf("写入配置文件失败: %v", err) + } + tmpFile.Close() + + // 加载配置 + cfg, err := Load(tmpFile.Name()) + if err != nil { + t.Fatalf("加载配置失败: %v", err) + } + + // 验证配置 + if cfg.PortRange.From != 10000 { + t.Errorf("期望起始端口为 10000,得到 %d", cfg.PortRange.From) + } + if cfg.PortRange.End != 10100 { + t.Errorf("期望结束端口为 10100,得到 %d", cfg.PortRange.End) + } + if !cfg.Tunnel.Enabled { + t.Error("期望隧道启用") + } + if cfg.Tunnel.ListenPort != 9000 { + t.Errorf("期望隧道端口为 9000,得到 %d", cfg.Tunnel.ListenPort) + } + if cfg.API.ListenPort != 8080 { + t.Errorf("期望 API 端口为 8080,得到 %d", cfg.API.ListenPort) + } +} + +func TestValidateConfig(t *testing.T) { + tests := []struct { + name string + config Config + wantErr bool + }{ + { + name: "有效配置", + config: Config{ + PortRange: PortRangeConfig{From: 10000, End: 10100}, + Tunnel: TunnelConfig{Enabled: true, ListenPort: 9000}, + API: APIConfig{ListenPort: 8080}, + Database: DatabaseConfig{Path: "./data/test.db"}, + }, + wantErr: false, + }, + { + name: "无效端口范围 - 起始端口为0", + config: Config{ + PortRange: PortRangeConfig{From: 0, End: 10100}, + Tunnel: TunnelConfig{Enabled: false, ListenPort: 0}, + API: APIConfig{ListenPort: 8080}, + Database: DatabaseConfig{Path: "./data/test.db"}, + }, + wantErr: true, + }, + { + name: "无效端口范围 - 起始大于结束", + config: Config{ + PortRange: PortRangeConfig{From: 10100, End: 10000}, + Tunnel: TunnelConfig{Enabled: false, ListenPort: 0}, + API: APIConfig{ListenPort: 8080}, + Database: DatabaseConfig{Path: "./data/test.db"}, + }, + wantErr: true, + }, + { + name: "端口范围过大", + config: Config{ + PortRange: PortRangeConfig{From: 1, End: 20000}, + Tunnel: TunnelConfig{Enabled: false, ListenPort: 0}, + API: APIConfig{ListenPort: 8080}, + Database: DatabaseConfig{Path: "./data/test.db"}, + }, + wantErr: true, + }, + { + name: "启用隧道但端口无效", + config: Config{ + PortRange: PortRangeConfig{From: 10000, End: 10100}, + Tunnel: TunnelConfig{Enabled: true, ListenPort: 0}, + API: APIConfig{ListenPort: 8080}, + Database: DatabaseConfig{Path: "./data/test.db"}, + }, + wantErr: true, + }, + { + name: "数据库路径为空", + config: Config{ + PortRange: PortRangeConfig{From: 10000, End: 10100}, + Tunnel: TunnelConfig{Enabled: false, ListenPort: 0}, + API: APIConfig{ListenPort: 8080}, + Database: DatabaseConfig{Path: ""}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} \ No newline at end of file diff --git a/src/server/db/database.go b/src/server/db/database.go new file mode 100644 index 0000000..b5e9bcb --- /dev/null +++ b/src/server/db/database.go @@ -0,0 +1,178 @@ +package db + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "sync" + + _ "github.com/mattn/go-sqlite3" +) + +// Mapping 端口映射结构 +type Mapping struct { + ID int64 `json:"id"` + SourcePort int `json:"source_port"` + TargetIP string `json:"target_ip"` + TargetPort int `json:"target_port"` + CreatedAt string `json:"created_at"` +} + +// Database 数据库管理器 +type Database struct { + db *sql.DB + mu sync.RWMutex +} + +// New 创建新的数据库管理器 +func New(dbPath string) (*Database, error) { + // 确保数据库目录存在 + dir := filepath.Dir(dbPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("创建数据库目录失败: %w", err) + } + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + return nil, fmt.Errorf("打开数据库失败: %w", err) + } + + // 设置连接池参数 + db.SetMaxOpenConns(25) + db.SetMaxIdleConns(5) + + database := &Database{db: db} + + // 初始化表结构 + if err := database.initTables(); err != nil { + db.Close() + return nil, err + } + + return database, nil +} + +// initTables 初始化数据库表 +func (d *Database) initTables() error { + query := ` + CREATE TABLE IF NOT EXISTS mappings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source_port INTEGER NOT NULL UNIQUE, + target_ip TEXT NOT NULL, + target_port INTEGER NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port); + ` + + _, err := d.db.Exec(query) + if err != nil { + return fmt.Errorf("初始化数据库表失败: %w", err) + } + + return nil +} + +// AddMapping 添加端口映射 +func (d *Database) AddMapping(sourcePort int, targetIP string, targetPort int) error { + d.mu.Lock() + defer d.mu.Unlock() + + query := `INSERT INTO mappings (source_port, target_ip, target_port) VALUES (?, ?, ?)` + _, err := d.db.Exec(query, sourcePort, targetIP, targetPort) + if err != nil { + return fmt.Errorf("添加端口映射失败: %w", err) + } + + return nil +} + +// RemoveMapping 删除端口映射 +func (d *Database) RemoveMapping(sourcePort int) error { + d.mu.Lock() + defer d.mu.Unlock() + + query := `DELETE FROM mappings WHERE source_port = ?` + result, err := d.db.Exec(query, sourcePort) + if err != nil { + return fmt.Errorf("删除端口映射失败: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("获取影响行数失败: %w", err) + } + + if rows == 0 { + return fmt.Errorf("端口映射不存在") + } + + return nil +} + +// GetMapping 获取指定端口的映射 +func (d *Database) GetMapping(sourcePort int) (*Mapping, error) { + d.mu.RLock() + defer d.mu.RUnlock() + + query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings WHERE source_port = ?` + + var mapping Mapping + err := d.db.QueryRow(query, sourcePort).Scan( + &mapping.ID, + &mapping.SourcePort, + &mapping.TargetIP, + &mapping.TargetPort, + &mapping.CreatedAt, + ) + + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("查询端口映射失败: %w", err) + } + + return &mapping, nil +} + +// GetAllMappings 获取所有端口映射 +func (d *Database) GetAllMappings() ([]*Mapping, error) { + d.mu.RLock() + defer d.mu.RUnlock() + + query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings ORDER BY source_port` + + rows, err := d.db.Query(query) + if err != nil { + return nil, fmt.Errorf("查询所有映射失败: %w", err) + } + defer rows.Close() + + var mappings []*Mapping + for rows.Next() { + var mapping Mapping + if err := rows.Scan( + &mapping.ID, + &mapping.SourcePort, + &mapping.TargetIP, + &mapping.TargetPort, + &mapping.CreatedAt, + ); err != nil { + return nil, fmt.Errorf("扫描映射记录失败: %w", err) + } + mappings = append(mappings, &mapping) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("遍历映射记录失败: %w", err) + } + + return mappings, nil +} + +// Close 关闭数据库连接 +func (d *Database) Close() error { + return d.db.Close() +} \ No newline at end of file diff --git a/src/server/db/database_test.go b/src/server/db/database_test.go new file mode 100644 index 0000000..af6c06c --- /dev/null +++ b/src/server/db/database_test.go @@ -0,0 +1,125 @@ +package db + +import ( + "os" + "testing" +) + +func TestDatabase(t *testing.T) { + // 使用临时数据库 + dbPath := "/tmp/test_mappings.db" + defer os.Remove(dbPath) + + // 创建数据库 + db, err := New(dbPath) + if err != nil { + t.Fatalf("创建数据库失败: %v", err) + } + defer db.Close() + + t.Run("添加映射", func(t *testing.T) { + err := db.AddMapping(10001, "192.168.1.100", 22) + if err != nil { + t.Errorf("添加映射失败: %v", err) + } + }) + + t.Run("获取映射", func(t *testing.T) { + mapping, err := db.GetMapping(10001) + if err != nil { + t.Errorf("获取映射失败: %v", err) + } + if mapping == nil { + t.Error("映射不应该为空") + } + if mapping.SourcePort != 10001 { + t.Errorf("期望源端口为 10001,得到 %d", mapping.SourcePort) + } + if mapping.TargetIP != "192.168.1.100" { + t.Errorf("期望目标 IP 为 192.168.1.100,得到 %s", mapping.TargetIP) + } + if mapping.TargetPort != 22 { + t.Errorf("期望目标端口为 22,得到 %d", mapping.TargetPort) + } + }) + + t.Run("添加重复映射应该失败", func(t *testing.T) { + err := db.AddMapping(10001, "192.168.1.101", 22) + if err == nil { + t.Error("添加重复映射应该失败") + } + }) + + t.Run("获取所有映射", func(t *testing.T) { + // 添加更多映射 + db.AddMapping(10002, "192.168.1.101", 22) + db.AddMapping(10003, "192.168.1.102", 22) + + mappings, err := db.GetAllMappings() + if err != nil { + t.Errorf("获取所有映射失败: %v", err) + } + if len(mappings) != 3 { + t.Errorf("期望 3 个映射,得到 %d", len(mappings)) + } + }) + + t.Run("删除映射", func(t *testing.T) { + err := db.RemoveMapping(10001) + if err != nil { + t.Errorf("删除映射失败: %v", err) + } + + mapping, err := db.GetMapping(10001) + if err != nil { + t.Errorf("查询映射失败: %v", err) + } + if mapping != nil { + t.Error("映射应该已被删除") + } + }) + + t.Run("删除不存在的映射应该失败", func(t *testing.T) { + err := db.RemoveMapping(99999) + if err == nil { + t.Error("删除不存在的映射应该失败") + } + }) +} + +func TestDatabaseConcurrency(t *testing.T) { + dbPath := "/tmp/test_concurrent.db" + defer os.Remove(dbPath) + + db, err := New(dbPath) + if err != nil { + t.Fatalf("创建数据库失败: %v", err) + } + defer db.Close() + + // 并发添加映射 + done := make(chan bool) + for i := 0; i < 10; i++ { + go func(port int) { + err := db.AddMapping(10000+port, "192.168.1.100", port) + if err != nil { + t.Logf("添加映射失败 (端口 %d): %v", 10000+port, err) + } + done <- true + }(i) + } + + // 等待所有操作完成 + for i := 0; i < 10; i++ { + <-done + } + + // 验证映射数量 + mappings, err := db.GetAllMappings() + if err != nil { + t.Errorf("获取所有映射失败: %v", err) + } + if len(mappings) == 0 { + t.Error("应该至少有一些映射") + } +} \ No newline at end of file diff --git a/src/server/forwarder/forwarder.go b/src/server/forwarder/forwarder.go new file mode 100644 index 0000000..a0ef8a2 --- /dev/null +++ b/src/server/forwarder/forwarder.go @@ -0,0 +1,278 @@ +package forwarder + +import ( + "context" + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +// Forwarder 端口转发器 +type Forwarder struct { + sourcePort int + targetAddr string + listener net.Listener + cancel context.CancelFunc + ctx context.Context + wg sync.WaitGroup + tunnelConn net.Conn + useTunnel bool + mu sync.RWMutex +} + +// NewForwarder 创建新的端口转发器 +func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder { + ctx, cancel := context.WithCancel(context.Background()) + return &Forwarder{ + sourcePort: sourcePort, + targetAddr: fmt.Sprintf("%s:%d", targetIP, targetPort), + cancel: cancel, + ctx: ctx, + useTunnel: false, + } +} + +// NewTunnelForwarder 创建使用隧道的端口转发器 +func NewTunnelForwarder(sourcePort int, targetPort int, tunnelConn net.Conn) *Forwarder { + ctx, cancel := context.WithCancel(context.Background()) + return &Forwarder{ + sourcePort: sourcePort, + targetAddr: fmt.Sprintf("127.0.0.1:%d", targetPort), + tunnelConn: tunnelConn, + useTunnel: true, + cancel: cancel, + ctx: ctx, + } +} + +// Start 启动端口转发 +func (f *Forwarder) Start() error { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", f.sourcePort)) + if err != nil { + return fmt.Errorf("监听端口 %d 失败: %w", f.sourcePort, err) + } + + f.listener = listener + log.Printf("端口转发启动: %d -> %s (tunnel: %v)", f.sourcePort, f.targetAddr, f.useTunnel) + + f.wg.Add(1) + go f.acceptLoop() + + return nil +} + +// acceptLoop 接受连接循环 +func (f *Forwarder) acceptLoop() { + defer f.wg.Done() + + for { + select { + case <-f.ctx.Done(): + return + default: + } + + // 设置接受超时,避免阻塞关闭 + f.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second)) + + conn, err := f.listener.Accept() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + select { + case <-f.ctx.Done(): + return + default: + log.Printf("接受连接失败 (端口 %d): %v", f.sourcePort, err) + continue + } + } + + f.wg.Add(1) + go f.handleConnection(conn) + } +} + +// handleConnection 处理单个连接 +func (f *Forwarder) handleConnection(clientConn net.Conn) { + defer f.wg.Done() + defer clientConn.Close() + + var targetConn net.Conn + var err error + + if f.useTunnel { + // 使用隧道连接 + f.mu.RLock() + targetConn = f.tunnelConn + f.mu.RUnlock() + + if targetConn == nil { + log.Printf("隧道连接不可用 (端口 %d)", f.sourcePort) + return + } + } else { + // 直接连接目标 + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + + targetConn, err = dialer.DialContext(f.ctx, "tcp", f.targetAddr) + if err != nil { + log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, f.targetAddr, err) + return + } + defer targetConn.Close() + } + + // 双向转发 + errChan := make(chan error, 2) + + // 客户端 -> 目标 + go func() { + _, err := io.Copy(targetConn, clientConn) + errChan <- err + }() + + // 目标 -> 客户端 + go func() { + _, err := io.Copy(clientConn, targetConn) + errChan <- err + }() + + // 等待任一方向完成或出错 + select { + case <-errChan: + // 连接已关闭或出错 + case <-f.ctx.Done(): + // 转发器被停止 + } +} + +// Stop 停止端口转发 +func (f *Forwarder) Stop() error { + f.cancel() + + if f.listener != nil { + if err := f.listener.Close(); err != nil { + log.Printf("关闭监听器失败 (端口 %d): %v", f.sourcePort, err) + } + } + + // 等待所有连接处理完成(最多等待5秒) + done := make(chan struct{}) + go func() { + f.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Printf("端口转发已停止: %d", f.sourcePort) + case <-time.After(5 * time.Second): + log.Printf("端口转发停止超时: %d", f.sourcePort) + } + + return nil +} + +// SetTunnelConn 设置隧道连接 +func (f *Forwarder) SetTunnelConn(conn net.Conn) { + f.mu.Lock() + defer f.mu.Unlock() + f.tunnelConn = conn +} + +// Manager 转发器管理器 +type Manager struct { + forwarders map[int]*Forwarder + mu sync.RWMutex +} + +// NewManager 创建新的转发器管理器 +func NewManager() *Manager { + return &Manager{ + forwarders: make(map[int]*Forwarder), + } +} + +// Add 添加并启动转发器 +func (m *Manager) Add(sourcePort int, targetIP string, targetPort int) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.forwarders[sourcePort]; exists { + return fmt.Errorf("端口 %d 已被占用", sourcePort) + } + + forwarder := NewForwarder(sourcePort, targetIP, targetPort) + if err := forwarder.Start(); err != nil { + return err + } + + m.forwarders[sourcePort] = forwarder + return nil +} + +// AddTunnel 添加使用隧道的转发器 +func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelConn net.Conn) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.forwarders[sourcePort]; exists { + return fmt.Errorf("端口 %d 已被占用", sourcePort) + } + + forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelConn) + if err := forwarder.Start(); err != nil { + return err + } + + m.forwarders[sourcePort] = forwarder + return nil +} + +// Remove 移除并停止转发器 +func (m *Manager) Remove(sourcePort int) error { + m.mu.Lock() + defer m.mu.Unlock() + + forwarder, exists := m.forwarders[sourcePort] + if !exists { + return fmt.Errorf("端口 %d 的转发器不存在", sourcePort) + } + + if err := forwarder.Stop(); err != nil { + return err + } + + delete(m.forwarders, sourcePort) + return nil +} + +// Exists 检查转发器是否存在 +func (m *Manager) Exists(sourcePort int) bool { + m.mu.RLock() + defer m.mu.RUnlock() + _, exists := m.forwarders[sourcePort] + return exists +} + +// StopAll 停止所有转发器 +func (m *Manager) StopAll() { + m.mu.Lock() + defer m.mu.Unlock() + + for port, forwarder := range m.forwarders { + if err := forwarder.Stop(); err != nil { + log.Printf("停止端口 %d 的转发器失败: %v", port, err) + } + } + + m.forwarders = make(map[int]*Forwarder) +} \ No newline at end of file diff --git a/src/server/main.go b/src/server/main.go new file mode 100644 index 0000000..8177744 --- /dev/null +++ b/src/server/main.go @@ -0,0 +1,119 @@ +package main + +import ( + "context" + "flag" + "log" + "os" + "os/signal" + "port-forward/server/api" + "port-forward/server/config" + "port-forward/server/db" + "port-forward/server/forwarder" + "port-forward/server/tunnel" + "syscall" + "time" +) + +func main() { + // 解析命令行参数 + configPath := flag.String("config", "config.yaml", "配置文件路径") + flag.Parse() + + // 加载配置 + log.Println("加载配置文件...") + cfg, err := config.Load(*configPath) + if err != nil { + log.Fatalf("加载配置失败: %v", err) + } + + // 初始化数据库 + log.Println("初始化数据库...") + database, err := db.New(cfg.Database.Path) + if err != nil { + log.Fatalf("初始化数据库失败: %v", err) + } + defer database.Close() + + // 创建转发器管理器 + log.Println("创建转发器管理器...") + fwdManager := forwarder.NewManager() + + // 如果启用隧道,启动隧道服务器 + var tunnelServer *tunnel.Server + if cfg.Tunnel.Enabled { + log.Println("启动隧道服务器...") + tunnelServer = tunnel.NewServer(cfg.Tunnel.ListenPort) + if err := tunnelServer.Start(); err != nil { + log.Fatalf("启动隧道服务器失败: %v", err) + } + defer tunnelServer.Stop() + } + + // 从数据库加载现有映射并启动转发器 + log.Println("加载现有端口映射...") + mappings, err := database.GetAllMappings() + if err != nil { + log.Fatalf("加载端口映射失败: %v", err) + } + + for _, mapping := range mappings { + // 验证端口在范围内 + if mapping.SourcePort < cfg.PortRange.From || mapping.SourcePort > cfg.PortRange.End { + log.Printf("警告: 端口 %d 超出范围,跳过", mapping.SourcePort) + continue + } + + log.Printf("恢复端口映射: %d -> %s:%d", mapping.SourcePort, mapping.TargetIP, mapping.TargetPort) + if err := fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort); err != nil { + log.Printf("警告: 启动端口 %d 的转发失败: %v", mapping.SourcePort, err) + } + } + + log.Printf("成功加载 %d 个端口映射", len(mappings)) + + // 创建 HTTP API 处理器 + log.Println("初始化 HTTP API...") + apiHandler := api.NewHandler( + database, + fwdManager, + tunnelServer, + cfg.PortRange.From, + cfg.PortRange.End, + cfg.Tunnel.Enabled, + ) + + // 启动 HTTP API 服务器 + go func() { + if err := api.Start(apiHandler, cfg.API.ListenPort); err != nil { + log.Fatalf("启动 HTTP API 服务失败: %v", err) + } + }() + + log.Println("===========================================") + log.Printf("服务器启动成功!") + log.Printf("端口范围: %d-%d", cfg.PortRange.From, cfg.PortRange.End) + log.Printf("HTTP API: http://localhost:%d", cfg.API.ListenPort) + if cfg.Tunnel.Enabled { + log.Printf("隧道服务: 端口 %d", cfg.Tunnel.ListenPort) + } + log.Println("===========================================") + + // 等待中断信号 + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + <-sigChan + log.Println("\n接收到关闭信号,正在优雅关闭...") + + // 创建关闭上下文 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // 停止所有转发器 + log.Println("停止所有端口转发...") + fwdManager.StopAll() + + log.Println("服务器已关闭") + <-ctx.Done() +} \ No newline at end of file diff --git a/src/server/tunnel/tunnel.go b/src/server/tunnel/tunnel.go new file mode 100644 index 0000000..f1545ef --- /dev/null +++ b/src/server/tunnel/tunnel.go @@ -0,0 +1,320 @@ +package tunnel + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +// Protocol 定义隧道协议 +// 消息格式: [4字节长度][4字节端口][数据] + +const ( + // HeaderSize 消息头大小(长度+端口) + HeaderSize = 8 + // MaxPacketSize 最大包大小 (1MB) + MaxPacketSize = 1024 * 1024 +) + +// Server 内网穿透服务器 +type Server struct { + listenPort int + listener net.Listener + client net.Conn + cancel context.CancelFunc + ctx context.Context + wg sync.WaitGroup + mu sync.RWMutex + + // 连接管理 + connections map[uint32]*Connection + connMu sync.RWMutex + nextConnID uint32 +} + +// Connection 表示一个客户端连接 +type Connection struct { + ID uint32 + TargetPort int + ClientConn net.Conn + readChan chan []byte + writeChan chan []byte + closeChan chan struct{} +} + +// NewServer 创建新的隧道服务器 +func NewServer(listenPort int) *Server { + ctx, cancel := context.WithCancel(context.Background()) + return &Server{ + listenPort: listenPort, + cancel: cancel, + ctx: ctx, + connections: make(map[uint32]*Connection), + } +} + +// Start 启动隧道服务器 +func (s *Server) Start() error { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.listenPort)) + if err != nil { + return fmt.Errorf("启动隧道服务器失败: %w", err) + } + + s.listener = listener + log.Printf("隧道服务器启动: 端口 %d", s.listenPort) + + s.wg.Add(1) + go s.acceptLoop() + + return nil +} + +// acceptLoop 接受客户端连接 +func (s *Server) acceptLoop() { + defer s.wg.Done() + + for { + select { + case <-s.ctx.Done(): + return + default: + } + + s.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second)) + conn, err := s.listener.Accept() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + select { + case <-s.ctx.Done(): + return + default: + log.Printf("接受隧道连接失败: %v", err) + continue + } + } + + // 只允许一个客户端连接 + s.mu.Lock() + if s.client != nil { + log.Printf("拒绝额外的隧道连接: %s", conn.RemoteAddr()) + conn.Close() + s.mu.Unlock() + continue + } + s.client = conn + s.mu.Unlock() + + log.Printf("隧道客户端已连接: %s", conn.RemoteAddr()) + + s.wg.Add(1) + go s.handleClient(conn) + } +} + +// handleClient 处理客户端连接 +func (s *Server) handleClient(conn net.Conn) { + defer s.wg.Done() + defer func() { + conn.Close() + s.mu.Lock() + s.client = nil + s.mu.Unlock() + log.Printf("隧道客户端已断开") + + // 关闭所有活动连接 + s.connMu.Lock() + for _, c := range s.connections { + close(c.closeChan) + } + s.connections = make(map[uint32]*Connection) + s.connMu.Unlock() + }() + + // 读取来自客户端的数据 + for { + select { + case <-s.ctx.Done(): + return + default: + } + + // 读取消息头 + header := make([]byte, HeaderSize) + if _, err := io.ReadFull(conn, header); err != nil { + if err != io.EOF { + log.Printf("读取隧道消息头失败: %v", err) + } + return + } + + dataLen := binary.BigEndian.Uint32(header[0:4]) + connID := binary.BigEndian.Uint32(header[4:8]) + + if dataLen > MaxPacketSize { + log.Printf("数据包过大: %d bytes", dataLen) + return + } + + // 读取数据 + data := make([]byte, dataLen) + if _, err := io.ReadFull(conn, data); err != nil { + log.Printf("读取隧道数据失败: %v", err) + return + } + + // 将数据发送到对应的连接 + s.connMu.RLock() + connection, exists := s.connections[connID] + s.connMu.RUnlock() + + if exists { + select { + case connection.readChan <- data: + case <-connection.closeChan: + case <-time.After(5 * time.Second): + log.Printf("向连接 %d 发送数据超时", connID) + } + } + } +} + +// ForwardConnection 转发连接到隧道 +func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error { + s.mu.RLock() + tunnelConn := s.client + s.mu.RUnlock() + + if tunnelConn == nil { + return fmt.Errorf("隧道连接不可用") + } + + // 创建连接对象 + s.connMu.Lock() + connID := s.nextConnID + s.nextConnID++ + + connection := &Connection{ + ID: connID, + TargetPort: targetPort, + ClientConn: clientConn, + readChan: make(chan []byte, 100), + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } + s.connections[connID] = connection + s.connMu.Unlock() + + defer func() { + s.connMu.Lock() + delete(s.connections, connID) + s.connMu.Unlock() + close(connection.closeChan) + clientConn.Close() + }() + + // 启动读写协程 + errChan := make(chan error, 2) + + // 从客户端读取并发送到隧道 + go func() { + buffer := make([]byte, 32*1024) + for { + select { + case <-connection.closeChan: + return + default: + } + + clientConn.SetReadDeadline(time.Now().Add(30 * time.Second)) + n, err := clientConn.Read(buffer) + if err != nil { + errChan <- err + return + } + + // 发送到隧道 + data := make([]byte, HeaderSize+n) + binary.BigEndian.PutUint32(data[0:4], uint32(n)) + binary.BigEndian.PutUint32(data[4:8], connID) + copy(data[HeaderSize:], buffer[:n]) + + s.mu.RLock() + _, err = tunnelConn.Write(data) + s.mu.RUnlock() + + if err != nil { + errChan <- err + return + } + } + }() + + // 从隧道读取并发送到客户端 + go func() { + for { + select { + case data := <-connection.readChan: + if _, err := clientConn.Write(data); err != nil { + errChan <- err + return + } + case <-connection.closeChan: + return + } + } + }() + + // 等待错误或关闭 + select { + case <-errChan: + case <-connection.closeChan: + case <-s.ctx.Done(): + } + + return nil +} + +// IsConnected 检查隧道是否已连接 +func (s *Server) IsConnected() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.client != nil +} + +// Stop 停止隧道服务器 +func (s *Server) Stop() error { + s.cancel() + + if s.listener != nil { + s.listener.Close() + } + + s.mu.Lock() + if s.client != nil { + s.client.Close() + } + s.mu.Unlock() + + // 等待所有协程结束 + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Printf("隧道服务器已停止") + case <-time.After(5 * time.Second): + log.Printf("隧道服务器停止超时") + } + + return nil +} \ No newline at end of file