部分代码调整位置,修改了参数量大小

This commit is contained in:
Pan Qiancheng 2025-04-23 17:45:42 +08:00
parent 851c5e538d
commit 7d77c78db8
11 changed files with 245 additions and 134 deletions

View File

@ -10,10 +10,12 @@ import (
"strings" "strings"
"syscall" "syscall"
"time" "time"
"version/pkg/handler"
client "backend-service/pkg/api" client "backend-service/pkg/api"
"backend-service/pkg/machine" "backend-service/pkg/machine"
"backend-service/pkg/service" "backend-service/pkg/service"
"backend-service/pkg/version"
"bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger" "bash_go_service/shared/pkg/logger"
@ -61,6 +63,10 @@ var forceCmd = &cobra.Command{
logger.Error("Error stopping daemon.") logger.Error("Error stopping daemon.")
return return
} }
if err := run(); err != nil {
logger.Error("Application failed: %v", err)
os.Exit(1)
}
}, },
} }
@ -76,11 +82,82 @@ var quitCmd = &cobra.Command{
}, },
} }
var updateCmd = &cobra.Command{
Use: "update",
Short: "Update the application",
Run: func(cmd *cobra.Command, args []string) {
logger.Debug("Update command executed.")
need, err := version.CheckUpdate()
if err != nil {
logger.Error("Error checking update: %v", err)
return
}
if need {
logger.Info("Update needed. Restarting the application...")
executable, err := os.Executable()
if err != nil {
logger.Error("Error getting executable path: %v", err)
return
}
args := os.Args
env := os.Environ()
cmd := exec.Command(executable, args[1:]...)
cmd.Env = env
err = cmd.Start()
if err != nil {
logger.Error("Error restarting the application: %v", err)
return
}
logger.Info("Application restarted successfully.")
os.Exit(0)
}
},
}
func init() { func init() {
cobra.OnInitialize(initConfig) cobra.OnInitialize(initConfig, checkAndUpdate)
rootCmd.AddCommand(daemonCmd) rootCmd.AddCommand(daemonCmd)
rootCmd.AddCommand(forceCmd) rootCmd.AddCommand(forceCmd)
rootCmd.AddCommand(quitCmd) rootCmd.AddCommand(quitCmd)
rootCmd.AddCommand(updateCmd)
}
func checkAndUpdate() {
exePath, err := os.Executable()
if err != nil {
logger.Error("Error getting executable path: %v", err)
return
}
info, err := handler.LoadUpdateInfo(exePath)
if err != nil {
logger.Debug("Error loading update info: %v", err)
return
}
exeDir := filepath.Dir(exePath)
// 先更新so文件
logger.Debug("Updating so file...")
err = handler.ExecuteUpdate(exeDir, info.SoFileNewPath, info.SoFileOldPath, info.SoTargetPath, false)
if err != nil {
logger.Error("Error updating so file: %v", err)
// 删除info文件下次再试
handler.RemoveUpdateInfo(exePath)
return
}
// 更新exe程序
logger.Debug("Updating exe file...")
err = handler.ExecuteUpdate(exeDir, info.ExeFileNewPath, info.ExeFileOldPath, info.ExeTargetPath, true)
if err != nil {
logger.Error("Error updating exe file: %v", err)
// 删除info文件下次再试
handler.RemoveUpdateInfo(exePath)
return
}
} }
func main() { func main() {
@ -258,6 +335,8 @@ func startDaemonProcess() error {
} }
func runDaemon() { func runDaemon() {
// daemon的默认日志等级
logger.SetLevel(logger.DEBUG)
logger.Info("Daemon started with PID: %d", os.Getpid()) logger.Info("Daemon started with PID: %d", os.Getpid())
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT) signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)

View File

@ -3,46 +3,22 @@ package client
import ( import (
"bash_go_service/shared/pkg/client" "bash_go_service/shared/pkg/client"
"bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger"
"fmt" "fmt"
"backend-service/pkg/machine" "backend-service/pkg/machine"
"github.com/spf13/viper"
) )
const (
configKey = "machine_registry.endpoint"
)
var (
apiEndpoint string // 从常量改为变量
)
func init() {
viper.SetDefault(configKey, "none")
}
type Result struct { type Result struct {
Success bool `json:"success"` // 使用大写字段名并添加json tag Success bool `json:"success"` // 使用大写字段名并添加json tag
Msg string `json:"msg"` Msg string `json:"msg"`
} }
func SendMachineInfo(info *machine.Info) (Result, error) { func SendMachineInfo(info *machine.Info) (Result, error) {
emptyResult := Result{
Success: false,
}
// 如果是none直接打一个log之后返回
if viper.GetString(configKey) == "none" {
logger.Info("Machine info: %+v", info)
return emptyResult, nil
}
var result Result var result Result
client := client.NewClient() client := client.NewClient()
client.Post(constants.MachineInfoApi, info, nil, &result) err := client.Post(constants.MachineInfoApi, info, nil, &result)
if err != nil {
if !result.Success { return result, fmt.Errorf("send machine info failed: %w", err)
return emptyResult, fmt.Errorf("send machine info failed")
} }
return result, nil return result, nil
} }

View File

@ -1,12 +1,9 @@
package service package service
import ( import (
"bash_go_service/shared/pkg/constants" "backend-service/pkg/version"
"bash_go_service/shared/pkg/logger" "bash_go_service/shared/pkg/logger"
"bash_go_service/version/pkg/handler"
"context" "context"
"os"
"path/filepath"
"time" "time"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -38,7 +35,7 @@ func (t *VersionCheckTask) Execute(ctx context.Context) {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
t.CheckUpdate(ctx) t.CheckUpdate()
case <-ctx.Done(): case <-ctx.Done():
logger.Info("VersionCheckTask stopped") logger.Info("VersionCheckTask stopped")
return return
@ -47,96 +44,13 @@ func (t *VersionCheckTask) Execute(ctx context.Context) {
} }
// 检查更新,将文件下载到本地,在下次启动程序时执行更新 // 检查更新,将文件下载到本地,在下次启动程序时执行更新
func (t *VersionCheckTask) CheckUpdate(ctx context.Context) { func (t *VersionCheckTask) CheckUpdate() {
logger.Info("Executing version update check at %v", time.Now()) needUpdate, err := version.CheckUpdate()
res, err := handler.CheckVersionUpdate()
needUpdate, newVersion, md5sum, soMd5 := res.NeedUpdate, res.Version, res.ExeMd5, res.SoMd5
if err != nil { if err != nil {
logger.Error("Failed to check for updates: %v", err) logger.Error("Check Update error: %v", err)
return return
} }
logger.Info("Check Update result is: %v", needUpdate)
if !needUpdate {
logger.Info("No update needed.")
return
}
currentExePath, err := os.Executable()
if err != nil {
logger.Error("Failed to get executable path: %v", err)
return
}
currentExeDir := filepath.Dir(currentExePath)
currentExeBase := filepath.Base(currentExePath)
backupPath, err := handler.BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir)
if err != nil {
logger.Error("Failed to backup current executable: %v", err)
return
}
logger.Info("Backup current executable to: %s", backupPath)
currentSoPath := filepath.Join(constants.SoPath, constants.SoName)
soBackupPath, err := handler.BackupCurrentExecutable(
currentSoPath,
constants.SoName, currentExeDir)
if err != nil {
logger.Error("Failed to backup so-lib: %v", err)
return
}
logger.Info("Backup current so-lib to: %s", soBackupPath)
// 下载最新的exe并且校验md5
newExePath, err := handler.DownloadNewExecutable(constants.SoPath, currentExeBase, newVersion, "exe")
if err != nil {
logger.Error("Failed to download new executable: %v", err)
_ = os.Remove(backupPath)
return
}
if err := handler.VerifyMD5(newExePath, md5sum); err != nil {
logger.Error("MD5 verification failed: %v", err)
// Attempt rollback
_ = os.Remove(newExePath)
return
}
// 下载最新的so并且校验md5
newSoPath, err := handler.DownloadNewExecutable(constants.SoPath, constants.SoName, newVersion, "so")
if err != nil {
logger.Error("Failed to download new so-lib: %v", err)
// Attempt rollback
_ = os.Remove(soBackupPath)
return
}
if err := handler.VerifyMD5(newSoPath, soMd5); err != nil {
logger.Error("MD5 verification failed: %v", err)
// Attempt rollback
_ = os.Remove(newSoPath)
return
}
info := handler.CreateUpdateInfo(backupPath, newExePath, currentExePath, soBackupPath, newSoPath, currentSoPath)
err = handler.SaveUpdateInfo(currentExePath, *info)
if err != nil {
logger.Error("Failed to save update info: %v", err)
return
}
logger.Info("The update will be initiated on the next application start.")
// 下次启动的时候检查更新
// if err := handler.ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil {
// logger.Error("Failed to execute update: %v", err)
// // Attempt rollback
// _ = os.Rename(backupPath, currentExePath)
// _ = os.Remove(newExePath)
// return
// }
// logger.Info("Update process initiated, current process exiting.")
//
} }
// Name returns the name of the task. // Name returns the name of the task.

View File

@ -0,0 +1,103 @@
package version
import (
"bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger"
"os"
"path/filepath"
"time"
"version/pkg/handler"
)
func CheckUpdate() (bool, error) {
logger.Info("Executing version update check at %v", time.Now())
res, err := handler.CheckVersionUpdate()
needUpdate, newVersion, md5sum, soMd5 := res.NeedUpdate, res.Version, res.ExeMd5, res.SoMd5
if err != nil {
logger.Error("Failed to check for updates: %v", err)
return false, err
}
if !needUpdate {
logger.Info("No update needed.")
return false, nil
}
currentExePath, err := os.Executable()
if err != nil {
logger.Error("Failed to get executable path: %v", err)
return false, err
}
currentExeDir := filepath.Dir(currentExePath)
currentExeBase := filepath.Base(currentExePath)
backupPath, err := handler.BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir)
if err != nil {
logger.Error("Failed to backup current executable: %v", err)
return false, err
}
logger.Info("Backup current executable to: %s", backupPath)
currentSoPath := filepath.Join(constants.SoPath, constants.SoName)
soBackupPath, err := handler.BackupCurrentExecutable(
currentSoPath,
constants.SoName, currentExeDir)
if err != nil {
logger.Error("Failed to backup so-lib: %v", err)
return false, err
}
logger.Info("Backup current so-lib to: %s", soBackupPath)
// 下载最新的exe并且校验md5
newExePath, err := handler.DownloadNewExecutable(constants.SoPath, currentExeBase, newVersion, "exe")
if err != nil {
logger.Error("Failed to download new executable: %v", err)
_ = os.Remove(backupPath)
return false, err
}
if err := handler.VerifyMD5(newExePath, md5sum); err != nil {
logger.Error("MD5 verification failed: %v", err)
// Attempt rollback
_ = os.Remove(newExePath)
return false, err
}
// 下载最新的so并且校验md5
newSoPath, err := handler.DownloadNewExecutable(constants.SoPath, constants.SoName, newVersion, "so")
if err != nil {
logger.Error("Failed to download new so-lib: %v", err)
// Attempt rollback
_ = os.Remove(soBackupPath)
return false, err
}
if err := handler.VerifyMD5(newSoPath, soMd5); err != nil {
logger.Error("MD5 verification failed: %v", err)
// Attempt rollback
_ = os.Remove(newSoPath)
return false, err
}
info := handler.CreateUpdateInfo(backupPath, newExePath, currentExePath, soBackupPath, newSoPath, currentSoPath)
err = handler.SaveUpdateInfo(currentExePath, *info)
if err != nil {
logger.Error("Failed to save update info: %v", err)
return false, err
}
logger.Info("The update will be initiated on the next application start.")
// 下次启动的时候检查更新
// if err := handler.ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil {
// logger.Error("Failed to execute update: %v", err)
// // Attempt rollback
// _ = os.Rename(backupPath, currentExePath)
// _ = os.Remove(newExePath)
// return
// }
// logger.Info("Update process initiated, current process exiting.")
//
return true, nil
}

View File

@ -53,7 +53,21 @@ func WriteConfigToSharedMemory(config *models.ConfigData) error {
// 创建或获取共享内存段 // 创建或获取共享内存段
shmID, err := shmget(constants.ShmKey, constants.ShmSize, unix.IPC_CREAT|0666) shmID, err := shmget(constants.ShmKey, constants.ShmSize, unix.IPC_CREAT|0666)
if err != nil { if err != nil {
logger.Error("shmget failed: %v", err) // 获取更详细的错误信息
switch err {
case unix.EACCES:
logger.Error("Permission denied: %v", err)
case unix.EEXIST:
logger.Error("Shared memory segment already exists: %v", err)
case unix.EINVAL:
logger.Error("Invalid size or key: %v", err)
case unix.ENOENT:
logger.Error("Shared memory segment does not exist: %v", err)
case unix.ENOMEM:
logger.Error("No memory available: %v", err)
default:
logger.Error("Unknown error: %v", err)
}
return fmt.Errorf("shmget failed: %w", err) return fmt.Errorf("shmget failed: %w", err)
} }
@ -84,7 +98,7 @@ func ReadConfigFromSharedMemory() (*models.ConfigData, error) {
logger.Debug("Starting to read configuration from shared memory") logger.Debug("Starting to read configuration from shared memory")
// 获取共享内存段 // 获取共享内存段
shmID, err := shmget(constants.ShmKey, constants.ShmSize, 0) shmID, err := shmget(constants.ShmKey, constants.ShmSize, unix.IPC_CREAT|0640)
if err != nil { if err != nil {
return nil, fmt.Errorf("shmget failed: %w", err) return nil, fmt.Errorf("shmget failed: %w", err)
} }

View File

@ -46,7 +46,7 @@ func (cm *ConfigManager) Initialize() (bool, error) {
cm.currentConfig = config cm.currentConfig = config
logger.Info("Loaded existing configuration from shared memory. Enabled: %v, Rule Count: %d", logger.Info("Loaded existing configuration from shared memory. Enabled: %v, Rule Count: %d",
config.Enabled, config.RuleCount) config.Enabled, config.RuleCount)
if config.RuleCount == 0 { if config.RuleCount == 0 || !config.Enabled {
cm.ForceSync() // 如果规则数量为 0强制同步配置 cm.ForceSync() // 如果规则数量为 0强制同步配置
} }
} else { } else {
@ -106,6 +106,7 @@ func (cm *ConfigManager) GetCurrentConfig() *models.ConfigData {
func (cm *ConfigManager) ForceSync() error { func (cm *ConfigManager) ForceSync() error {
formFile, err := cm.syncFromFile() formFile, err := cm.syncFromFile()
logger.Info("Force syncing configuration, fromFile: %v", formFile) logger.Info("Force syncing configuration, fromFile: %v", formFile)
logger.Debug("current config count %v", cm.currentConfig.RuleCount)
return err return err
} }
@ -156,8 +157,10 @@ func (cm *ConfigManager) syncFromFile() (bool, error) {
} }
cm.currentConfig = cConfig cm.currentConfig = cConfig
if fileInfo, err := os.Stat(cm.configFile); err == nil { if loader == "FILE" {
cm.lastModified = fileInfo.ModTime() // 更新最后修改时间 if fileInfo, err := os.Stat(cm.configFile); err == nil {
cm.lastModified = fileInfo.ModTime() // 更新最后修改时间
}
} }
return fromFile, nil return fromFile, nil

View File

@ -45,6 +45,7 @@ golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
google.golang.org/api v0.215.0/go.mod h1:fta3CVtuJYOEdugLNWm6WodzOS8KdFckABwN4I40hzY= google.golang.org/api v0.215.0/go.mod h1:fta3CVtuJYOEdugLNWm6WodzOS8KdFckABwN4I40hzY=

View File

@ -51,6 +51,10 @@ func WithBaseURL(baseURL string) ClientOption {
} }
} }
func init() {
viper.SetDefault("machine_registry.endpoint", "http://localhost:3001/endpoint")
}
func NewClient(opts ...ClientOption) *Client { func NewClient(opts ...ClientOption) *Client {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())

View File

@ -1,14 +1,14 @@
package constants package constants
const ( const (
MaxRules = 128 // 最大规则数量 MaxRules = 128 // 最大规则数量
MaxArgs = 10 // 每条规则的最大参数数量 MaxArgs = 32 // 每条规则的最大参数数量
MaxRuleCmdLength = 256 // 规则命令的最大长度 MaxRuleCmdLength = 256 // 规则命令的最大长度
MaxRuleTypeLength = 32 // 规则类型的最大长度 MaxRuleTypeLength = 32 // 规则类型的最大长度
MaxRuleMsgLength = 1024 // 规则消息的最大长度 MaxRuleMsgLength = 256 // 规则消息的最大长度
MaxArgLength = 256 // 单个参数的最大长度 MaxArgLength = 128 // 单个参数的最大长度
ShmKey = 0x78945 // 共享内存的键值 ShmKey = 0x78945 // 共享内存的键值
ShmSize = 512 * 1024 // 共享内存的大小(字节) ShmSize = 1024 * 1024 // 共享内存的大小(字节)
SoPath = "/tmp/bash_hook" // 共享库路径 SoPath = "/tmp/bash_hook" // 共享库路径
SoName = "intercept.so" // Hook 文件名称 SoName = "intercept.so" // Hook 文件名称

View File

@ -132,3 +132,7 @@ func Warn(format string, v ...interface{}) {
warningLogger.Printf(format, v...) warningLogger.Printf(format, v...)
} }
} }
func SetLevel(level LogLevel) {
currentLevel = level
}

View File

@ -80,7 +80,7 @@ func SaveUpdateInfo(exePath string, info UpdateInfo) error {
// LoadUpdateInfo 函数实现 // LoadUpdateInfo 函数实现
func LoadUpdateInfo(exePath string) (*UpdateInfo, error) { func LoadUpdateInfo(exePath string) (*UpdateInfo, error) {
// 构建文件路径 // 构建文件路径
filePath := filepath.Join(exePath, "update-infos") filePath := filepath.Join(exePath, constants.UpdateInfoFileName)
// 从文件中读取二进制数据 // 从文件中读取二进制数据
data, err := os.ReadFile(filePath) data, err := os.ReadFile(filePath)
@ -100,6 +100,19 @@ func LoadUpdateInfo(exePath string) (*UpdateInfo, error) {
return &info, nil return &info, nil
} }
func RemoveUpdateInfo(exePath string) error {
// 构建文件路径
filePath := filepath.Join(exePath, constants.UpdateInfoFileName)
// 删除文件
err := os.Remove(filePath)
if err != nil {
return fmt.Errorf("error removing file: %v", err)
}
return nil
}
// DoUpdate checks for updates, downloads if needed, and executes the update. // DoUpdate checks for updates, downloads if needed, and executes the update.
func DoUpdate() { func DoUpdate() {
logger.Debug("start check updating") logger.Debug("start check updating")