init
This commit is contained in:
parent
b6baedeaf8
commit
6b02def8de
|
|
@ -5,10 +5,10 @@ bin/
|
|||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
server
|
||||
client
|
||||
server-linux
|
||||
client-linux
|
||||
/server
|
||||
/client
|
||||
/server-linux
|
||||
/client-linux
|
||||
|
||||
# 测试文件
|
||||
*.test
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"port-forward/client/tunnel"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
serverAddr := flag.String("server", "localhost:9000", "隧道服务器地址 (host:port)")
|
||||
flag.Parse()
|
||||
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
// 创建隧道客户端
|
||||
log.Printf("隧道客户端启动...")
|
||||
log.Printf("服务器地址: %s", *serverAddr)
|
||||
|
||||
client := tunnel.NewClient(*serverAddr)
|
||||
|
||||
// 启动客户端
|
||||
if err := client.Start(); err != nil {
|
||||
log.Fatalf("启动隧道客户端失败: %v", err)
|
||||
}
|
||||
|
||||
log.Println("===========================================")
|
||||
log.Println("隧道客户端运行中...")
|
||||
log.Println("按 Ctrl+C 退出")
|
||||
log.Println("===========================================")
|
||||
|
||||
// 等待中断信号
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
<-sigChan
|
||||
log.Println("\n接收到关闭信号,正在关闭...")
|
||||
|
||||
// 停止客户端
|
||||
if err := client.Stop(); err != nil {
|
||||
log.Printf("停止客户端失败: %v", err)
|
||||
}
|
||||
|
||||
log.Println("客户端已关闭")
|
||||
}
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// HeaderSize 消息头大小
|
||||
HeaderSize = 8
|
||||
// MaxPacketSize 最大包大小
|
||||
MaxPacketSize = 1024 * 1024
|
||||
// ReconnectDelay 重连延迟
|
||||
ReconnectDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
// Client 内网穿透客户端
|
||||
type Client struct {
|
||||
serverAddr string
|
||||
serverConn net.Conn
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
|
||||
// 连接管理
|
||||
connections map[uint32]*LocalConnection
|
||||
connMu sync.RWMutex
|
||||
}
|
||||
|
||||
// LocalConnection 本地连接
|
||||
type LocalConnection struct {
|
||||
ID uint32
|
||||
TargetAddr string
|
||||
Conn net.Conn
|
||||
closeChan chan struct{}
|
||||
}
|
||||
|
||||
// NewClient 创建新的隧道客户端
|
||||
func NewClient(serverAddr string) *Client {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Client{
|
||||
serverAddr: serverAddr,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
connections: make(map[uint32]*LocalConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动隧道客户端
|
||||
func (c *Client) Start() error {
|
||||
log.Printf("正在连接到隧道服务器: %s", c.serverAddr)
|
||||
|
||||
c.wg.Add(1)
|
||||
go c.connectLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectLoop 连接循环(支持自动重连)
|
||||
func (c *Client) connectLoop() {
|
||||
defer c.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", c.serverAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("连接隧道服务器失败: %v,%v 后重试", err, ReconnectDelay)
|
||||
time.Sleep(ReconnectDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("已连接到隧道服务器: %s", c.serverAddr)
|
||||
|
||||
c.mu.Lock()
|
||||
c.serverConn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
// 处理连接
|
||||
if err := c.handleServerConnection(conn); err != nil {
|
||||
if err != io.EOF {
|
||||
log.Printf("处理服务器连接出错: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.serverConn = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
// 关闭所有本地连接
|
||||
c.connMu.Lock()
|
||||
for _, conn := range c.connections {
|
||||
close(conn.closeChan)
|
||||
if conn.Conn != nil {
|
||||
conn.Conn.Close()
|
||||
}
|
||||
}
|
||||
c.connections = make(map[uint32]*LocalConnection)
|
||||
c.connMu.Unlock()
|
||||
|
||||
log.Printf("与服务器断开连接,%v 后重连", ReconnectDelay)
|
||||
time.Sleep(ReconnectDelay)
|
||||
}
|
||||
}
|
||||
|
||||
// handleServerConnection 处理服务器连接
|
||||
func (c *Client) handleServerConnection(conn net.Conn) error {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// 读取消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dataLen := binary.BigEndian.Uint32(header[0:4])
|
||||
connID := binary.BigEndian.Uint32(header[4:8])
|
||||
|
||||
if dataLen > MaxPacketSize {
|
||||
return fmt.Errorf("数据包过大: %d bytes", dataLen)
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
data := make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(conn, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理数据
|
||||
c.handleData(connID, data)
|
||||
}
|
||||
}
|
||||
|
||||
// handleData 处理接收到的数据
|
||||
func (c *Client) handleData(connID uint32, data []byte) {
|
||||
c.connMu.Lock()
|
||||
localConn, exists := c.connections[connID]
|
||||
|
||||
if !exists {
|
||||
// 新连接,需要建立到本地服务的连接
|
||||
// 从数据中解析目标端口(这里简化处理,实际应该从协议中获取)
|
||||
localConn = &LocalConnection{
|
||||
ID: connID,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
c.connections[connID] = localConn
|
||||
c.connMu.Unlock()
|
||||
|
||||
// 启动本地连接处理
|
||||
c.wg.Add(1)
|
||||
go c.handleLocalConnection(localConn)
|
||||
|
||||
// 重新获取锁并发送数据
|
||||
c.connMu.Lock()
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
|
||||
// 发送数据到本地连接
|
||||
if localConn.Conn != nil {
|
||||
localConn.Conn.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
// handleLocalConnection 处理本地连接
|
||||
func (c *Client) handleLocalConnection(localConn *LocalConnection) {
|
||||
defer c.wg.Done()
|
||||
defer func() {
|
||||
c.connMu.Lock()
|
||||
delete(c.connections, localConn.ID)
|
||||
c.connMu.Unlock()
|
||||
|
||||
if localConn.Conn != nil {
|
||||
localConn.Conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// 连接到本地目标服务
|
||||
// 这里使用固定的本地地址,实际应该根据映射配置
|
||||
targetAddr := localConn.TargetAddr
|
||||
if targetAddr == "" {
|
||||
targetAddr = "127.0.0.1:22" // 默认 SSH
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", targetAddr, 5*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("连接本地服务失败 (连接 %d -> %s): %v", localConn.ID, targetAddr, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
localConn.Conn = conn
|
||||
log.Printf("建立本地连接: %d -> %s", localConn.ID, targetAddr)
|
||||
|
||||
// 从本地服务读取数据并发送到服务器
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-localConn.closeChan:
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF && !isTimeout(err) {
|
||||
log.Printf("读取本地连接失败 (连接 %d): %v", localConn.ID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送到服务器
|
||||
c.mu.RLock()
|
||||
serverConn := c.serverConn
|
||||
c.mu.RUnlock()
|
||||
|
||||
if serverConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]byte, HeaderSize+n)
|
||||
binary.BigEndian.PutUint32(data[0:4], uint32(n))
|
||||
binary.BigEndian.PutUint32(data[4:8], localConn.ID)
|
||||
copy(data[HeaderSize:], buffer[:n])
|
||||
|
||||
if _, err := serverConn.Write(data); err != nil {
|
||||
log.Printf("发送数据到服务器失败 (连接 %d): %v", localConn.ID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止隧道客户端
|
||||
func (c *Client) Stop() error {
|
||||
log.Println("正在停止隧道客户端...")
|
||||
c.cancel()
|
||||
|
||||
c.mu.Lock()
|
||||
if c.serverConn != nil {
|
||||
c.serverConn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
// 等待所有协程结束
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Println("隧道客户端已停止")
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Println("隧道客户端停止超时")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isTimeout 检查是否为超时错误
|
||||
func isTimeout(err error) bool {
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"port-forward/server/db"
|
||||
"port-forward/server/forwarder"
|
||||
"port-forward/server/tunnel"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Handler HTTP API 处理器
|
||||
type Handler struct {
|
||||
db *db.Database
|
||||
forwarderMgr *forwarder.Manager
|
||||
tunnelServer *tunnel.Server
|
||||
portRangeFrom int
|
||||
portRangeEnd int
|
||||
useTunnel bool
|
||||
}
|
||||
|
||||
// NewHandler 创建新的 API 处理器
|
||||
func NewHandler(database *db.Database, fwdMgr *forwarder.Manager, ts *tunnel.Server, portFrom, portEnd int, useTunnel bool) *Handler {
|
||||
return &Handler{
|
||||
db: database,
|
||||
forwarderMgr: fwdMgr,
|
||||
tunnelServer: ts,
|
||||
portRangeFrom: portFrom,
|
||||
portRangeEnd: portEnd,
|
||||
useTunnel: useTunnel,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMappingRequest 创建映射请求
|
||||
type CreateMappingRequest struct {
|
||||
Port int `json:"port"` // 源端口和目标端口(相同)
|
||||
TargetIP string `json:"target_ip"` // 目标 IP(非隧道模式使用)
|
||||
}
|
||||
|
||||
// RemoveMappingRequest 删除映射请求
|
||||
type RemoveMappingRequest struct {
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
// Response 统一响应格式
|
||||
type Response struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/mapping/create", h.handleCreateMapping)
|
||||
mux.HandleFunc("/api/mapping/remove", h.handleRemoveMapping)
|
||||
mux.HandleFunc("/api/mapping/list", h.handleListMappings)
|
||||
mux.HandleFunc("/health", h.handleHealth)
|
||||
}
|
||||
|
||||
// handleCreateMapping 处理创建映射请求
|
||||
func (h *Handler) handleCreateMapping(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
h.writeError(w, http.StatusMethodNotAllowed, "只支持 POST 方法")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateMappingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证端口范围
|
||||
if req.Port < h.portRangeFrom || req.Port > h.portRangeEnd {
|
||||
h.writeError(w, http.StatusBadRequest, fmt.Sprintf("端口必须在 %d-%d 范围内", h.portRangeFrom, h.portRangeEnd))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查端口是否已被使用
|
||||
if h.forwarderMgr.Exists(req.Port) {
|
||||
h.writeError(w, http.StatusConflict, "端口已被占用")
|
||||
return
|
||||
}
|
||||
|
||||
// 非隧道模式需要验证 IP
|
||||
if !h.useTunnel {
|
||||
if req.TargetIP == "" {
|
||||
h.writeError(w, http.StatusBadRequest, "target_ip 不能为空")
|
||||
return
|
||||
}
|
||||
if net.ParseIP(req.TargetIP) == nil {
|
||||
h.writeError(w, http.StatusBadRequest, "target_ip 格式无效")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 隧道模式,检查隧道是否连接
|
||||
if !h.tunnelServer.IsConnected() {
|
||||
h.writeError(w, http.StatusServiceUnavailable, "隧道未连接")
|
||||
return
|
||||
}
|
||||
// 隧道模式使用本地地址
|
||||
req.TargetIP = "127.0.0.1"
|
||||
}
|
||||
|
||||
// 添加到数据库
|
||||
if err := h.db.AddMapping(req.Port, req.TargetIP, req.Port); err != nil {
|
||||
h.writeError(w, http.StatusInternalServerError, "保存映射失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 启动转发器
|
||||
if err := h.forwarderMgr.Add(req.Port, req.TargetIP, req.Port); err != nil {
|
||||
// 回滚数据库操作
|
||||
h.db.RemoveMapping(req.Port)
|
||||
h.writeError(w, http.StatusInternalServerError, "启动转发失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("创建端口映射: %d -> %s:%d", req.Port, req.TargetIP, req.Port)
|
||||
|
||||
h.writeSuccess(w, "端口映射创建成功", map[string]interface{}{
|
||||
"port": req.Port,
|
||||
"target_ip": req.TargetIP,
|
||||
"use_tunnel": h.useTunnel,
|
||||
})
|
||||
}
|
||||
|
||||
// handleRemoveMapping 处理删除映射请求
|
||||
func (h *Handler) handleRemoveMapping(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
h.writeError(w, http.StatusMethodNotAllowed, "只支持 POST 方法")
|
||||
return
|
||||
}
|
||||
|
||||
var req RemoveMappingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证端口范围
|
||||
if req.Port < h.portRangeFrom || req.Port > h.portRangeEnd {
|
||||
h.writeError(w, http.StatusBadRequest, fmt.Sprintf("端口必须在 %d-%d 范围内", h.portRangeFrom, h.portRangeEnd))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查映射是否存在
|
||||
if !h.forwarderMgr.Exists(req.Port) {
|
||||
h.writeError(w, http.StatusNotFound, "端口映射不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 停止转发器
|
||||
if err := h.forwarderMgr.Remove(req.Port); err != nil {
|
||||
h.writeError(w, http.StatusInternalServerError, "停止转发失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库删除
|
||||
if err := h.db.RemoveMapping(req.Port); err != nil {
|
||||
log.Printf("从数据库删除映射失败 (端口 %d): %v", req.Port, err)
|
||||
// 即使数据库删除失败,转发器已经停止,仍然返回成功
|
||||
}
|
||||
|
||||
log.Printf("删除端口映射: %d", req.Port)
|
||||
|
||||
h.writeSuccess(w, "端口映射删除成功", map[string]interface{}{
|
||||
"port": req.Port,
|
||||
})
|
||||
}
|
||||
|
||||
// handleListMappings 处理列出所有映射请求
|
||||
func (h *Handler) handleListMappings(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
h.writeError(w, http.StatusMethodNotAllowed, "只支持 GET 方法")
|
||||
return
|
||||
}
|
||||
|
||||
mappings, err := h.db.GetAllMappings()
|
||||
if err != nil {
|
||||
h.writeError(w, http.StatusInternalServerError, "获取映射列表失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.writeSuccess(w, "获取映射列表成功", map[string]interface{}{
|
||||
"mappings": mappings,
|
||||
"count": len(mappings),
|
||||
"use_tunnel": h.useTunnel,
|
||||
})
|
||||
}
|
||||
|
||||
// handleHealth 健康检查
|
||||
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := map[string]interface{}{
|
||||
"status": "ok",
|
||||
"tunnel_enabled": h.useTunnel,
|
||||
"tunnel_connected": false,
|
||||
}
|
||||
|
||||
if h.useTunnel {
|
||||
status["tunnel_connected"] = h.tunnelServer.IsConnected()
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
// writeSuccess 写入成功响应
|
||||
func (h *Handler) writeSuccess(w http.ResponseWriter, message string, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Success: true,
|
||||
Message: message,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// writeError 写入错误响应
|
||||
func (h *Handler) writeError(w http.ResponseWriter, statusCode int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Success: false,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// Start 启动 HTTP 服务器
|
||||
func Start(handler *Handler, port int) error {
|
||||
mux := http.NewServeMux()
|
||||
handler.RegisterRoutes(mux)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":" + strconv.Itoa(port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
log.Printf("HTTP API 服务启动: 端口 %d", port)
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config 应用配置结构
|
||||
type Config struct {
|
||||
PortRange PortRangeConfig `yaml:"port_range"`
|
||||
Tunnel TunnelConfig `yaml:"tunnel"`
|
||||
API APIConfig `yaml:"api"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
}
|
||||
|
||||
// PortRangeConfig 端口范围配置
|
||||
type PortRangeConfig struct {
|
||||
From int `yaml:"from"`
|
||||
End int `yaml:"end"`
|
||||
}
|
||||
|
||||
// TunnelConfig 内网穿透配置
|
||||
type TunnelConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
}
|
||||
|
||||
// APIConfig HTTP API 配置
|
||||
type APIConfig struct {
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Path string `yaml:"path"`
|
||||
}
|
||||
|
||||
// Load 从文件加载配置
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("配置验证失败: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// Validate 验证配置的有效性
|
||||
func (c *Config) Validate() error {
|
||||
if c.PortRange.From <= 0 || c.PortRange.End <= 0 {
|
||||
return fmt.Errorf("端口范围必须大于 0")
|
||||
}
|
||||
if c.PortRange.From > c.PortRange.End {
|
||||
return fmt.Errorf("起始端口不能大于结束端口")
|
||||
}
|
||||
if c.PortRange.End-c.PortRange.From > 10000 {
|
||||
return fmt.Errorf("端口范围过大,最多支持 10000 个端口")
|
||||
}
|
||||
if c.Tunnel.Enabled && c.Tunnel.ListenPort <= 0 {
|
||||
return fmt.Errorf("内网穿透端口必须大于 0")
|
||||
}
|
||||
if c.API.ListenPort <= 0 {
|
||||
return fmt.Errorf("API 端口必须大于 0")
|
||||
}
|
||||
if c.Database.Path == "" {
|
||||
return fmt.Errorf("数据库路径不能为空")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
// 创建临时配置文件
|
||||
configContent := `
|
||||
port_range:
|
||||
from: 10000
|
||||
end: 10100
|
||||
|
||||
tunnel:
|
||||
enabled: true
|
||||
listen_port: 9000
|
||||
|
||||
api:
|
||||
listen_port: 8080
|
||||
|
||||
database:
|
||||
path: "./data/mappings.db"
|
||||
`
|
||||
tmpFile, err := os.CreateTemp("", "config_test_*.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("创建临时文件失败: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := tmpFile.Write([]byte(configContent)); err != nil {
|
||||
t.Fatalf("写入配置文件失败: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// 加载配置
|
||||
cfg, err := Load(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if cfg.PortRange.From != 10000 {
|
||||
t.Errorf("期望起始端口为 10000,得到 %d", cfg.PortRange.From)
|
||||
}
|
||||
if cfg.PortRange.End != 10100 {
|
||||
t.Errorf("期望结束端口为 10100,得到 %d", cfg.PortRange.End)
|
||||
}
|
||||
if !cfg.Tunnel.Enabled {
|
||||
t.Error("期望隧道启用")
|
||||
}
|
||||
if cfg.Tunnel.ListenPort != 9000 {
|
||||
t.Errorf("期望隧道端口为 9000,得到 %d", cfg.Tunnel.ListenPort)
|
||||
}
|
||||
if cfg.API.ListenPort != 8080 {
|
||||
t.Errorf("期望 API 端口为 8080,得到 %d", cfg.API.ListenPort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "有效配置",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
Tunnel: TunnelConfig{Enabled: true, ListenPort: 9000},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "无效端口范围 - 起始端口为0",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 0, End: 10100},
|
||||
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无效端口范围 - 起始大于结束",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 10100, End: 10000},
|
||||
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "端口范围过大",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 1, End: 20000},
|
||||
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "启用隧道但端口无效",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
Tunnel: TunnelConfig{Enabled: true, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: "./data/test.db"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "数据库路径为空",
|
||||
config: Config{
|
||||
PortRange: PortRangeConfig{From: 10000, End: 10100},
|
||||
Tunnel: TunnelConfig{Enabled: false, ListenPort: 0},
|
||||
API: APIConfig{ListenPort: 8080},
|
||||
Database: DatabaseConfig{Path: ""},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Mapping 端口映射结构
|
||||
type Mapping struct {
|
||||
ID int64 `json:"id"`
|
||||
SourcePort int `json:"source_port"`
|
||||
TargetIP string `json:"target_ip"`
|
||||
TargetPort int `json:"target_port"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// Database 数据库管理器
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// New 创建新的数据库管理器
|
||||
func New(dbPath string) (*Database, error) {
|
||||
// 确保数据库目录存在
|
||||
dir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置连接池参数
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
|
||||
database := &Database{db: db}
|
||||
|
||||
// 初始化表结构
|
||||
if err := database.initTables(); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// initTables 初始化数据库表
|
||||
func (d *Database) initTables() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS mappings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_port INTEGER NOT NULL UNIQUE,
|
||||
target_ip TEXT NOT NULL,
|
||||
target_port INTEGER NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_source_port ON mappings(source_port);
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化数据库表失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMapping 添加端口映射
|
||||
func (d *Database) AddMapping(sourcePort int, targetIP string, targetPort int) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
query := `INSERT INTO mappings (source_port, target_ip, target_port) VALUES (?, ?, ?)`
|
||||
_, err := d.db.Exec(query, sourcePort, targetIP, targetPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加端口映射失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveMapping 删除端口映射
|
||||
func (d *Database) RemoveMapping(sourcePort int) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
query := `DELETE FROM mappings WHERE source_port = ?`
|
||||
result, err := d.db.Exec(query, sourcePort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除端口映射失败: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取影响行数失败: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("端口映射不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMapping 获取指定端口的映射
|
||||
func (d *Database) GetMapping(sourcePort int) (*Mapping, error) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings WHERE source_port = ?`
|
||||
|
||||
var mapping Mapping
|
||||
err := d.db.QueryRow(query, sourcePort).Scan(
|
||||
&mapping.ID,
|
||||
&mapping.SourcePort,
|
||||
&mapping.TargetIP,
|
||||
&mapping.TargetPort,
|
||||
&mapping.CreatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询端口映射失败: %w", err)
|
||||
}
|
||||
|
||||
return &mapping, nil
|
||||
}
|
||||
|
||||
// GetAllMappings 获取所有端口映射
|
||||
func (d *Database) GetAllMappings() ([]*Mapping, error) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
query := `SELECT id, source_port, target_ip, target_port, created_at FROM mappings ORDER BY source_port`
|
||||
|
||||
rows, err := d.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询所有映射失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var mappings []*Mapping
|
||||
for rows.Next() {
|
||||
var mapping Mapping
|
||||
if err := rows.Scan(
|
||||
&mapping.ID,
|
||||
&mapping.SourcePort,
|
||||
&mapping.TargetIP,
|
||||
&mapping.TargetPort,
|
||||
&mapping.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描映射记录失败: %w", err)
|
||||
}
|
||||
mappings = append(mappings, &mapping)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("遍历映射记录失败: %w", err)
|
||||
}
|
||||
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (d *Database) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDatabase(t *testing.T) {
|
||||
// 使用临时数据库
|
||||
dbPath := "/tmp/test_mappings.db"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
// 创建数据库
|
||||
db, err := New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("创建数据库失败: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
t.Run("添加映射", func(t *testing.T) {
|
||||
err := db.AddMapping(10001, "192.168.1.100", 22)
|
||||
if err != nil {
|
||||
t.Errorf("添加映射失败: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("获取映射", func(t *testing.T) {
|
||||
mapping, err := db.GetMapping(10001)
|
||||
if err != nil {
|
||||
t.Errorf("获取映射失败: %v", err)
|
||||
}
|
||||
if mapping == nil {
|
||||
t.Error("映射不应该为空")
|
||||
}
|
||||
if mapping.SourcePort != 10001 {
|
||||
t.Errorf("期望源端口为 10001,得到 %d", mapping.SourcePort)
|
||||
}
|
||||
if mapping.TargetIP != "192.168.1.100" {
|
||||
t.Errorf("期望目标 IP 为 192.168.1.100,得到 %s", mapping.TargetIP)
|
||||
}
|
||||
if mapping.TargetPort != 22 {
|
||||
t.Errorf("期望目标端口为 22,得到 %d", mapping.TargetPort)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("添加重复映射应该失败", func(t *testing.T) {
|
||||
err := db.AddMapping(10001, "192.168.1.101", 22)
|
||||
if err == nil {
|
||||
t.Error("添加重复映射应该失败")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("获取所有映射", func(t *testing.T) {
|
||||
// 添加更多映射
|
||||
db.AddMapping(10002, "192.168.1.101", 22)
|
||||
db.AddMapping(10003, "192.168.1.102", 22)
|
||||
|
||||
mappings, err := db.GetAllMappings()
|
||||
if err != nil {
|
||||
t.Errorf("获取所有映射失败: %v", err)
|
||||
}
|
||||
if len(mappings) != 3 {
|
||||
t.Errorf("期望 3 个映射,得到 %d", len(mappings))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("删除映射", func(t *testing.T) {
|
||||
err := db.RemoveMapping(10001)
|
||||
if err != nil {
|
||||
t.Errorf("删除映射失败: %v", err)
|
||||
}
|
||||
|
||||
mapping, err := db.GetMapping(10001)
|
||||
if err != nil {
|
||||
t.Errorf("查询映射失败: %v", err)
|
||||
}
|
||||
if mapping != nil {
|
||||
t.Error("映射应该已被删除")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("删除不存在的映射应该失败", func(t *testing.T) {
|
||||
err := db.RemoveMapping(99999)
|
||||
if err == nil {
|
||||
t.Error("删除不存在的映射应该失败")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabaseConcurrency(t *testing.T) {
|
||||
dbPath := "/tmp/test_concurrent.db"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
db, err := New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("创建数据库失败: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 并发添加映射
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(port int) {
|
||||
err := db.AddMapping(10000+port, "192.168.1.100", port)
|
||||
if err != nil {
|
||||
t.Logf("添加映射失败 (端口 %d): %v", 10000+port, err)
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有操作完成
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// 验证映射数量
|
||||
mappings, err := db.GetAllMappings()
|
||||
if err != nil {
|
||||
t.Errorf("获取所有映射失败: %v", err)
|
||||
}
|
||||
if len(mappings) == 0 {
|
||||
t.Error("应该至少有一些映射")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Forwarder 端口转发器
|
||||
type Forwarder struct {
|
||||
sourcePort int
|
||||
targetAddr string
|
||||
listener net.Listener
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
tunnelConn net.Conn
|
||||
useTunnel bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewForwarder 创建新的端口转发器
|
||||
func NewForwarder(sourcePort int, targetIP string, targetPort int) *Forwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Forwarder{
|
||||
sourcePort: sourcePort,
|
||||
targetAddr: fmt.Sprintf("%s:%d", targetIP, targetPort),
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
useTunnel: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTunnelForwarder 创建使用隧道的端口转发器
|
||||
func NewTunnelForwarder(sourcePort int, targetPort int, tunnelConn net.Conn) *Forwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Forwarder{
|
||||
sourcePort: sourcePort,
|
||||
targetAddr: fmt.Sprintf("127.0.0.1:%d", targetPort),
|
||||
tunnelConn: tunnelConn,
|
||||
useTunnel: true,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动端口转发
|
||||
func (f *Forwarder) Start() error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", f.sourcePort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("监听端口 %d 失败: %w", f.sourcePort, err)
|
||||
}
|
||||
|
||||
f.listener = listener
|
||||
log.Printf("端口转发启动: %d -> %s (tunnel: %v)", f.sourcePort, f.targetAddr, f.useTunnel)
|
||||
|
||||
f.wg.Add(1)
|
||||
go f.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop 接受连接循环
|
||||
func (f *Forwarder) acceptLoop() {
|
||||
defer f.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// 设置接受超时,避免阻塞关闭
|
||||
f.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second))
|
||||
|
||||
conn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
default:
|
||||
log.Printf("接受连接失败 (端口 %d): %v", f.sourcePort, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
f.wg.Add(1)
|
||||
go f.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection 处理单个连接
|
||||
func (f *Forwarder) handleConnection(clientConn net.Conn) {
|
||||
defer f.wg.Done()
|
||||
defer clientConn.Close()
|
||||
|
||||
var targetConn net.Conn
|
||||
var err error
|
||||
|
||||
if f.useTunnel {
|
||||
// 使用隧道连接
|
||||
f.mu.RLock()
|
||||
targetConn = f.tunnelConn
|
||||
f.mu.RUnlock()
|
||||
|
||||
if targetConn == nil {
|
||||
log.Printf("隧道连接不可用 (端口 %d)", f.sourcePort)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 直接连接目标
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
targetConn, err = dialer.DialContext(f.ctx, "tcp", f.targetAddr)
|
||||
if err != nil {
|
||||
log.Printf("连接目标失败 (端口 %d -> %s): %v", f.sourcePort, f.targetAddr, err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
}
|
||||
|
||||
// 双向转发
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
// 客户端 -> 目标
|
||||
go func() {
|
||||
_, err := io.Copy(targetConn, clientConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
// 目标 -> 客户端
|
||||
go func() {
|
||||
_, err := io.Copy(clientConn, targetConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
// 等待任一方向完成或出错
|
||||
select {
|
||||
case <-errChan:
|
||||
// 连接已关闭或出错
|
||||
case <-f.ctx.Done():
|
||||
// 转发器被停止
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止端口转发
|
||||
func (f *Forwarder) Stop() error {
|
||||
f.cancel()
|
||||
|
||||
if f.listener != nil {
|
||||
if err := f.listener.Close(); err != nil {
|
||||
log.Printf("关闭监听器失败 (端口 %d): %v", f.sourcePort, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 等待所有连接处理完成(最多等待5秒)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
f.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Printf("端口转发已停止: %d", f.sourcePort)
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("端口转发停止超时: %d", f.sourcePort)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTunnelConn 设置隧道连接
|
||||
func (f *Forwarder) SetTunnelConn(conn net.Conn) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.tunnelConn = conn
|
||||
}
|
||||
|
||||
// Manager 转发器管理器
|
||||
type Manager struct {
|
||||
forwarders map[int]*Forwarder
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager 创建新的转发器管理器
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
forwarders: make(map[int]*Forwarder),
|
||||
}
|
||||
}
|
||||
|
||||
// Add 添加并启动转发器
|
||||
func (m *Manager) Add(sourcePort int, targetIP string, targetPort int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.forwarders[sourcePort]; exists {
|
||||
return fmt.Errorf("端口 %d 已被占用", sourcePort)
|
||||
}
|
||||
|
||||
forwarder := NewForwarder(sourcePort, targetIP, targetPort)
|
||||
if err := forwarder.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.forwarders[sourcePort] = forwarder
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTunnel 添加使用隧道的转发器
|
||||
func (m *Manager) AddTunnel(sourcePort int, targetPort int, tunnelConn net.Conn) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.forwarders[sourcePort]; exists {
|
||||
return fmt.Errorf("端口 %d 已被占用", sourcePort)
|
||||
}
|
||||
|
||||
forwarder := NewTunnelForwarder(sourcePort, targetPort, tunnelConn)
|
||||
if err := forwarder.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.forwarders[sourcePort] = forwarder
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove 移除并停止转发器
|
||||
func (m *Manager) Remove(sourcePort int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
forwarder, exists := m.forwarders[sourcePort]
|
||||
if !exists {
|
||||
return fmt.Errorf("端口 %d 的转发器不存在", sourcePort)
|
||||
}
|
||||
|
||||
if err := forwarder.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m.forwarders, sourcePort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists 检查转发器是否存在
|
||||
func (m *Manager) Exists(sourcePort int) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, exists := m.forwarders[sourcePort]
|
||||
return exists
|
||||
}
|
||||
|
||||
// StopAll 停止所有转发器
|
||||
func (m *Manager) StopAll() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for port, forwarder := range m.forwarders {
|
||||
if err := forwarder.Stop(); err != nil {
|
||||
log.Printf("停止端口 %d 的转发器失败: %v", port, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.forwarders = make(map[int]*Forwarder)
|
||||
}
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"port-forward/server/api"
|
||||
"port-forward/server/config"
|
||||
"port-forward/server/db"
|
||||
"port-forward/server/forwarder"
|
||||
"port-forward/server/tunnel"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
configPath := flag.String("config", "config.yaml", "配置文件路径")
|
||||
flag.Parse()
|
||||
|
||||
// 加载配置
|
||||
log.Println("加载配置文件...")
|
||||
cfg, err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
log.Println("初始化数据库...")
|
||||
database, err := db.New(cfg.Database.Path)
|
||||
if err != nil {
|
||||
log.Fatalf("初始化数据库失败: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
// 创建转发器管理器
|
||||
log.Println("创建转发器管理器...")
|
||||
fwdManager := forwarder.NewManager()
|
||||
|
||||
// 如果启用隧道,启动隧道服务器
|
||||
var tunnelServer *tunnel.Server
|
||||
if cfg.Tunnel.Enabled {
|
||||
log.Println("启动隧道服务器...")
|
||||
tunnelServer = tunnel.NewServer(cfg.Tunnel.ListenPort)
|
||||
if err := tunnelServer.Start(); err != nil {
|
||||
log.Fatalf("启动隧道服务器失败: %v", err)
|
||||
}
|
||||
defer tunnelServer.Stop()
|
||||
}
|
||||
|
||||
// 从数据库加载现有映射并启动转发器
|
||||
log.Println("加载现有端口映射...")
|
||||
mappings, err := database.GetAllMappings()
|
||||
if err != nil {
|
||||
log.Fatalf("加载端口映射失败: %v", err)
|
||||
}
|
||||
|
||||
for _, mapping := range mappings {
|
||||
// 验证端口在范围内
|
||||
if mapping.SourcePort < cfg.PortRange.From || mapping.SourcePort > cfg.PortRange.End {
|
||||
log.Printf("警告: 端口 %d 超出范围,跳过", mapping.SourcePort)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("恢复端口映射: %d -> %s:%d", mapping.SourcePort, mapping.TargetIP, mapping.TargetPort)
|
||||
if err := fwdManager.Add(mapping.SourcePort, mapping.TargetIP, mapping.TargetPort); err != nil {
|
||||
log.Printf("警告: 启动端口 %d 的转发失败: %v", mapping.SourcePort, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("成功加载 %d 个端口映射", len(mappings))
|
||||
|
||||
// 创建 HTTP API 处理器
|
||||
log.Println("初始化 HTTP API...")
|
||||
apiHandler := api.NewHandler(
|
||||
database,
|
||||
fwdManager,
|
||||
tunnelServer,
|
||||
cfg.PortRange.From,
|
||||
cfg.PortRange.End,
|
||||
cfg.Tunnel.Enabled,
|
||||
)
|
||||
|
||||
// 启动 HTTP API 服务器
|
||||
go func() {
|
||||
if err := api.Start(apiHandler, cfg.API.ListenPort); err != nil {
|
||||
log.Fatalf("启动 HTTP API 服务失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Println("===========================================")
|
||||
log.Printf("服务器启动成功!")
|
||||
log.Printf("端口范围: %d-%d", cfg.PortRange.From, cfg.PortRange.End)
|
||||
log.Printf("HTTP API: http://localhost:%d", cfg.API.ListenPort)
|
||||
if cfg.Tunnel.Enabled {
|
||||
log.Printf("隧道服务: 端口 %d", cfg.Tunnel.ListenPort)
|
||||
}
|
||||
log.Println("===========================================")
|
||||
|
||||
// 等待中断信号
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
<-sigChan
|
||||
log.Println("\n接收到关闭信号,正在优雅关闭...")
|
||||
|
||||
// 创建关闭上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 停止所有转发器
|
||||
log.Println("停止所有端口转发...")
|
||||
fwdManager.StopAll()
|
||||
|
||||
log.Println("服务器已关闭")
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Protocol 定义隧道协议
|
||||
// 消息格式: [4字节长度][4字节端口][数据]
|
||||
|
||||
const (
|
||||
// HeaderSize 消息头大小(长度+端口)
|
||||
HeaderSize = 8
|
||||
// MaxPacketSize 最大包大小 (1MB)
|
||||
MaxPacketSize = 1024 * 1024
|
||||
)
|
||||
|
||||
// Server 内网穿透服务器
|
||||
type Server struct {
|
||||
listenPort int
|
||||
listener net.Listener
|
||||
client net.Conn
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
|
||||
// 连接管理
|
||||
connections map[uint32]*Connection
|
||||
connMu sync.RWMutex
|
||||
nextConnID uint32
|
||||
}
|
||||
|
||||
// Connection 表示一个客户端连接
|
||||
type Connection struct {
|
||||
ID uint32
|
||||
TargetPort int
|
||||
ClientConn net.Conn
|
||||
readChan chan []byte
|
||||
writeChan chan []byte
|
||||
closeChan chan struct{}
|
||||
}
|
||||
|
||||
// NewServer 创建新的隧道服务器
|
||||
func NewServer(listenPort int) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Server{
|
||||
listenPort: listenPort,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
connections: make(map[uint32]*Connection),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动隧道服务器
|
||||
func (s *Server) Start() error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.listenPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("启动隧道服务器失败: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
log.Printf("隧道服务器启动: 端口 %d", s.listenPort)
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop 接受客户端连接
|
||||
func (s *Server) acceptLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
s.listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second))
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
log.Printf("接受隧道连接失败: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 只允许一个客户端连接
|
||||
s.mu.Lock()
|
||||
if s.client != nil {
|
||||
log.Printf("拒绝额外的隧道连接: %s", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
s.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
s.client = conn
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Printf("隧道客户端已连接: %s", conn.RemoteAddr())
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.handleClient(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleClient 处理客户端连接
|
||||
func (s *Server) handleClient(conn net.Conn) {
|
||||
defer s.wg.Done()
|
||||
defer func() {
|
||||
conn.Close()
|
||||
s.mu.Lock()
|
||||
s.client = nil
|
||||
s.mu.Unlock()
|
||||
log.Printf("隧道客户端已断开")
|
||||
|
||||
// 关闭所有活动连接
|
||||
s.connMu.Lock()
|
||||
for _, c := range s.connections {
|
||||
close(c.closeChan)
|
||||
}
|
||||
s.connections = make(map[uint32]*Connection)
|
||||
s.connMu.Unlock()
|
||||
}()
|
||||
|
||||
// 读取来自客户端的数据
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// 读取消息头
|
||||
header := make([]byte, HeaderSize)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
if err != io.EOF {
|
||||
log.Printf("读取隧道消息头失败: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
dataLen := binary.BigEndian.Uint32(header[0:4])
|
||||
connID := binary.BigEndian.Uint32(header[4:8])
|
||||
|
||||
if dataLen > MaxPacketSize {
|
||||
log.Printf("数据包过大: %d bytes", dataLen)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
data := make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(conn, data); err != nil {
|
||||
log.Printf("读取隧道数据失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 将数据发送到对应的连接
|
||||
s.connMu.RLock()
|
||||
connection, exists := s.connections[connID]
|
||||
s.connMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
select {
|
||||
case connection.readChan <- data:
|
||||
case <-connection.closeChan:
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("向连接 %d 发送数据超时", connID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardConnection 转发连接到隧道
|
||||
func (s *Server) ForwardConnection(clientConn net.Conn, targetPort int) error {
|
||||
s.mu.RLock()
|
||||
tunnelConn := s.client
|
||||
s.mu.RUnlock()
|
||||
|
||||
if tunnelConn == nil {
|
||||
return fmt.Errorf("隧道连接不可用")
|
||||
}
|
||||
|
||||
// 创建连接对象
|
||||
s.connMu.Lock()
|
||||
connID := s.nextConnID
|
||||
s.nextConnID++
|
||||
|
||||
connection := &Connection{
|
||||
ID: connID,
|
||||
TargetPort: targetPort,
|
||||
ClientConn: clientConn,
|
||||
readChan: make(chan []byte, 100),
|
||||
writeChan: make(chan []byte, 100),
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
s.connections[connID] = connection
|
||||
s.connMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.connMu.Lock()
|
||||
delete(s.connections, connID)
|
||||
s.connMu.Unlock()
|
||||
close(connection.closeChan)
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
// 启动读写协程
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
// 从客户端读取并发送到隧道
|
||||
go func() {
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-connection.closeChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
clientConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
n, err := clientConn.Read(buffer)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// 发送到隧道
|
||||
data := make([]byte, HeaderSize+n)
|
||||
binary.BigEndian.PutUint32(data[0:4], uint32(n))
|
||||
binary.BigEndian.PutUint32(data[4:8], connID)
|
||||
copy(data[HeaderSize:], buffer[:n])
|
||||
|
||||
s.mu.RLock()
|
||||
_, err = tunnelConn.Write(data)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 从隧道读取并发送到客户端
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case data := <-connection.readChan:
|
||||
if _, err := clientConn.Write(data); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
case <-connection.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待错误或关闭
|
||||
select {
|
||||
case <-errChan:
|
||||
case <-connection.closeChan:
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected 检查隧道是否已连接
|
||||
func (s *Server) IsConnected() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.client != nil
|
||||
}
|
||||
|
||||
// Stop 停止隧道服务器
|
||||
func (s *Server) Stop() error {
|
||||
s.cancel()
|
||||
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if s.client != nil {
|
||||
s.client.Close()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// 等待所有协程结束
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Printf("隧道服务器已停止")
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("隧道服务器停止超时")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Loading…
Reference in New Issue