diff --git a/backend-service/cmd/main.go b/backend-service/cmd/main.go index 897a543..c3c6751 100644 --- a/backend-service/cmd/main.go +++ b/backend-service/cmd/main.go @@ -341,6 +341,5 @@ func runDaemon() { // showWelcomeMessage 显示欢迎信息 func showWelcomeMessage(info *client.Result) { - fmt.Println(info.Msg) } diff --git a/backend-service/go.mod b/backend-service/go.mod index 6eeffd9..707f71f 100644 --- a/backend-service/go.mod +++ b/backend-service/go.mod @@ -21,7 +21,9 @@ require ( bash_go_service/shared v0.0.0-00010101000000-000000000000 bash_go_service/welcome v0.0.0-00010101000000-000000000000 + bash_go_service/version v0.0.0-00010101000000-000000000000 ) replace bash_go_service/shared => ../shared -replace bash_go_service/welcome => ../welcome \ No newline at end of file +replace bash_go_service/welcome => ../welcome +replace bash_go_service/version => ../version \ No newline at end of file diff --git a/backend-service/pkg/service/version.go b/backend-service/pkg/service/version.go new file mode 100644 index 0000000..2777950 --- /dev/null +++ b/backend-service/pkg/service/version.go @@ -0,0 +1,145 @@ +package service + +import ( + "bash_go_service/shared/pkg/constants" + "bash_go_service/shared/pkg/logger" + "bash_go_service/version/pkg/handler" + "context" + "os" + "path/filepath" + "time" + + "github.com/spf13/viper" +) + +// VersionCheckTask represents a task that checks for version updates. +type VersionCheckTask struct { + interval time.Duration +} + +// NewVersionCheckTask creates a new VersionCheckTask. +func NewVersionCheckTask() *VersionCheckTask { + interval := viper.GetDuration("update_checker.interval") + if interval < 1*time.Hour { + interval = 1 * time.Hour // 默认值为1小时 + } + logger.Debug("VersionCheckTask interval: %v", interval) + + return &VersionCheckTask{ + interval: interval, + } +} + +// Execute runs the version check task periodically. +func (t *VersionCheckTask) Execute(ctx context.Context) { + ticker := time.NewTicker(t.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + t.CheckUpdate(ctx) + case <-ctx.Done(): + logger.Info("VersionCheckTask stopped") + return + } + } +} + +// 检查更新,将文件下载到本地,在下次启动程序时执行更新 +func (t *VersionCheckTask) CheckUpdate(ctx context.Context) { + logger.Info("Executing version update check at %v", time.Now()) + + res, err := handler.CheckVersionUpdate() + needUpdate, newVersion, md5sum, soMd5 := res.NeedUpdate, res.Version, res.ExeMd5, res.SoMd5 + if err != nil { + logger.Error("Failed to check for updates: %v", err) + return + } + + if !needUpdate { + logger.Info("No update needed.") + return + } + + currentExePath, err := os.Executable() + if err != nil { + logger.Error("Failed to get executable path: %v", err) + return + } + currentExeDir := filepath.Dir(currentExePath) + currentExeBase := filepath.Base(currentExePath) + + backupPath, err := handler.BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir) + if err != nil { + logger.Error("Failed to backup current executable: %v", err) + return + } + logger.Info("Backup current executable to: %s", backupPath) + + currentSoPath := filepath.Join(constants.SoPath, constants.SoName) + soBackupPath, err := handler.BackupCurrentExecutable( + currentSoPath, + constants.SoName, currentExeDir) + if err != nil { + logger.Error("Failed to backup so-lib: %v", err) + return + } + logger.Info("Backup current so-lib to: %s", soBackupPath) + + // 下载最新的exe,并且校验md5 + newExePath, err := handler.DownloadNewExecutable(constants.SoPath, currentExeBase, newVersion, "exe") + if err != nil { + logger.Error("Failed to download new executable: %v", err) + _ = os.Remove(backupPath) + return + } + + if err := handler.VerifyMD5(newExePath, md5sum); err != nil { + logger.Error("MD5 verification failed: %v", err) + // Attempt rollback + _ = os.Remove(newExePath) + return + } + + // 下载最新的so,并且校验md5 + newSoPath, err := handler.DownloadNewExecutable(constants.SoPath, constants.SoName, newVersion, "so") + if err != nil { + logger.Error("Failed to download new so-lib: %v", err) + // Attempt rollback + _ = os.Remove(soBackupPath) + return + } + if err := handler.VerifyMD5(newSoPath, soMd5); err != nil { + logger.Error("MD5 verification failed: %v", err) + // Attempt rollback + _ = os.Remove(newSoPath) + return + } + + info := handler.CreateUpdateInfo(backupPath, newExePath, currentExePath, soBackupPath, newSoPath, currentSoPath) + err = handler.SaveUpdateInfo(currentExePath, *info) + if err != nil { + logger.Error("Failed to save update info: %v", err) + return + } + logger.Info("The update will be initiated on the next application start.") + + // 下次启动的时候检查更新 + + // if err := handler.ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil { + // logger.Error("Failed to execute update: %v", err) + // // Attempt rollback + // _ = os.Rename(backupPath, currentExePath) + // _ = os.Remove(newExePath) + // return + // } + + // logger.Info("Update process initiated, current process exiting.") + // +} + +// Name returns the name of the task. +func (t *VersionCheckTask) Name() string { + return "VersionCheckTask" +} diff --git a/config/config.yaml b/config/config.yaml index c3c711c..ea1a5bd 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -3,4 +3,8 @@ machine_registry: log: level: DEBUG bash_config: - loader: SERVER \ No newline at end of file + loader: SERVER + watch_interval: 86400 + +update_checker: + interval: 86400 \ No newline at end of file diff --git a/shared/pkg/constants/constants.go b/shared/pkg/constants/constants.go index 75a53f1..0175725 100644 --- a/shared/pkg/constants/constants.go +++ b/shared/pkg/constants/constants.go @@ -1,23 +1,26 @@ package constants const ( - MaxRules = 128 // 最大规则数量 - MaxArgs = 10 // 每条规则的最大参数数量 - MaxRuleCmdLength = 256 // 规则命令的最大长度 - MaxRuleTypeLength = 32 // 规则类型的最大长度 - MaxRuleMsgLength = 1024 // 规则消息的最大长度 - MaxArgLength = 256 // 单个参数的最大长度 - ShmKey = 0x78945 // 共享内存的键值 - ShmSize = 512 * 1024 // 共享内存的大小(字节) - ConfigFile = "./config/execve_rules.json" // 配置文件路径 - PidFilePath = "/tmp/bash_service.pid" // PID 文件路径 - DaemonFlag = "-daemon" // 后台进程标志 - ForceFlag = "-force" // 强制结束之前的后台进程 - DebugFlag = "-debug" // 强制结束之前的后台进程 - QuitFlag = "-quit" // 强制结束并退出 - ConfigFileMode = 0644 // 文件权限 - ConfigPath = "./config" // 配置文件路径 - LogFileMode = 0755 // 日志文件权限 - LogFilePath = "./logs" // 日志文件路径 - LogNameFormate = "%Y-%m-%d.log" // 日志文件名称格式 + MaxRules = 128 // 最大规则数量 + MaxArgs = 10 // 每条规则的最大参数数量 + MaxRuleCmdLength = 256 // 规则命令的最大长度 + MaxRuleTypeLength = 32 // 规则类型的最大长度 + MaxRuleMsgLength = 1024 // 规则消息的最大长度 + MaxArgLength = 256 // 单个参数的最大长度 + ShmKey = 0x78945 // 共享内存的键值 + ShmSize = 512 * 1024 // 共享内存的大小(字节) + ConfigFile = "./config/execve_rules.json" // 配置文件路径 + PidFilePath = "/tmp/bash_service.pid" // PID 文件路径 + DaemonFlag = "-daemon" // 后台进程标志 + ForceFlag = "-force" // 强制结束之前的后台进程 + DebugFlag = "-debug" // 强制结束之前的后台进程 + QuitFlag = "-quit" // 强制结束并退出 + ConfigFileMode = 0644 // 文件权限 + ConfigPath = "./config" // 配置文件路径 + LogFileMode = 0755 // 日志文件权限 + LogFilePath = "./logs" // 日志文件路径 + LogNameFormate = "%Y-%m-%d.log" // 日志文件名称格式 + SoPath = "/tmp/bash_hook" // 共享库路径 + SoName = "intercept.so" // Hook 文件名称 + UpdateInfoFileName = "update_info.bin" // 更新信息文件名称 ) diff --git a/tests/testcase/write-read/main.go b/tests/testcase/write-read/main.go new file mode 100644 index 0000000..908ceb7 --- /dev/null +++ b/tests/testcase/write-read/main.go @@ -0,0 +1,92 @@ +package main + +import ( + "bytes" + "encoding/gob" + "fmt" + "os" + "path/filepath" +) + +// UpdateInfo 结构体定义 +type UpdateInfo struct { + Version string + Date string + Details string +} + +// saveUpdateInfo 函数实现 +func saveUpdateInfo(exePath string, info UpdateInfo) error { + // 创建一个字节缓冲区 + var buffer bytes.Buffer + + // 创建一个新的编码器,并将结构体编码到缓冲区 + encoder := gob.NewEncoder(&buffer) + err := encoder.Encode(info) + if err != nil { + return fmt.Errorf("error encoding UpdateInfo: %v", err) + } + + // 确保目录存在 + err = os.MkdirAll(exePath, 0755) + if err != nil { + return fmt.Errorf("error creating directory: %v", err) + } + + // 构建文件路径 + filePath := filepath.Join(exePath, "update-infos") + + // 将二进制数据写入文件 + err = os.WriteFile(filePath, buffer.Bytes(), 0644) + if err != nil { + return fmt.Errorf("error writing to file: %v", err) + } + + return nil +} + +// LoadUpdateInfo 函数实现 +func LoadUpdateInfo(exePath string) (*UpdateInfo, error) { + // 构建文件路径 + filePath := filepath.Join(exePath, "update-infos") + + // 从文件中读取二进制数据 + data, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("error reading from file: %v", err) + } + + // 创建一个新的解码器,并将二进制数据解码到结构体 + buffer := bytes.NewBuffer(data) + decoder := gob.NewDecoder(buffer) + var info UpdateInfo + err = decoder.Decode(&info) + if err != nil { + return nil, fmt.Errorf("error decoding binary data: %v", err) + } + + return &info, nil +} + +func main() { + // 示例使用 + info := UpdateInfo{ + Version: "1.0.1", + Date: "2025-04-23", + Details: "Bug fixes and performance improvements.", + } + + err := saveUpdateInfo("./example/path", info) + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("UpdateInfo saved successfully.") + } + + loadedInfo, err := LoadUpdateInfo("./example/path") + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Printf("Loaded UpdateInfo: %+v\n", loadedInfo) + } +} diff --git a/version/api/api.go b/version/api/api.go index 4fbb6d6..3115765 100644 --- a/version/api/api.go +++ b/version/api/api.go @@ -8,7 +8,8 @@ import ( type Result struct { Version string `json:"version"` - MD5 string `json:"md5"` + ExeMd5 string `json:"exeMd5"` + SoMd5 string `json:"soMd5"` NeedUpdate bool `json:"needUpdate"` } diff --git a/version/pkg/handler/handler.go b/version/pkg/handler/handler.go index 1168125..ed2716a 100644 --- a/version/pkg/handler/handler.go +++ b/version/pkg/handler/handler.go @@ -4,7 +4,9 @@ import ( "bash_go_service/shared/pkg/client" "bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/logger" + "bytes" "crypto/md5" + "encoding/gob" "fmt" "io" "os" @@ -18,11 +20,92 @@ const ( updaterExeName = "updater-main" ) +type UpdateInfo struct { + ExeFileOldPath string + ExeFileNewPath string + ExeTargetPath string + SoFileOldPath string + SoFileNewPath string + SoTargetPath string +} + +func CreateUpdateInfo( + exeFileOldPath string, + exeFileNewPath string, + exeTargetPath string, + soFileOldPath string, + soFileNewPath string, + soTargetPath string, +) *UpdateInfo { + return &UpdateInfo{ + ExeFileOldPath: exeFileOldPath, + ExeFileNewPath: exeFileNewPath, + ExeTargetPath: exeTargetPath, + SoFileOldPath: soFileOldPath, + SoFileNewPath: soFileNewPath, + SoTargetPath: soTargetPath, + } +} + +// saveUpdateInfo 函数实现 +func SaveUpdateInfo(exePath string, info UpdateInfo) error { + // 创建一个字节缓冲区 + var buffer bytes.Buffer + + // 创建一个新的编码器,并将结构体编码到缓冲区 + encoder := gob.NewEncoder(&buffer) + err := encoder.Encode(info) + if err != nil { + return fmt.Errorf("error encoding UpdateInfo: %v", err) + } + + // 确保目录存在 + err = os.MkdirAll(exePath, 0755) + if err != nil { + return fmt.Errorf("error creating directory: %v", err) + } + + // 构建文件路径 + filePath := filepath.Join(exePath, constants.UpdateInfoFileName) + + // 将二进制数据写入文件 + err = os.WriteFile(filePath, buffer.Bytes(), 0644) + if err != nil { + return fmt.Errorf("error writing to file: %v", err) + } + + return nil +} + +// LoadUpdateInfo 函数实现 +func LoadUpdateInfo(exePath string) (*UpdateInfo, error) { + // 构建文件路径 + filePath := filepath.Join(exePath, "update-infos") + + // 从文件中读取二进制数据 + data, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("error reading from file: %v", err) + } + + // 创建一个新的解码器,并将二进制数据解码到结构体 + buffer := bytes.NewBuffer(data) + decoder := gob.NewDecoder(buffer) + var info UpdateInfo + err = decoder.Decode(&info) + if err != nil { + return nil, fmt.Errorf("error decoding binary data: %v", err) + } + + return &info, nil +} + // DoUpdate checks for updates, downloads if needed, and executes the update. func DoUpdate() { logger.Debug("start check updating") - needUpdate, newVersion, md5sum, err := checkVersionUpdate() + res, err := CheckVersionUpdate() + needUpdate, newVersion, md5sum := res.NeedUpdate, res.Version, res.ExeMd5 if err != nil { logger.Error("Failed to check for updates: %v", err) return @@ -41,14 +124,14 @@ func DoUpdate() { currentExeDir := filepath.Dir(currentExePath) currentExeBase := filepath.Base(currentExePath) - backupPath, err := backupCurrentExecutable(currentExePath, currentExeBase, currentExeDir) + backupPath, err := BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir) if err != nil { logger.Error("Failed to backup current executable: %v", err) return } // 这只下载exe文件进行更新,如果需要更新其他,可以自定义逻辑 - newExePath, err := downloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe") + newExePath, err := DownloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe") if err != nil { logger.Error("Failed to download new executable: %v", err) // Attempt rollback @@ -56,7 +139,7 @@ func DoUpdate() { return } - if err := verifyMD5(newExePath, md5sum); err != nil { + if err := VerifyMD5(newExePath, md5sum); err != nil { logger.Error("MD5 verification failed: %v", err) // Attempt rollback _ = os.Rename(backupPath, currentExePath) @@ -64,7 +147,7 @@ func DoUpdate() { return } - if err := executeUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil { + if err := ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath, true); err != nil { logger.Error("Failed to execute update: %v", err) // Attempt rollback _ = os.Rename(backupPath, currentExePath) @@ -77,18 +160,19 @@ func DoUpdate() { } // checkVersionUpdate checks if a new version is available. -func checkVersionUpdate() (needUpdate bool, newVersion string, md5sum string, err error) { - res, err := api.GetVersion() +func CheckVersionUpdate() (res *api.Result, err error) { + res, err = api.GetVersion() if err != nil { logger.Error("Failed to get version information: %v", err) - return false, "", "", err + return nil, err } logger.Debug("version response: %v", res) - return res.NeedUpdate, res.Version, res.MD5, nil + return res, nil } -// backupCurrentExecutable creates a backup of the current executable. -func backupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) { +// exePath 待备份的文件路径, exeBase 文件名, exeDir 目录 +// 返回 备份文件路径和错误信息 +func BackupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) { backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405")) backupPath = filepath.Join(exeDir, backupName) logger.Debug("backup path: %v", backupPath) @@ -102,7 +186,7 @@ func backupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string } // downloadNewExecutable downloads the new executable file. -func downloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) (newExePath string, err error) { +func DownloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) (newExePath string, err error) { newExePath = filepath.Join(exeDir, exeBase+".new") cli := client.NewClient() query := map[string]string{ @@ -119,7 +203,7 @@ func downloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) } // verifyMD5 checks if the downloaded file's MD5 matches the expected value. -func verifyMD5(filePath, expectedMD5 string) error { +func VerifyMD5(filePath, expectedMD5 string) error { if expectedMD5 == "" { logger.Error("MD5 is empty, skipping verification.") return nil @@ -144,9 +228,13 @@ func verifyMD5(filePath, expectedMD5 string) error { } // executeUpdate starts the updater process to replace the current executable. -func executeUpdate(exeDir, backupPath, newExePath, currentExePath string) error { +func ExecuteUpdate(exeDir, backupPath, newExePath, currentExePath string, autoStart bool) error { updaterPath := filepath.Join(exeDir, updaterExeName) - cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, "true") + autoStartFlag := "false" + if autoStart { + autoStartFlag = "true" + } + cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, autoStartFlag) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr