package handler import ( "bash_go_service/shared/pkg/client" "bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/logger" "bytes" "crypto/md5" "encoding/gob" "fmt" "io" "os" "os/exec" "path/filepath" "time" "version/api" ) const ( updaterExeName = "updater-main" ) type UpdateInfo struct { ExeFileOldPath string ExeFileNewPath string ExeTargetPath string SoFileOldPath string SoFileNewPath string SoTargetPath string } func CreateUpdateInfo( exeFileOldPath string, exeFileNewPath string, exeTargetPath string, soFileOldPath string, soFileNewPath string, soTargetPath string, ) *UpdateInfo { return &UpdateInfo{ ExeFileOldPath: exeFileOldPath, ExeFileNewPath: exeFileNewPath, ExeTargetPath: exeTargetPath, SoFileOldPath: soFileOldPath, SoFileNewPath: soFileNewPath, SoTargetPath: soTargetPath, } } // saveUpdateInfo 函数实现 func SaveUpdateInfo(exePath string, info UpdateInfo) error { // 创建一个字节缓冲区 var buffer bytes.Buffer // 创建一个新的编码器,并将结构体编码到缓冲区 encoder := gob.NewEncoder(&buffer) err := encoder.Encode(info) if err != nil { return fmt.Errorf("error encoding UpdateInfo: %v", err) } // 确保目录存在 err = os.MkdirAll(exePath, 0755) if err != nil { return fmt.Errorf("error creating directory: %v", err) } // 构建文件路径 filePath := filepath.Join(exePath, constants.UpdateInfoFileName) // 将二进制数据写入文件 err = os.WriteFile(filePath, buffer.Bytes(), 0644) if err != nil { return fmt.Errorf("error writing to file: %v", err) } return nil } // LoadUpdateInfo 函数实现 func LoadUpdateInfo(exePath string) (*UpdateInfo, error) { // 构建文件路径 filePath := filepath.Join(exePath, "update-infos") // 从文件中读取二进制数据 data, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("error reading from file: %v", err) } // 创建一个新的解码器,并将二进制数据解码到结构体 buffer := bytes.NewBuffer(data) decoder := gob.NewDecoder(buffer) var info UpdateInfo err = decoder.Decode(&info) if err != nil { return nil, fmt.Errorf("error decoding binary data: %v", err) } return &info, nil } // DoUpdate checks for updates, downloads if needed, and executes the update. func DoUpdate() { logger.Debug("start check updating") res, err := CheckVersionUpdate() needUpdate, newVersion, md5sum := res.NeedUpdate, res.Version, res.ExeMd5 if err != nil { logger.Error("Failed to check for updates: %v", err) return } if !needUpdate { logger.Info("No update needed.") return } currentExePath, err := os.Executable() if err != nil { logger.Error("Failed to get executable path: %v", err) return } currentExeDir := filepath.Dir(currentExePath) currentExeBase := filepath.Base(currentExePath) backupPath, err := BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir) if err != nil { logger.Error("Failed to backup current executable: %v", err) return } // 这只下载exe文件进行更新,如果需要更新其他,可以自定义逻辑 newExePath, err := DownloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe") if err != nil { logger.Error("Failed to download new executable: %v", err) // Attempt rollback _ = os.Rename(backupPath, currentExePath) return } if err := VerifyMD5(newExePath, md5sum); err != nil { logger.Error("MD5 verification failed: %v", err) // Attempt rollback _ = os.Rename(backupPath, currentExePath) _ = os.Remove(newExePath) return } if err := ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath, true); err != nil { logger.Error("Failed to execute update: %v", err) // Attempt rollback _ = os.Rename(backupPath, currentExePath) _ = os.Remove(newExePath) return } logger.Info("Update process initiated, current process exiting.") os.Exit(0) } // checkVersionUpdate checks if a new version is available. func CheckVersionUpdate() (res *api.Result, err error) { res, err = api.GetVersion() if err != nil { logger.Error("Failed to get version information: %v", err) return nil, err } logger.Debug("version response: %v", res) return res, nil } // exePath 待备份的文件路径, exeBase 文件名, exeDir 目录 // 返回 备份文件路径和错误信息 func BackupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) { backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405")) backupPath = filepath.Join(exeDir, backupName) logger.Debug("backup path: %v", backupPath) err = copyFile(exePath, backupPath) if err != nil { logger.Error("Failed to backup current executable: %v", err) return "", err } logger.Info("Backup created: %s", backupPath) return backupPath, nil } // downloadNewExecutable downloads the new executable file. func DownloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) (newExePath string, err error) { newExePath = filepath.Join(exeDir, exeBase+".new") cli := client.NewClient() query := map[string]string{ "version": newVersion, "filetype": filetype, } err = cli.Download(constants.DownloadNewApi, query, newExePath) if err != nil { logger.Error("Download failed: %v", err) return "", err } logger.Info("New executable downloaded to: %s", newExePath) return newExePath, nil } // verifyMD5 checks if the downloaded file's MD5 matches the expected value. func VerifyMD5(filePath, expectedMD5 string) error { if expectedMD5 == "" { logger.Error("MD5 is empty, skipping verification.") return nil } match, err := checkFileMD5(filePath, expectedMD5) if err != nil { logger.Error("Failed to check MD5: %v", err) return err } if !match { logger.Error("MD5 mismatch: expected %s", expectedMD5) return fmt.Errorf("MD5 mismatch: expected %s", expectedMD5) } logger.Info("MD5 checksum verified.") err = os.Chmod(filePath, 0755) if err != nil { logger.Error("Failed to set executable permission for new file: %v", err) return err } logger.Info("Executable permission set for new file.") return nil } // executeUpdate starts the updater process to replace the current executable. func ExecuteUpdate(exeDir, backupPath, newExePath, currentExePath string, autoStart bool) error { updaterPath := filepath.Join(exeDir, updaterExeName) autoStartFlag := "false" if autoStart { autoStartFlag = "true" } cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, autoStartFlag) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Start() if err != nil { logger.Error("Failed to start updater: %v", err) return err } return nil } // checkFileMD5 calculates the MD5 hash of a file and compares it with the expected MD5. func checkFileMD5(filePath, expectedMD5 string) (bool, error) { f, err := os.Open(filePath) if err != nil { return false, err } defer f.Close() hash := md5.New() if _, err := io.Copy(hash, f); err != nil { return false, err } actualMD5 := fmt.Sprintf("%x", hash.Sum(nil)) logger.Debug("actual MD5: %v", actualMD5) return actualMD5 == expectedMD5, nil } // copyFile copies a file from source to destination. func copyFile(src, dst string) error { sourceFile, err := os.Open(src) if err != nil { return err } defer sourceFile.Close() destFile, err := os.Create(dst) if err != nil { return err } defer destFile.Close() _, err = io.Copy(destFile, sourceFile) return err }