package handler import ( "bash_go_service/shared/pkg/client" "bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/logger" "bytes" "crypto/md5" "encoding/gob" "fmt" "io" "os" "os/exec" "path/filepath" "time" "version/api" ) const ( updaterExeName = "updater-main" ) type UpdateInfo struct { ExeFileOldPath string ExeFileNewPath string ExeTargetPath string SoFileOldPath string SoFileNewPath string SoTargetPath string } func CreateUpdateInfo( exeFileOldPath string, exeFileNewPath string, exeTargetPath string, soFileOldPath string, soFileNewPath string, soTargetPath string, ) *UpdateInfo { return &UpdateInfo{ ExeFileOldPath: exeFileOldPath, ExeFileNewPath: exeFileNewPath, ExeTargetPath: exeTargetPath, SoFileOldPath: soFileOldPath, SoFileNewPath: soFileNewPath, SoTargetPath: soTargetPath, } } // saveUpdateInfo 函数实现 func SaveUpdateInfo(exePath string, info UpdateInfo) error { saveDir := filepath.Join(exePath, "../") // 创建一个字节缓冲区 var buffer bytes.Buffer // 创建一个新的编码器,并将结构体编码到缓冲区 encoder := gob.NewEncoder(&buffer) err := encoder.Encode(info) if err != nil { return fmt.Errorf("error encoding UpdateInfo: %v", err) } // 确保目录存在 err = os.MkdirAll(saveDir, 0755) if err != nil { return fmt.Errorf("error creating directory: %v", err) } // 构建文件路径 filePath := filepath.Join(saveDir, constants.UpdateInfoFileName) // 将二进制数据写入文件 err = os.WriteFile(filePath, buffer.Bytes(), 0644) if err != nil { return fmt.Errorf("error writing to file: %v", err) } return nil } // LoadUpdateInfo 函数实现 func LoadUpdateInfo(exePath string) (*UpdateInfo, error) { savePath := filepath.Join(exePath, "../") // 构建文件路径 filePath := filepath.Join(savePath, constants.UpdateInfoFileName) // 从文件中读取二进制数据 data, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("error reading from file: %v", err) } // 创建一个新的解码器,并将二进制数据解码到结构体 buffer := bytes.NewBuffer(data) decoder := gob.NewDecoder(buffer) var info UpdateInfo err = decoder.Decode(&info) if err != nil { return nil, fmt.Errorf("error decoding binary data: %v", err) } return &info, nil } func RemoveUpdateInfo(exePath string) error { // 构建文件路径 filePath := filepath.Join(exePath, constants.UpdateInfoFileName) // 删除文件 err := os.Remove(filePath) if err != nil { return fmt.Errorf("error removing file: %v", err) } return nil } // DoUpdate checks for updates, downloads if needed, and executes the update. // func DoUpdate() { // logger.Debug("start check updating") // res, err := CheckVersionUpdate() // needUpdate, newVersion, md5sum := res.NeedUpdate, res.Version, res.ExeMd5 // if err != nil { // logger.Error("Failed to check for updates: %v", err) // return // } // if !needUpdate { // logger.Info("No update needed.") // return // } // currentExePath, err := os.Executable() // if err != nil { // logger.Error("Failed to get executable path: %v", err) // return // } // currentExeDir := filepath.Dir(currentExePath) // currentExeBase := filepath.Base(currentExePath) // backupPath, err := BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir) // if err != nil { // logger.Error("Failed to backup current executable: %v", err) // return // } // // 这只下载exe文件进行更新,如果需要更新其他,可以自定义逻辑 // newExePath, err := DownloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe") // if err != nil { // logger.Error("Failed to download new executable: %v", err) // // Attempt rollback // _ = os.Rename(backupPath, currentExePath) // return // } // if err := VerifyMD5(newExePath, md5sum); err != nil { // logger.Error("MD5 verification failed: %v", err) // // Attempt rollback // _ = os.Rename(backupPath, currentExePath) // _ = os.Remove(newExePath) // return // } // if err := ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath, true); err != nil { // logger.Error("Failed to execute update: %v", err) // // Attempt rollback // _ = os.Rename(backupPath, currentExePath) // _ = os.Remove(newExePath) // return // } // logger.Info("Update process initiated, current process exiting.") // os.Exit(0) // } // checkVersionUpdate checks if a new version is available. func CheckVersionUpdate() (res *api.Result, err error) { res, err = api.GetVersion() if err != nil { logger.Error("Failed to get version information: %v", err) return nil, err } return res, nil } // exePath 待备份的文件路径, exeBase 文件名, exeDir 目录 // 返回 备份文件路径和错误信息 func BackupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) { backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405")) backupPath = filepath.Join(exeDir, backupName) logger.Debug("backup path: %v", backupPath) err = copyFile(exePath, backupPath) if err != nil { logger.Error("Failed to backup current executable: %v", err) return "", err } logger.Info("Backup created: %s", backupPath) return backupPath, nil } // downloadNewExecutable downloads the new executable file. func DownloadNewExecutable(exeDir, exeBase, newVersion string, url string) (newExePath string, err error) { logger.Debug("downloading new executable from: %s", url) // 这里的url是api的地址 newExePath = filepath.Join(exeDir, exeBase+".new") cli := client.NewClient(client.WithBaseURL(url)) err = cli.Download("", nil, newExePath) if err != nil { logger.Error("Download failed: %v", err) return "", err } logger.Info("New executable downloaded to: %s", newExePath) return newExePath, nil } // verifyMD5 checks if the downloaded file's MD5 matches the expected value. func VerifyMD5(filePath, expectedMD5 string) error { if expectedMD5 == "" { logger.Error("MD5 is empty, skipping verification.") return nil } match, err := checkFileMD5(filePath, expectedMD5) if err != nil { logger.Error("Failed to check MD5: %v", err) return err } if !match { logger.Error("MD5 mismatch: expected %s", expectedMD5) return fmt.Errorf("MD5 mismatch: expected %s", expectedMD5) } logger.Info("MD5 checksum verified.") err = os.Chmod(filePath, 0755) if err != nil { logger.Error("Failed to set executable permission for new file: %v", err) return err } logger.Info("Executable permission set for new file.") return nil } // executeUpdate starts the updater process to replace the current executable. func ExecuteUpdate(exeDir, backupPath, newExePath, currentExePath string, autoStart bool) error { updaterPath := filepath.Join(exeDir, updaterExeName) autoStartFlag := "false" if autoStart { autoStartFlag = "true" } cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, autoStartFlag) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Start() if err != nil { logger.Error("Failed to start updater: %v", err) return err } return nil } // checkFileMD5 calculates the MD5 hash of a file and compares it with the expected MD5. func checkFileMD5(filePath, expectedMD5 string) (bool, error) { f, err := os.Open(filePath) if err != nil { return false, err } defer f.Close() hash := md5.New() if _, err := io.Copy(hash, f); err != nil { return false, err } actualMD5 := fmt.Sprintf("%x", hash.Sum(nil)) logger.Debug("actual MD5: %v", actualMD5) return actualMD5 == expectedMD5, nil } // copyFile copies a file from source to destination. func copyFile(src, dst string) error { sourceFile, err := os.Open(src) if err != nil { return err } defer sourceFile.Close() destFile, err := os.Create(dst) if err != nil { return err } defer destFile.Close() _, err = io.Copy(destFile, sourceFile) return err }