This commit is contained in:
Pan Qiancheng 2025-10-14 13:57:47 +08:00
parent b6baedeaf8
commit 6b02def8de
14 changed files with 1821 additions and 4 deletions

8
.gitignore vendored
View File

@ -5,10 +5,10 @@ bin/
*.dll
*.so
*.dylib
server
client
server-linux
client-linux
/server
/client
/server-linux
/client-linux
# 测试文件
*.test

48
src/client/main.go Normal file
View File

@ -0,0 +1,48 @@
package main
import (
"flag"
"log"
"os"
"os/signal"
"port-forward/client/tunnel"
"syscall"
)
func main() {
// 解析命令行参数
serverAddr := flag.String("server", "localhost:9000", "隧道服务器地址 (host:port)")
flag.Parse()
log.SetFlags(log.LstdFlags | log.Lshortfile)
// 创建隧道客户端
log.Printf("隧道客户端启动...")
log.Printf("服务器地址: %s", *serverAddr)
client := tunnel.NewClient(*serverAddr)
// 启动客户端
if err := client.Start(); err != nil {
log.Fatalf("启动隧道客户端失败: %v", err)
}
log.Println("===========================================")
log.Println("隧道客户端运行中...")
log.Println("按 Ctrl+C 退出")
log.Println("===========================================")
// 等待中断信号
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
log.Println("\n接收到关闭信号正在关闭...")
// 停止客户端
if err := client.Stop(); err != nil {
log.Printf("停止客户端失败: %v", err)
}
log.Println("客户端已关闭")
}

285
src/client/tunnel/client.go Normal file
View File

@ -0,0 +1,285 @@
package tunnel
import (
"context"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"sync"
"time"
)
const (
// HeaderSize 消息头大小
HeaderSize = 8
// MaxPacketSize 最大包大小
MaxPacketSize = 1024 * 1024
// ReconnectDelay 重连延迟
ReconnectDelay = 5 * time.Second
)
// Client 内网穿透客户端
type Client struct {
serverAddr string
serverConn net.Conn
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
mu sync.RWMutex
// 连接管理
connections map[uint32]*LocalConnection
connMu sync.RWMutex
}
// LocalConnection 本地连接
type LocalConnection struct {
ID uint32
TargetAddr string
Conn net.Conn
closeChan chan struct{}
}
// NewClient 创建新的隧道客户端
func NewClient(serverAddr string) *Client {
ctx, cancel := context.WithCancel(context.Background())
return &Client{
serverAddr: serverAddr,
cancel: cancel,
ctx: ctx,
connections: make(map[uint32]*LocalConnection),
}
}
// Start 启动隧道客户端
func (c *Client) Start() error {
log.Printf("正在连接到隧道服务器: %s", c.serverAddr)
c.wg.Add(1)
go c.connectLoop()
return nil
}
// connectLoop 连接循环(支持自动重连)
func (c *Client) connectLoop() {
defer c.wg.Done()
for {
select {
case <-c.ctx.Done():
return
default:
}
conn, err := net.DialTimeout("tcp", c.serverAddr, 10*time.Second)
if err != nil {
log.Printf("连接隧道服务器失败: %v%v 后重试", err, ReconnectDelay)
time.Sleep(ReconnectDelay)
continue
}
log.Printf("已连接到隧道服务器: %s", c.serverAddr)
c.mu.Lock()
c.serverConn = conn
c.mu.Unlock()
// 处理连接
if err := c.handleServerConnection(conn); err != nil {
if err != io.EOF {
log.Printf("处理服务器连接出错: %v", err)
}
}
c.mu.Lock()
c.serverConn = nil
c.mu.Unlock()
// 关闭所有本地连接
c.connMu.Lock()
for _, conn := range c.connections {
close(conn.closeChan)
if conn.Conn != nil {
conn.Conn.Close()
}
}
c.connections = make(map[uint32]*LocalConnection)
c.connMu.Unlock()
log.Printf("与服务器断开连接,%v 后重连", ReconnectDelay)
time.Sleep(ReconnectDelay)
}
}
// handleServerConnection 处理服务器连接
func (c *Client) handleServerConnection(conn net.Conn) error {
for {
select {
case <-c.ctx.Done():
return nil
default:
}
// 读取消息头
header := make([]byte, HeaderSize)
if _, err := io.ReadFull(conn, header); err != nil {
return err
}
dataLen := binary.BigEndian.Uint32(header[0:4])
connID := binary.BigEndian.Uint32(header[4:8])
if dataLen > MaxPacketSize {
return fmt.Errorf("数据包过大: %d bytes", dataLen)
}
// 读取数据
data := make([]byte, dataLen)
if _, err := io.ReadFull(conn, data); err != nil {
return err
}
// 处理数据
c.handleData(connID, data)
}
}
// handleData 处理接收到的数据
func (c *Client) handleData(connID uint32, data []byte) {
c.connMu.Lock()
localConn, exists := c.connections[connID]
if !exists {
// 新连接,需要建立到本地服务的连接
// 从数据中解析目标端口(这里简化处理,实际应该从协议中获取)
localConn = &LocalConnection{
ID: connID,
closeChan: make(chan struct{}),
}
c.connections[connID] = localConn
c.connMu.Unlock()
// 启动本地连接处理
c.wg.Add(1)
go c.handleLocalConnection(localConn)
// 重新获取锁并发送数据
c.connMu.Lock()
}
c.connMu.Unlock()
// 发送数据到本地连接
if localConn.Conn != nil {
localConn.Conn.Write(data)
}
}
// handleLocalConnection 处理本地连接
func (c *Client) handleLocalConnection(localConn *LocalConnection) {
defer c.wg.Done()
defer func() {
c.connMu.Lock()
delete(c.connections, localConn.ID)
c.connMu.Unlock()
if localConn.Conn != nil {
localConn.Conn.Close()
}
}()
// 连接到本地目标服务
// 这里使用固定的本地地址,实际应该根据映射配置
targetAddr := localConn.TargetAddr
if targetAddr == "" {
targetAddr = "127.0.0.1:22" // 默认 SSH
}
conn, err := net.DialTimeout("tcp", targetAddr, 5*time.Second)
if err != nil {
log.Printf("连接本地服务失败 (连接 %d -> %s): %v", localConn.ID, targetAddr, err)
return
}
defer conn.Close()
localConn.Conn = conn
log.Printf("建立本地连接: %d -> %s", localConn.ID, targetAddr)
// 从本地服务读取数据并发送到服务器
buffer := make([]byte, 32*1024)
for {
select {
case <-localConn.closeChan:
return
case <-c.ctx.Done():
return
default:
}
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := conn.Read(buffer)
if err != nil {
if err != io.EOF && !isTimeout(err) {
log.Printf("读取本地连接失败 (连接 %d): %v", localConn.ID, err)
}
return
}
// 发送到服务器
c.mu.RLock()
serverConn := c.serverConn
c.mu.RUnlock()
if serverConn == nil {
return
}
data := make([]byte, HeaderSize+n)
binary.BigEndian.PutUint32(data[0:4], uint32(n))
binary.BigEndian.PutUint32(data[4:8], localConn.ID)
copy(data[HeaderSize:], buffer[:n])
if _, err := serverConn.Write(data); err != nil {
log.Printf("发送数据到服务器失败 (连接 %d): %v", localConn.ID, err)
return
}
}
}
// Stop 停止隧道客户端
func (c *Client) Stop() error {
log.Println("正在停止隧道客户端...")
c.cancel()
c.mu.Lock()
if c.serverConn != nil {
c.serverConn.Close()
}
c.mu.Unlock()
// 等待所有协程结束
done := make(chan struct{})
go func() {
c.wg.Wait()
close(done)
}()
select {
case <-done:
log.Println("隧道客户端已停止")
case <-time.After(5 * time.Second):
log.Println("隧道客户端停止超时")
}
return nil
}
// isTimeout 检查是否为超时错误
func isTimeout(err error) bool {
if netErr, ok := err.(net.Error); ok {
return netErr.Timeout()
}
return false
}

View File

View File

247
src/server/api/api.go Normal file
View File

@ -0,0 +1,247 @@
package api
import (
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"port-forward/server/db"
"port-forward/server/forwarder"
"port-forward/server/tunnel"
"strconv"
"time"
)
// Handler HTTP API 处理器
type Handler struct {
db *db.Database
forwarderMgr *forwarder.Manager
tunnelServer *tunnel.Server
portRangeFrom int
portRangeEnd int
useTunnel bool
}
// NewHandler 创建新的 API 处理器
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, portFrom, portEnd int, useTunnel bool) *Handler {
return &Handler{
db: database,
forwarderMgr: fwdMgr,
tunnelServer: ts,
portRangeFrom: portFrom,
portRangeEnd: portEnd,
useTunnel: useTunnel,
}
}
// CreateMappingRequest 创建映射请求
type CreateMappingRequest struct {
Port int `json:"port"` // 源端口和目标端口(相同)
TargetIP string `json:"target_ip"` // 目标 IP非隧道模式使用
}
// RemoveMappingRequest 删除映射请求
type RemoveMappingRequest struct {
Port int `json:"port"`
}
// Response 统一响应格式
type Response struct {
Success bool `json:"success"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// RegisterRoutes 注册路由
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/mapping/create", h.handleCreateMapping)
mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping)
mux.HandleFunc("/api/mapping/list", h.handleListMappings)
mux.HandleFunc("/health", h.handleHealth)
}
// handleCreateMapping 处理创建映射请求
func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
h.writeError(w, http.StatusMethodNotAllowed, "只支持 POST 方法")
return
}
var req CreateMappingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
// 验证端口范围
if req.Port < h.portRangeFrom || req.Port > h.portRangeEnd {
h.writeError(w, http.StatusBadRequest, fmt.Sprintf("端口必须在 %d-%d 范围内", h.portRangeFrom, h.portRangeEnd))
return
}
// 检查端口是否已被使用
if h.forwarderMgr.Exists(req.Port) {
h.writeError(w, http.StatusConflict, "端口已被占用")
return
}
// 非隧道模式需要验证 IP
if !h.useTunnel {
if req.TargetIP == "" {
h.writeError(w, http.StatusBadRequest, "target_ip 不能为空")
return
}
if net.ParseIP(req.TargetIP) == nil {
h.writeError(w, http.StatusBadRequest, "target_ip 格式无效")
return
}
} else {
// 隧道模式,检查隧道是否连接
if !h.tunnelServer.IsConnected() {
h.writeError(w, http.StatusServiceUnavailable, "隧道未连接")
return
}
// 隧道模式使用本地地址
req.TargetIP = "127.0.0.1"
}
// 添加到数据库
if err := h.db.AddMapping(req.Port, req.TargetIP, req.Port); err != nil {
h.writeError(w, http.StatusInternalServerError, "保存映射失败: "+err.Error())
return
}
// 启动转发器
if err := h.forwarderMgr.Add(req.Port, req.TargetIP, req.Port); err != nil {
// 回滚数据库操作
h.db.RemoveMapping(req.Port)
h.writeError(w, http.StatusInternalServerError, "启动转发失败: "+err.Error())
return
}
log.Printf("创建端口映射: %d -> %s:%d", req.Port, req.TargetIP, req.Port)
h.writeSuccess(w, "端口映射创建成功", map[string]interface{}{
"port": req.Port,
"target_ip": req.TargetIP,
"use_tunnel": h.useTunnel,
})
}
// handleRemoveMapping 处理删除映射请求
func (h *Handler) handleRemoveMapping(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
h.writeError(w, http.StatusMethodNotAllowed, "只支持 POST 方法")
return
}
var req RemoveMappingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
// 验证端口范围
if req.Port < h.portRangeFrom || req.Port > h.portRangeEnd {
h.writeError(w, http.StatusBadRequest, fmt.Sprintf("端口必须在 %d-%d 范围内", h.portRangeFrom, h.portRangeEnd))
return
}
// 检查映射是否存在
if !h.forwarderMgr.Exists(req.Port) {
h.writeError(w, http.StatusNotFound, "端口映射不存在")
return
}
// 停止转发器
if err := h.forwarderMgr.Remove(req.Port); err != nil {
h.writeError(w, http.StatusInternalServerError, "停止转发失败: "+err.Error())
return
}
// 从数据库删除
if err := h.db.RemoveMapping(req.Port); err != nil {
log.Printf("从数据库删除映射失败 (端口 %d): %v", req.Port, err)
// 即使数据库删除失败,转发器已经停止,仍然返回成功
}
log.Printf("删除端口映射: %d", req.Port)
h.writeSuccess(w, "端口映射删除成功", map[string]interface{}{
"port": req.Port,
})
}
// handleListMappings 处理列出所有映射请求
func (h *Handler) handleListMappings(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
h.writeError(w, http.StatusMethodNotAllowed, "只支持 GET 方法")
return
}
mappings, err := h.db.GetAllMappings()
if err != nil {
h.writeError(w, http.StatusInternalServerError, "获取映射列表失败: "+err.Error())
return
}
h.writeSuccess(w, "获取映射列表成功", map[string]interface{}{
"mappings": mappings,
"count": len(mappings),
"use_tunnel": h.useTunnel,
})
}
// handleHealth 健康检查
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
status := map[string]interface{}{
"status": "ok",
"tunnel_enabled": h.useTunnel,
"tunnel_connected": false,
}
if h.useTunnel {
status["tunnel_connected"] = h.tunnelServer.IsConnected()
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
// writeSuccess 写入成功响应
func (h *Handler) writeSuccess(w http.ResponseWriter, message string, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{
Success: true,
Message: message,
Data: data,
})
}
// writeError 写入错误响应
func (h *Handler) writeError(w http.ResponseWriter, statusCode int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(Response{
Success: false,
Message: message,
})
}
// Start 启动 HTTP 服务器
func Start(handler *Handler, port int) error {
mux := http.NewServeMux()
handler.RegisterRoutes(mux)
server := &http.Server{
Addr: ":" + strconv.Itoa(port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
log.Printf("HTTP API 服务启动: 端口 %d", port)
return server.ListenAndServe()
}

View File

@ -0,0 +1,81 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
// Config 应用配置结构
type Config struct {
PortRange PortRangeConfig `yaml:"port_range"`
Tunnel TunnelConfig `yaml:"tunnel"`
API APIConfig `yaml:"api"`
Database DatabaseConfig `yaml:"database"`
}
// PortRangeConfig 端口范围配置
type PortRangeConfig struct {
From int `yaml:"from"`
End int `yaml:"end"`
}
// TunnelConfig 内网穿透配置
type TunnelConfig struct {
Enabled bool `yaml:"enabled"`
ListenPort int `yaml:"listen_port"`
}
// APIConfig HTTP API 配置
type APIConfig struct {
ListenPort int `yaml:"listen_port"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Path string `yaml:"path"`
}
// Load 从文件加载配置
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
// 验证配置
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err)
}
return &config, nil
}
// Validate 验证配置的有效性
func (c *Config) Validate() error {
if c.PortRange.From <= 0 || c.PortRange.End <= 0 {
return fmt.Errorf("端口范围必须大于 0")
}
if c.PortRange.From > c.PortRange.End {
return fmt.Errorf("起始端口不能大于结束端口")
}
if c.PortRange.End-c.PortRange.From > 10000 {
return fmt.Errorf("端口范围过大,最多支持 10000 个端口")
}
if c.Tunnel.Enabled && c.Tunnel.ListenPort <= 0 {
return fmt.Errorf("内网穿透端口必须大于 0")
}
if c.API.ListenPort <= 0 {
return fmt.Errorf("API 端口必须大于 0")
}
if c.Database.Path == "" {
return fmt.Errorf("数据库路径不能为空")
}
return nil
}

View File

@ -0,0 +1,136 @@
package config
import (
"os"
"testing"
)
func TestLoadConfig(t *testing.T) {
// 创建临时配置文件
configContent := `
port_range:
from: 10000
end: 10100
tunnel:
enabled: true
listen_port: 9000
api:
listen_port: 8080
database:
path: "./data/mappings.db"
`
tmpFile, err := os.CreateTemp("", "config_test_*.yaml")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.Write([]byte(configContent)); err != nil {
t.Fatalf("写入配置文件失败: %v", err)
}
tmpFile.Close()
// 加载配置
cfg, err := Load(tmpFile.Name())
if err != nil {
t.Fatalf("加载配置失败: %v", err)
}
// 验证配置
if cfg.PortRange.From != 10000 {
t.Errorf("期望起始端口为 10000得到 %d", cfg.PortRange.From)
}
if cfg.PortRange.End != 10100 {
t.Errorf("期望结束端口为 10100得到 %d", cfg.PortRange.End)
}
if !cfg.Tunnel.Enabled {
t.Error("期望隧道启用")
}
if cfg.Tunnel.ListenPort != 9000 {
t.Errorf("期望隧道端口为 9000得到 %d", cfg.Tunnel.ListenPort)
}
if cfg.API.ListenPort != 8080 {
t.Errorf("期望 API 端口为 8080得到 %d", cfg.API.ListenPort)
}
}
func TestValidateConfig(t *testing.T) {
tests := []struct {
name string
config Config
wantErr bool
}{
{
name: "有效配置",
config: Config{
PortRange: PortRangeConfig{From: 10000, End: 10100},
Tunnel: TunnelConfig{Enabled: true, ListenPort: 9000},
API: APIConfig{ListenPort: 8080},
Database: DatabaseConfig{Path: "./data/test.db"},
},
wantErr: false,
},
{
name: "无效端口范围 - 起始端口为0",
config: Config{
PortRange: PortRangeConfig{From: 0, End: 10100},
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
API: APIConfig{ListenPort: 8080},
Database: DatabaseConfig{Path: "./data/test.db"},
},
wantErr: true,
},
{
name: "无效端口范围 - 起始大于结束",
config: Config{
PortRange: PortRangeConfig{From: 10100, End: 10000},
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
API: APIConfig{ListenPort: 8080},
Database: DatabaseConfig{Path: "./data/test.db"},
},
wantErr: true,
},
{
name: "端口范围过大",
config: Config{
PortRange: PortRangeConfig{From: 1, End: 20000},
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
API: APIConfig{ListenPort: 8080},
Database: DatabaseConfig{Path: "./data/test.db"},
},
wantErr: true,
},
{
name: "启用隧道但端口无效",
config: Config{
PortRange: PortRangeConfig{From: 10000, End: 10100},
Tunnel: TunnelConfig{Enabled: true, ListenPort: 0},
API: APIConfig{ListenPort: 8080},
Database: DatabaseConfig{Path: "./data/test.db"},
},
wantErr: true,
},
{
name: "数据库路径为空",
config: Config{
PortRange: PortRangeConfig{From: 10000, End: 10100},
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
API: APIConfig{ListenPort: 8080},
Database: DatabaseConfig{Path: ""},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

178
src/server/db/database.go Normal file
View File

@ -0,0 +1,178 @@
package db
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"sync"
_ "github.com/mattn/go-sqlite3"
)
// Mapping 端口映射结构
type Mapping struct {
ID int64 `json:"id"`
SourcePort int `json:"source_port"`
TargetIP string `json:"target_ip"`
TargetPort int `json:"target_port"`
CreatedAt string `json:"created_at"`
}
// Database 数据库管理器
type Database struct {
db *sql.DB
mu sync.RWMutex
}
// New 创建新的数据库管理器
func New(dbPath string) (*Database, error) {
// 确保数据库目录存在
dir := filepath.Dir(dbPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("打开数据库失败: %w", err)
}
// 设置连接池参数
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
database := &Database{db: db}
// 初始化表结构
if err := database.initTables(); err != nil {
db.Close()
return nil, err
}
return database, nil
}
// initTables 初始化数据库表
func (d *Database) initTables() error {
query := `
CREATE TABLE IF NOT EXISTS mappings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_port INTEGER NOT NULL UNIQUE,
target_ip TEXT NOT NULL,
target_port INTEGER NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port);
`
_, err := d.db.Exec(query)
if err != nil {
return fmt.Errorf("初始化数据库表失败: %w", err)
}
return nil
}
// AddMapping 添加端口映射
func (d *Database) AddMapping(sourcePort int, targetIP string, targetPort int) error {
d.mu.Lock()
defer d.mu.Unlock()
query := `INSERT INTO mappings (source_port, target_ip, target_port) VALUES (?, ?, ?)`
_, err := d.db.Exec(query, sourcePort, targetIP, targetPort)
if err != nil {
return fmt.Errorf("添加端口映射失败: %w", err)
}
return nil
}
// RemoveMapping 删除端口映射
func (d *Database) RemoveMapping(sourcePort int) error {
d.mu.Lock()
defer d.mu.Unlock()
query := `DELETE FROM mappings WHERE source_port = ?`
result, err := d.db.Exec(query, sourcePort)
if err != nil {
return fmt.Errorf("删除端口映射失败: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("获取影响行数失败: %w", err)
}
if rows == 0 {
return fmt.Errorf("端口映射不存在")
}
return nil
}
// GetMapping 获取指定端口的映射
func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
d.mu.RLock()
defer d.mu.RUnlock()
query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings WHERE source_port = ?`
var mapping Mapping
err := d.db.QueryRow(query, sourcePort).Scan(
&mapping.ID,
&mapping.SourcePort,
&mapping.TargetIP,
&mapping.TargetPort,
&mapping.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询端口映射失败: %w", err)
}
return &mapping, nil
}
// GetAllMappings 获取所有端口映射
func (d *Database) GetAllMappings() ([]*Mapping, error) {
d.mu.RLock()
defer d.mu.RUnlock()
query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings ORDER BY source_port`
rows, err := d.db.Query(query)
if err != nil {
return nil, fmt.Errorf("查询所有映射失败: %w", err)
}
defer rows.Close()
var mappings []*Mapping
for rows.Next() {
var mapping Mapping
if err := rows.Scan(
&mapping.ID,
&mapping.SourcePort,
&mapping.TargetIP,
&mapping.TargetPort,
&mapping.CreatedAt,
); err != nil {
return nil, fmt.Errorf("扫描映射记录失败: %w", err)
}
mappings = append(mappings, &mapping)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("遍历映射记录失败: %w", err)
}
return mappings, nil
}
// Close 关闭数据库连接
func (d *Database) Close() error {
return d.db.Close()
}

View File

@ -0,0 +1,125 @@
package db
import (
"os"
"testing"
)
func TestDatabase(t *testing.T) {
// 使用临时数据库
dbPath := "/tmp/test_mappings.db"
defer os.Remove(dbPath)
// 创建数据库
db, err := New(dbPath)
if err != nil {
t.Fatalf("创建数据库失败: %v", err)
}
defer db.Close()
t.Run("添加映射", func(t *testing.T) {
err := db.AddMapping(10001, "192.168.1.100", 22)
if err != nil {
t.Errorf("添加映射失败: %v", err)
}
})
t.Run("获取映射", func(t *testing.T) {
mapping, err := db.GetMapping(10001)
if err != nil {
t.Errorf("获取映射失败: %v", err)
}
if mapping == nil {
t.Error("映射不应该为空")
}
if mapping.SourcePort != 10001 {
t.Errorf("期望源端口为 10001得到 %d", mapping.SourcePort)
}
if mapping.TargetIP != "192.168.1.100" {
t.Errorf("期望目标 IP 为 192.168.1.100,得到 %s", mapping.TargetIP)
}
if mapping.TargetPort != 22 {
t.Errorf("期望目标端口为 22得到 %d", mapping.TargetPort)
}
})
t.Run("添加重复映射应该失败", func(t *testing.T) {
err := db.AddMapping(10001, "192.168.1.101", 22)
if err == nil {
t.Error("添加重复映射应该失败")
}
})
t.Run("获取所有映射", func(t *testing.T) {
// 添加更多映射
db.AddMapping(10002, "192.168.1.101", 22)
db.AddMapping(10003, "192.168.1.102", 22)
mappings, err := db.GetAllMappings()
if err != nil {
t.Errorf("获取所有映射失败: %v", err)
}
if len(mappings) != 3 {
t.Errorf("期望 3 个映射,得到 %d", len(mappings))
}
})
t.Run("删除映射", func(t *testing.T) {
err := db.RemoveMapping(10001)
if err != nil {
t.Errorf("删除映射失败: %v", err)
}
mapping, err := db.GetMapping(10001)
if err != nil {
t.Errorf("查询映射失败: %v", err)
}
if mapping != nil {
t.Error("映射应该已被删除")
}
})
t.Run("删除不存在的映射应该失败", func(t *testing.T) {
err := db.RemoveMapping(99999)
if err == nil {
t.Error("删除不存在的映射应该失败")
}
})
}
func TestDatabaseConcurrency(t *testing.T) {
dbPath := "/tmp/test_concurrent.db"
defer os.Remove(dbPath)
db, err := New(dbPath)
if err != nil {
t.Fatalf("创建数据库失败: %v", err)
}
defer db.Close()
// 并发添加映射
done := make(chan bool)
for i := 0; i < 10; i++ {
go func(port int) {
err := db.AddMapping(10000+port, "192.168.1.100", port)
if err != nil {
t.Logf("添加映射失败 (端口 %d): %v", 10000+port, err)
}
done <- true
}(i)
}
// 等待所有操作完成
for i := 0; i < 10; i++ {
<-done
}
// 验证映射数量
mappings, err := db.GetAllMappings()
if err != nil {
t.Errorf("获取所有映射失败: %v", err)
}
if len(mappings) == 0 {
t.Error("应该至少有一些映射")
}
}

View File

@ -0,0 +1,278 @@
package forwarder
import (
"context"
"fmt"
"io"
"log"
"net"
"sync"
"time"
)
// Forwarder 端口转发器
type Forwarder struct {
sourcePort int
targetAddr string
listener net.Listener
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
tunnelConn net.Conn
useTunnel bool
mu sync.RWMutex
}
// NewForwarder 创建新的端口转发器
func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
return &Forwarder{
sourcePort: sourcePort,
targetAddr: fmt.Sprintf("%s:%d", targetIP, targetPort),
cancel: cancel,
ctx: ctx,
useTunnel: false,
}
}
// NewTunnelForwarder 创建使用隧道的端口转发器
func NewTunnelForwarder(sourcePort int, targetPort int, tunnelConn net.Conn) *Forwarder {
ctx, cancel := context.WithCancel(context.Background())
return &Forwarder{
sourcePort: sourcePort,
targetAddr: fmt.Sprintf("127.0.0.1:%d", targetPort),
tunnelConn: tunnelConn,
useTunnel: true,
cancel: cancel,
ctx: ctx,
}
}
// Start 启动端口转发
func (f *Forwarder) Start() error {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", f.sourcePort))
if err != nil {
return fmt.Errorf("监听端口 %d 失败: %w", f.sourcePort, err)
}
f.listener = listener
log.Printf("端口转发启动: %d -> %s (tunnel: %v)", f.sourcePort, f.targetAddr, f.useTunnel)
f.wg.Add(1)
go f.acceptLoop()
return nil
}
// acceptLoop 接受连接循环
func (f *Forwarder) acceptLoop() {
defer f.wg.Done()
for {
select {
case <-f.ctx.Done():
return
default:
}
// 设置接受超时,避免阻塞关闭
f.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second))
conn, err := f.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
select {
case <-f.ctx.Done():
return
default:
log.Printf("接受连接失败 (端口 %d): %v", f.sourcePort, err)
continue
}
}
f.wg.Add(1)
go f.handleConnection(conn)
}
}
// handleConnection 处理单个连接
func (f *Forwarder) handleConnection(clientConn net.Conn) {
defer f.wg.Done()
defer clientConn.Close()
var targetConn net.Conn
var err error
if f.useTunnel {
// 使用隧道连接
f.mu.RLock()
targetConn = f.tunnelConn
f.mu.RUnlock()
if targetConn == nil {
log.Printf("隧道连接不可用 (端口 %d)", f.sourcePort)
return
}
} else {
// 直接连接目标
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
targetConn, err = dialer.DialContext(f.ctx, "tcp", f.targetAddr)
if err != nil {
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, f.targetAddr, err)
return
}
defer targetConn.Close()
}
// 双向转发
errChan := make(chan error, 2)
// 客户端 -> 目标
go func() {
_, err := io.Copy(targetConn, clientConn)
errChan <- err
}()
// 目标 -> 客户端
go func() {
_, err := io.Copy(clientConn, targetConn)
errChan <- err
}()
// 等待任一方向完成或出错
select {
case <-errChan:
// 连接已关闭或出错
case <-f.ctx.Done():
// 转发器被停止
}
}
// Stop 停止端口转发
func (f *Forwarder) Stop() error {
f.cancel()
if f.listener != nil {
if err := f.listener.Close(); err != nil {
log.Printf("关闭监听器失败 (端口 %d): %v", f.sourcePort, err)
}
}
// 等待所有连接处理完成最多等待5秒
done := make(chan struct{})
go func() {
f.wg.Wait()
close(done)
}()
select {
case <-done:
log.Printf("端口转发已停止: %d", f.sourcePort)
case <-time.After(5 * time.Second):
log.Printf("端口转发停止超时: %d", f.sourcePort)
}
return nil
}
// SetTunnelConn 设置隧道连接
func (f *Forwarder) SetTunnelConn(conn net.Conn) {
f.mu.Lock()
defer f.mu.Unlock()
f.tunnelConn = conn
}
// Manager 转发器管理器
type Manager struct {
forwarders map[int]*Forwarder
mu sync.RWMutex
}
// NewManager 创建新的转发器管理器
func NewManager() *Manager {
return &Manager{
forwarders: make(map[int]*Forwarder),
}
}
// Add 添加并启动转发器
func (m *Manager) Add(sourcePort int, targetIP string, targetPort int) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.forwarders[sourcePort]; exists {
return fmt.Errorf("端口 %d 已被占用", sourcePort)
}
forwarder := NewForwarder(sourcePort, targetIP, targetPort)
if err := forwarder.Start(); err != nil {
return err
}
m.forwarders[sourcePort] = forwarder
return nil
}
// AddTunnel 添加使用隧道的转发器
func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelConn net.Conn) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.forwarders[sourcePort]; exists {
return fmt.Errorf("端口 %d 已被占用", sourcePort)
}
forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelConn)
if err := forwarder.Start(); err != nil {
return err
}
m.forwarders[sourcePort] = forwarder
return nil
}
// Remove 移除并停止转发器
func (m *Manager) Remove(sourcePort int) error {
m.mu.Lock()
defer m.mu.Unlock()
forwarder, exists := m.forwarders[sourcePort]
if !exists {
return fmt.Errorf("端口 %d 的转发器不存在", sourcePort)
}
if err := forwarder.Stop(); err != nil {
return err
}
delete(m.forwarders, sourcePort)
return nil
}
// Exists 检查转发器是否存在
func (m *Manager) Exists(sourcePort int) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, exists := m.forwarders[sourcePort]
return exists
}
// StopAll 停止所有转发器
func (m *Manager) StopAll() {
m.mu.Lock()
defer m.mu.Unlock()
for port, forwarder := range m.forwarders {
if err := forwarder.Stop(); err != nil {
log.Printf("停止端口 %d 的转发器失败: %v", port, err)
}
}
m.forwarders = make(map[int]*Forwarder)
}

119
src/server/main.go Normal file
View File

@ -0,0 +1,119 @@
package main
import (
"context"
"flag"
"log"
"os"
"os/signal"
"port-forward/server/api"
"port-forward/server/config"
"port-forward/server/db"
"port-forward/server/forwarder"
"port-forward/server/tunnel"
"syscall"
"time"
)
func main() {
// 解析命令行参数
configPath := flag.String("config", "config.yaml", "配置文件路径")
flag.Parse()
// 加载配置
log.Println("加载配置文件...")
cfg, err := config.Load(*configPath)
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
// 初始化数据库
log.Println("初始化数据库...")
database, err := db.New(cfg.Database.Path)
if err != nil {
log.Fatalf("初始化数据库失败: %v", err)
}
defer database.Close()
// 创建转发器管理器
log.Println("创建转发器管理器...")
fwdManager := forwarder.NewManager()
// 如果启用隧道,启动隧道服务器
var tunnelServer *tunnel.Server
if cfg.Tunnel.Enabled {
log.Println("启动隧道服务器...")
tunnelServer = tunnel.NewServer(cfg.Tunnel.ListenPort)
if err := tunnelServer.Start(); err != nil {
log.Fatalf("启动隧道服务器失败: %v", err)
}
defer tunnelServer.Stop()
}
// 从数据库加载现有映射并启动转发器
log.Println("加载现有端口映射...")
mappings, err := database.GetAllMappings()
if err != nil {
log.Fatalf("加载端口映射失败: %v", err)
}
for _, mapping := range mappings {
// 验证端口在范围内
if mapping.SourcePort < cfg.PortRange.From || mapping.SourcePort > cfg.PortRange.End {
log.Printf("警告: 端口 %d 超出范围,跳过", mapping.SourcePort)
continue
}
log.Printf("恢复端口映射: %d -> %s:%d", mapping.SourcePort, mapping.TargetIP, mapping.TargetPort)
if err := fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort); err != nil {
log.Printf("警告: 启动端口 %d 的转发失败: %v", mapping.SourcePort, err)
}
}
log.Printf("成功加载 %d 个端口映射", len(mappings))
// 创建 HTTP API 处理器
log.Println("初始化 HTTP API...")
apiHandler := api.NewHandler(
database,
fwdManager,
tunnelServer,
cfg.PortRange.From,
cfg.PortRange.End,
cfg.Tunnel.Enabled,
)
// 启动 HTTP API 服务器
go func() {
if err := api.Start(apiHandler, cfg.API.ListenPort); err != nil {
log.Fatalf("启动 HTTP API 服务失败: %v", err)
}
}()
log.Println("===========================================")
log.Printf("服务器启动成功!")
log.Printf("端口范围: %d-%d", cfg.PortRange.From, cfg.PortRange.End)
log.Printf("HTTP API: http://localhost:%d", cfg.API.ListenPort)
if cfg.Tunnel.Enabled {
log.Printf("隧道服务: 端口 %d", cfg.Tunnel.ListenPort)
}
log.Println("===========================================")
// 等待中断信号
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
log.Println("\n接收到关闭信号正在优雅关闭...")
// 创建关闭上下文
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 停止所有转发器
log.Println("停止所有端口转发...")
fwdManager.StopAll()
log.Println("服务器已关闭")
<-ctx.Done()
}

320
src/server/tunnel/tunnel.go Normal file
View File

@ -0,0 +1,320 @@
package tunnel
import (
"context"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"sync"
"time"
)
// Protocol 定义隧道协议
// 消息格式: [4字节长度][4字节端口][数据]
const (
// HeaderSize 消息头大小(长度+端口)
HeaderSize = 8
// MaxPacketSize 最大包大小 (1MB)
MaxPacketSize = 1024 * 1024
)
// Server 内网穿透服务器
type Server struct {
listenPort int
listener net.Listener
client net.Conn
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
mu sync.RWMutex
// 连接管理
connections map[uint32]*Connection
connMu sync.RWMutex
nextConnID uint32
}
// Connection 表示一个客户端连接
type Connection struct {
ID uint32
TargetPort int
ClientConn net.Conn
readChan chan []byte
writeChan chan []byte
closeChan chan struct{}
}
// NewServer 创建新的隧道服务器
func NewServer(listenPort int) *Server {
ctx, cancel := context.WithCancel(context.Background())
return &Server{
listenPort: listenPort,
cancel: cancel,
ctx: ctx,
connections: make(map[uint32]*Connection),
}
}
// Start 启动隧道服务器
func (s *Server) Start() error {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.listenPort))
if err != nil {
return fmt.Errorf("启动隧道服务器失败: %w", err)
}
s.listener = listener
log.Printf("隧道服务器启动: 端口 %d", s.listenPort)
s.wg.Add(1)
go s.acceptLoop()
return nil
}
// acceptLoop 接受客户端连接
func (s *Server) acceptLoop() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
default:
}
s.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second))
conn, err := s.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
select {
case <-s.ctx.Done():
return
default:
log.Printf("接受隧道连接失败: %v", err)
continue
}
}
// 只允许一个客户端连接
s.mu.Lock()
if s.client != nil {
log.Printf("拒绝额外的隧道连接: %s", conn.RemoteAddr())
conn.Close()
s.mu.Unlock()
continue
}
s.client = conn
s.mu.Unlock()
log.Printf("隧道客户端已连接: %s", conn.RemoteAddr())
s.wg.Add(1)
go s.handleClient(conn)
}
}
// handleClient 处理客户端连接
func (s *Server) handleClient(conn net.Conn) {
defer s.wg.Done()
defer func() {
conn.Close()
s.mu.Lock()
s.client = nil
s.mu.Unlock()
log.Printf("隧道客户端已断开")
// 关闭所有活动连接
s.connMu.Lock()
for _, c := range s.connections {
close(c.closeChan)
}
s.connections = make(map[uint32]*Connection)
s.connMu.Unlock()
}()
// 读取来自客户端的数据
for {
select {
case <-s.ctx.Done():
return
default:
}
// 读取消息头
header := make([]byte, HeaderSize)
if _, err := io.ReadFull(conn, header); err != nil {
if err != io.EOF {
log.Printf("读取隧道消息头失败: %v", err)
}
return
}
dataLen := binary.BigEndian.Uint32(header[0:4])
connID := binary.BigEndian.Uint32(header[4:8])
if dataLen > MaxPacketSize {
log.Printf("数据包过大: %d bytes", dataLen)
return
}
// 读取数据
data := make([]byte, dataLen)
if _, err := io.ReadFull(conn, data); err != nil {
log.Printf("读取隧道数据失败: %v", err)
return
}
// 将数据发送到对应的连接
s.connMu.RLock()
connection, exists := s.connections[connID]
s.connMu.RUnlock()
if exists {
select {
case connection.readChan <- data:
case <-connection.closeChan:
case <-time.After(5 * time.Second):
log.Printf("向连接 %d 发送数据超时", connID)
}
}
}
}
// ForwardConnection 转发连接到隧道
func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
s.mu.RLock()
tunnelConn := s.client
s.mu.RUnlock()
if tunnelConn == nil {
return fmt.Errorf("隧道连接不可用")
}
// 创建连接对象
s.connMu.Lock()
connID := s.nextConnID
s.nextConnID++
connection := &Connection{
ID: connID,
TargetPort: targetPort,
ClientConn: clientConn,
readChan: make(chan []byte, 100),
writeChan: make(chan []byte, 100),
closeChan: make(chan struct{}),
}
s.connections[connID] = connection
s.connMu.Unlock()
defer func() {
s.connMu.Lock()
delete(s.connections, connID)
s.connMu.Unlock()
close(connection.closeChan)
clientConn.Close()
}()
// 启动读写协程
errChan := make(chan error, 2)
// 从客户端读取并发送到隧道
go func() {
buffer := make([]byte, 32*1024)
for {
select {
case <-connection.closeChan:
return
default:
}
clientConn.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := clientConn.Read(buffer)
if err != nil {
errChan <- err
return
}
// 发送到隧道
data := make([]byte, HeaderSize+n)
binary.BigEndian.PutUint32(data[0:4], uint32(n))
binary.BigEndian.PutUint32(data[4:8], connID)
copy(data[HeaderSize:], buffer[:n])
s.mu.RLock()
_, err = tunnelConn.Write(data)
s.mu.RUnlock()
if err != nil {
errChan <- err
return
}
}
}()
// 从隧道读取并发送到客户端
go func() {
for {
select {
case data := <-connection.readChan:
if _, err := clientConn.Write(data); err != nil {
errChan <- err
return
}
case <-connection.closeChan:
return
}
}
}()
// 等待错误或关闭
select {
case <-errChan:
case <-connection.closeChan:
case <-s.ctx.Done():
}
return nil
}
// IsConnected 检查隧道是否已连接
func (s *Server) IsConnected() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.client != nil
}
// Stop 停止隧道服务器
func (s *Server) Stop() error {
s.cancel()
if s.listener != nil {
s.listener.Close()
}
s.mu.Lock()
if s.client != nil {
s.client.Close()
}
s.mu.Unlock()
// 等待所有协程结束
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
log.Printf("隧道服务器已停止")
case <-time.After(5 * time.Second):
log.Printf("隧道服务器停止超时")
}
return nil
}