diff --git a/backend-service/cmd/main.go b/backend-service/cmd/main.go index be80db2..a352b0c 100644 --- a/backend-service/cmd/main.go +++ b/backend-service/cmd/main.go @@ -10,10 +10,12 @@ import ( "strings" "syscall" "time" + "version/pkg/handler" client "backend-service/pkg/api" "backend-service/pkg/machine" "backend-service/pkg/service" + "backend-service/pkg/version" "bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/logger" @@ -61,6 +63,10 @@ var forceCmd = &cobra.Command{ logger.Error("Error stopping daemon.") return } + if err := run(); err != nil { + logger.Error("Application failed: %v", err) + os.Exit(1) + } }, } @@ -76,11 +82,82 @@ var quitCmd = &cobra.Command{ }, } +var updateCmd = &cobra.Command{ + Use: "update", + Short: "Update the application", + Run: func(cmd *cobra.Command, args []string) { + logger.Debug("Update command executed.") + need, err := version.CheckUpdate() + if err != nil { + logger.Error("Error checking update: %v", err) + return + } + if need { + logger.Info("Update needed. Restarting the application...") + + executable, err := os.Executable() + if err != nil { + logger.Error("Error getting executable path: %v", err) + return + } + + args := os.Args + env := os.Environ() + + cmd := exec.Command(executable, args[1:]...) + cmd.Env = env + + err = cmd.Start() + if err != nil { + logger.Error("Error restarting the application: %v", err) + return + } + + logger.Info("Application restarted successfully.") + os.Exit(0) + } + }, +} + func init() { - cobra.OnInitialize(initConfig) + cobra.OnInitialize(initConfig, checkAndUpdate) rootCmd.AddCommand(daemonCmd) rootCmd.AddCommand(forceCmd) rootCmd.AddCommand(quitCmd) + rootCmd.AddCommand(updateCmd) +} + +func checkAndUpdate() { + exePath, err := os.Executable() + if err != nil { + logger.Error("Error getting executable path: %v", err) + return + } + info, err := handler.LoadUpdateInfo(exePath) + if err != nil { + logger.Debug("Error loading update info: %v", err) + return + } + + exeDir := filepath.Dir(exePath) + // 先更新so文件 + logger.Debug("Updating so file...") + err = handler.ExecuteUpdate(exeDir, info.SoFileNewPath, info.SoFileOldPath, info.SoTargetPath, false) + if err != nil { + logger.Error("Error updating so file: %v", err) + // 删除info文件,下次再试 + handler.RemoveUpdateInfo(exePath) + return + } + // 更新exe程序 + logger.Debug("Updating exe file...") + err = handler.ExecuteUpdate(exeDir, info.ExeFileNewPath, info.ExeFileOldPath, info.ExeTargetPath, true) + if err != nil { + logger.Error("Error updating exe file: %v", err) + // 删除info文件,下次再试 + handler.RemoveUpdateInfo(exePath) + return + } } func main() { @@ -258,6 +335,8 @@ func startDaemonProcess() error { } func runDaemon() { + // daemon的默认日志等级 + logger.SetLevel(logger.DEBUG) logger.Info("Daemon started with PID: %d", os.Getpid()) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT) diff --git a/backend-service/pkg/api/api.go b/backend-service/pkg/api/api.go index 4bc316e..5232dc6 100644 --- a/backend-service/pkg/api/api.go +++ b/backend-service/pkg/api/api.go @@ -3,46 +3,22 @@ package client import ( "bash_go_service/shared/pkg/client" "bash_go_service/shared/pkg/constants" - "bash_go_service/shared/pkg/logger" "fmt" "backend-service/pkg/machine" - - "github.com/spf13/viper" ) -const ( - configKey = "machine_registry.endpoint" -) - -var ( - apiEndpoint string // 从常量改为变量 -) - -func init() { - viper.SetDefault(configKey, "none") -} - type Result struct { Success bool `json:"success"` // 使用大写字段名,并添加json tag Msg string `json:"msg"` } func SendMachineInfo(info *machine.Info) (Result, error) { - emptyResult := Result{ - Success: false, - } - // 如果是none直接打一个log之后返回 - if viper.GetString(configKey) == "none" { - logger.Info("Machine info: %+v", info) - return emptyResult, nil - } var result Result client := client.NewClient() - client.Post(constants.MachineInfoApi, info, nil, &result) - - if !result.Success { - return emptyResult, fmt.Errorf("send machine info failed") + err := client.Post(constants.MachineInfoApi, info, nil, &result) + if err != nil { + return result, fmt.Errorf("send machine info failed: %w", err) } return result, nil } diff --git a/backend-service/pkg/service/version.go b/backend-service/pkg/service/version.go index 2777950..f6cc7ef 100644 --- a/backend-service/pkg/service/version.go +++ b/backend-service/pkg/service/version.go @@ -1,12 +1,9 @@ package service import ( - "bash_go_service/shared/pkg/constants" + "backend-service/pkg/version" "bash_go_service/shared/pkg/logger" - "bash_go_service/version/pkg/handler" "context" - "os" - "path/filepath" "time" "github.com/spf13/viper" @@ -38,7 +35,7 @@ func (t *VersionCheckTask) Execute(ctx context.Context) { for { select { case <-ticker.C: - t.CheckUpdate(ctx) + t.CheckUpdate() case <-ctx.Done(): logger.Info("VersionCheckTask stopped") return @@ -47,96 +44,13 @@ func (t *VersionCheckTask) Execute(ctx context.Context) { } // 检查更新,将文件下载到本地,在下次启动程序时执行更新 -func (t *VersionCheckTask) CheckUpdate(ctx context.Context) { - logger.Info("Executing version update check at %v", time.Now()) - - res, err := handler.CheckVersionUpdate() - needUpdate, newVersion, md5sum, soMd5 := res.NeedUpdate, res.Version, res.ExeMd5, res.SoMd5 +func (t *VersionCheckTask) CheckUpdate() { + needUpdate, err := version.CheckUpdate() if err != nil { - logger.Error("Failed to check for updates: %v", err) + logger.Error("Check Update error: %v", err) return } - - if !needUpdate { - logger.Info("No update needed.") - return - } - - currentExePath, err := os.Executable() - if err != nil { - logger.Error("Failed to get executable path: %v", err) - return - } - currentExeDir := filepath.Dir(currentExePath) - currentExeBase := filepath.Base(currentExePath) - - backupPath, err := handler.BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir) - if err != nil { - logger.Error("Failed to backup current executable: %v", err) - return - } - logger.Info("Backup current executable to: %s", backupPath) - - currentSoPath := filepath.Join(constants.SoPath, constants.SoName) - soBackupPath, err := handler.BackupCurrentExecutable( - currentSoPath, - constants.SoName, currentExeDir) - if err != nil { - logger.Error("Failed to backup so-lib: %v", err) - return - } - logger.Info("Backup current so-lib to: %s", soBackupPath) - - // 下载最新的exe,并且校验md5 - newExePath, err := handler.DownloadNewExecutable(constants.SoPath, currentExeBase, newVersion, "exe") - if err != nil { - logger.Error("Failed to download new executable: %v", err) - _ = os.Remove(backupPath) - return - } - - if err := handler.VerifyMD5(newExePath, md5sum); err != nil { - logger.Error("MD5 verification failed: %v", err) - // Attempt rollback - _ = os.Remove(newExePath) - return - } - - // 下载最新的so,并且校验md5 - newSoPath, err := handler.DownloadNewExecutable(constants.SoPath, constants.SoName, newVersion, "so") - if err != nil { - logger.Error("Failed to download new so-lib: %v", err) - // Attempt rollback - _ = os.Remove(soBackupPath) - return - } - if err := handler.VerifyMD5(newSoPath, soMd5); err != nil { - logger.Error("MD5 verification failed: %v", err) - // Attempt rollback - _ = os.Remove(newSoPath) - return - } - - info := handler.CreateUpdateInfo(backupPath, newExePath, currentExePath, soBackupPath, newSoPath, currentSoPath) - err = handler.SaveUpdateInfo(currentExePath, *info) - if err != nil { - logger.Error("Failed to save update info: %v", err) - return - } - logger.Info("The update will be initiated on the next application start.") - - // 下次启动的时候检查更新 - - // if err := handler.ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil { - // logger.Error("Failed to execute update: %v", err) - // // Attempt rollback - // _ = os.Rename(backupPath, currentExePath) - // _ = os.Remove(newExePath) - // return - // } - - // logger.Info("Update process initiated, current process exiting.") - // + logger.Info("Check Update result is: %v", needUpdate) } // Name returns the name of the task. diff --git a/backend-service/pkg/version/version.go b/backend-service/pkg/version/version.go new file mode 100644 index 0000000..10d3d9e --- /dev/null +++ b/backend-service/pkg/version/version.go @@ -0,0 +1,103 @@ +package version + +import ( + "bash_go_service/shared/pkg/constants" + "bash_go_service/shared/pkg/logger" + "os" + "path/filepath" + "time" + "version/pkg/handler" +) + +func CheckUpdate() (bool, error) { + logger.Info("Executing version update check at %v", time.Now()) + + res, err := handler.CheckVersionUpdate() + needUpdate, newVersion, md5sum, soMd5 := res.NeedUpdate, res.Version, res.ExeMd5, res.SoMd5 + if err != nil { + logger.Error("Failed to check for updates: %v", err) + return false, err + } + + if !needUpdate { + logger.Info("No update needed.") + return false, nil + } + + currentExePath, err := os.Executable() + if err != nil { + logger.Error("Failed to get executable path: %v", err) + return false, err + } + currentExeDir := filepath.Dir(currentExePath) + currentExeBase := filepath.Base(currentExePath) + + backupPath, err := handler.BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir) + if err != nil { + logger.Error("Failed to backup current executable: %v", err) + return false, err + } + logger.Info("Backup current executable to: %s", backupPath) + + currentSoPath := filepath.Join(constants.SoPath, constants.SoName) + soBackupPath, err := handler.BackupCurrentExecutable( + currentSoPath, + constants.SoName, currentExeDir) + if err != nil { + logger.Error("Failed to backup so-lib: %v", err) + return false, err + } + logger.Info("Backup current so-lib to: %s", soBackupPath) + + // 下载最新的exe,并且校验md5 + newExePath, err := handler.DownloadNewExecutable(constants.SoPath, currentExeBase, newVersion, "exe") + if err != nil { + logger.Error("Failed to download new executable: %v", err) + _ = os.Remove(backupPath) + return false, err + } + + if err := handler.VerifyMD5(newExePath, md5sum); err != nil { + logger.Error("MD5 verification failed: %v", err) + // Attempt rollback + _ = os.Remove(newExePath) + return false, err + } + + // 下载最新的so,并且校验md5 + newSoPath, err := handler.DownloadNewExecutable(constants.SoPath, constants.SoName, newVersion, "so") + if err != nil { + logger.Error("Failed to download new so-lib: %v", err) + // Attempt rollback + _ = os.Remove(soBackupPath) + return false, err + } + if err := handler.VerifyMD5(newSoPath, soMd5); err != nil { + logger.Error("MD5 verification failed: %v", err) + // Attempt rollback + _ = os.Remove(newSoPath) + return false, err + } + + info := handler.CreateUpdateInfo(backupPath, newExePath, currentExePath, soBackupPath, newSoPath, currentSoPath) + err = handler.SaveUpdateInfo(currentExePath, *info) + if err != nil { + logger.Error("Failed to save update info: %v", err) + return false, err + } + logger.Info("The update will be initiated on the next application start.") + + // 下次启动的时候检查更新 + + // if err := handler.ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil { + // logger.Error("Failed to execute update: %v", err) + // // Attempt rollback + // _ = os.Rename(backupPath, currentExePath) + // _ = os.Remove(newExePath) + // return + // } + + // logger.Info("Update process initiated, current process exiting.") + // + return true, nil +} diff --git a/config-loader/internal/shm/shm.go b/config-loader/internal/shm/shm.go index 0cb8463..266200c 100644 --- a/config-loader/internal/shm/shm.go +++ b/config-loader/internal/shm/shm.go @@ -53,7 +53,21 @@ func WriteConfigToSharedMemory(config *models.ConfigData) error { // 创建或获取共享内存段 shmID, err := shmget(constants.ShmKey, constants.ShmSize, unix.IPC_CREAT|0666) if err != nil { - logger.Error("shmget failed: %v", err) + // 获取更详细的错误信息 + switch err { + case unix.EACCES: + logger.Error("Permission denied: %v", err) + case unix.EEXIST: + logger.Error("Shared memory segment already exists: %v", err) + case unix.EINVAL: + logger.Error("Invalid size or key: %v", err) + case unix.ENOENT: + logger.Error("Shared memory segment does not exist: %v", err) + case unix.ENOMEM: + logger.Error("No memory available: %v", err) + default: + logger.Error("Unknown error: %v", err) + } return fmt.Errorf("shmget failed: %w", err) } @@ -84,7 +98,7 @@ func ReadConfigFromSharedMemory() (*models.ConfigData, error) { logger.Debug("Starting to read configuration from shared memory") // 获取共享内存段 - shmID, err := shmget(constants.ShmKey, constants.ShmSize, 0) + shmID, err := shmget(constants.ShmKey, constants.ShmSize, unix.IPC_CREAT|0640) if err != nil { return nil, fmt.Errorf("shmget failed: %w", err) } diff --git a/config-loader/pkg/manager/config_manager.go b/config-loader/pkg/manager/config_manager.go index 7f55468..8b24d27 100644 --- a/config-loader/pkg/manager/config_manager.go +++ b/config-loader/pkg/manager/config_manager.go @@ -46,7 +46,7 @@ func (cm *ConfigManager) Initialize() (bool, error) { cm.currentConfig = config logger.Info("Loaded existing configuration from shared memory. Enabled: %v, Rule Count: %d", config.Enabled, config.RuleCount) - if config.RuleCount == 0 { + if config.RuleCount == 0 || !config.Enabled { cm.ForceSync() // 如果规则数量为 0,强制同步配置 } } else { @@ -106,6 +106,7 @@ func (cm *ConfigManager) GetCurrentConfig() *models.ConfigData { func (cm *ConfigManager) ForceSync() error { formFile, err := cm.syncFromFile() logger.Info("Force syncing configuration, fromFile: %v", formFile) + logger.Debug("current config count %v", cm.currentConfig.RuleCount) return err } @@ -156,8 +157,10 @@ func (cm *ConfigManager) syncFromFile() (bool, error) { } cm.currentConfig = cConfig - if fileInfo, err := os.Stat(cm.configFile); err == nil { - cm.lastModified = fileInfo.ModTime() // 更新最后修改时间 + if loader == "FILE" { + if fileInfo, err := os.Stat(cm.configFile); err == nil { + cm.lastModified = fileInfo.ModTime() // 更新最后修改时间 + } } return fromFile, nil diff --git a/go.work.sum b/go.work.sum index 6f7619f..6573245 100644 --- a/go.work.sum +++ b/go.work.sum @@ -45,6 +45,7 @@ golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= google.golang.org/api v0.215.0/go.mod h1:fta3CVtuJYOEdugLNWm6WodzOS8KdFckABwN4I40hzY= diff --git a/shared/pkg/client/client.go b/shared/pkg/client/client.go index bec3468..c040700 100644 --- a/shared/pkg/client/client.go +++ b/shared/pkg/client/client.go @@ -51,6 +51,10 @@ func WithBaseURL(baseURL string) ClientOption { } } +func init() { + viper.SetDefault("machine_registry.endpoint", "http://localhost:3001/endpoint") +} + func NewClient(opts ...ClientOption) *Client { ctx, cancel := context.WithCancel(context.Background()) diff --git a/shared/pkg/constants/constants.go b/shared/pkg/constants/constants.go index 5b132c9..f1d090b 100644 --- a/shared/pkg/constants/constants.go +++ b/shared/pkg/constants/constants.go @@ -1,14 +1,14 @@ package constants const ( - MaxRules = 128 // 最大规则数量 - MaxArgs = 10 // 每条规则的最大参数数量 - MaxRuleCmdLength = 256 // 规则命令的最大长度 - MaxRuleTypeLength = 32 // 规则类型的最大长度 - MaxRuleMsgLength = 1024 // 规则消息的最大长度 - MaxArgLength = 256 // 单个参数的最大长度 - ShmKey = 0x78945 // 共享内存的键值 - ShmSize = 512 * 1024 // 共享内存的大小(字节) + MaxRules = 128 // 最大规则数量 + MaxArgs = 32 // 每条规则的最大参数数量 + MaxRuleCmdLength = 256 // 规则命令的最大长度 + MaxRuleTypeLength = 32 // 规则类型的最大长度 + MaxRuleMsgLength = 256 // 规则消息的最大长度 + MaxArgLength = 128 // 单个参数的最大长度 + ShmKey = 0x78945 // 共享内存的键值 + ShmSize = 1024 * 1024 // 共享内存的大小(字节) SoPath = "/tmp/bash_hook" // 共享库路径 SoName = "intercept.so" // Hook 文件名称 diff --git a/shared/pkg/logger/logger.go b/shared/pkg/logger/logger.go index 2b5f554..12262c1 100644 --- a/shared/pkg/logger/logger.go +++ b/shared/pkg/logger/logger.go @@ -132,3 +132,7 @@ func Warn(format string, v ...interface{}) { warningLogger.Printf(format, v...) } } + +func SetLevel(level LogLevel) { + currentLevel = level +} diff --git a/version/pkg/handler/handler.go b/version/pkg/handler/handler.go index ed2716a..5c5ff1c 100644 --- a/version/pkg/handler/handler.go +++ b/version/pkg/handler/handler.go @@ -80,7 +80,7 @@ func SaveUpdateInfo(exePath string, info UpdateInfo) error { // LoadUpdateInfo 函数实现 func LoadUpdateInfo(exePath string) (*UpdateInfo, error) { // 构建文件路径 - filePath := filepath.Join(exePath, "update-infos") + filePath := filepath.Join(exePath, constants.UpdateInfoFileName) // 从文件中读取二进制数据 data, err := os.ReadFile(filePath) @@ -100,6 +100,19 @@ func LoadUpdateInfo(exePath string) (*UpdateInfo, error) { return &info, nil } +func RemoveUpdateInfo(exePath string) error { + // 构建文件路径 + filePath := filepath.Join(exePath, constants.UpdateInfoFileName) + + // 删除文件 + err := os.Remove(filePath) + if err != nil { + return fmt.Errorf("error removing file: %v", err) + } + + return nil +} + // DoUpdate checks for updates, downloads if needed, and executes the update. func DoUpdate() { logger.Debug("start check updating")