195 lines
5.3 KiB
Go
195 lines
5.3 KiB
Go
package handler
|
||
|
||
import (
|
||
"bash_go_service/shared/pkg/client"
|
||
"bash_go_service/shared/pkg/constants"
|
||
"bash_go_service/shared/pkg/logger"
|
||
"crypto/md5"
|
||
"fmt"
|
||
"io"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"time"
|
||
"version/api"
|
||
)
|
||
|
||
const (
|
||
updaterExeName = "updater-main"
|
||
)
|
||
|
||
// DoUpdate checks for updates, downloads if needed, and executes the update.
|
||
func DoUpdate() {
|
||
logger.Debug("start check updating")
|
||
|
||
needUpdate, newVersion, md5sum, err := checkVersionUpdate()
|
||
if err != nil {
|
||
logger.Error("Failed to check for updates: %v", err)
|
||
return
|
||
}
|
||
|
||
if !needUpdate {
|
||
logger.Info("No update needed.")
|
||
return
|
||
}
|
||
|
||
currentExePath, err := os.Executable()
|
||
if err != nil {
|
||
logger.Error("Failed to get executable path: %v", err)
|
||
return
|
||
}
|
||
currentExeDir := filepath.Dir(currentExePath)
|
||
currentExeBase := filepath.Base(currentExePath)
|
||
|
||
backupPath, err := backupCurrentExecutable(currentExePath, currentExeBase, currentExeDir)
|
||
if err != nil {
|
||
logger.Error("Failed to backup current executable: %v", err)
|
||
return
|
||
}
|
||
|
||
// 这只下载exe文件进行更新,如果需要更新其他,可以自定义逻辑
|
||
newExePath, err := downloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe")
|
||
if err != nil {
|
||
logger.Error("Failed to download new executable: %v", err)
|
||
// Attempt rollback
|
||
_ = os.Rename(backupPath, currentExePath)
|
||
return
|
||
}
|
||
|
||
if err := verifyMD5(newExePath, md5sum); err != nil {
|
||
logger.Error("MD5 verification failed: %v", err)
|
||
// Attempt rollback
|
||
_ = os.Rename(backupPath, currentExePath)
|
||
_ = os.Remove(newExePath)
|
||
return
|
||
}
|
||
|
||
if err := executeUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil {
|
||
logger.Error("Failed to execute update: %v", err)
|
||
// Attempt rollback
|
||
_ = os.Rename(backupPath, currentExePath)
|
||
_ = os.Remove(newExePath)
|
||
return
|
||
}
|
||
|
||
logger.Info("Update process initiated, current process exiting.")
|
||
os.Exit(0)
|
||
}
|
||
|
||
// checkVersionUpdate checks if a new version is available.
|
||
func checkVersionUpdate() (needUpdate bool, newVersion string, md5sum string, err error) {
|
||
res, err := api.GetVersion()
|
||
if err != nil {
|
||
logger.Error("Failed to get version information: %v", err)
|
||
return false, "", "", err
|
||
}
|
||
logger.Debug("version response: %v", res)
|
||
return res.NeedUpdate, res.Version, res.MD5, nil
|
||
}
|
||
|
||
// backupCurrentExecutable creates a backup of the current executable.
|
||
func backupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) {
|
||
backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405"))
|
||
backupPath = filepath.Join(exeDir, backupName)
|
||
logger.Debug("backup path: %v", backupPath)
|
||
err = copyFile(exePath, backupPath)
|
||
if err != nil {
|
||
logger.Error("Failed to backup current executable: %v", err)
|
||
return "", err
|
||
}
|
||
logger.Info("Backup created: %s", backupPath)
|
||
return backupPath, nil
|
||
}
|
||
|
||
// downloadNewExecutable downloads the new executable file.
|
||
func downloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) (newExePath string, err error) {
|
||
newExePath = filepath.Join(exeDir, exeBase+".new")
|
||
cli := client.NewClient()
|
||
query := map[string]string{
|
||
"version": newVersion,
|
||
"filetype": filetype,
|
||
}
|
||
err = cli.Download(constants.DownloadNewApi, query, newExePath)
|
||
if err != nil {
|
||
logger.Error("Download failed: %v", err)
|
||
return "", err
|
||
}
|
||
logger.Info("New executable downloaded to: %s", newExePath)
|
||
return newExePath, nil
|
||
}
|
||
|
||
// verifyMD5 checks if the downloaded file's MD5 matches the expected value.
|
||
func verifyMD5(filePath, expectedMD5 string) error {
|
||
if expectedMD5 == "" {
|
||
logger.Error("MD5 is empty, skipping verification.")
|
||
return nil
|
||
}
|
||
match, err := checkFileMD5(filePath, expectedMD5)
|
||
if err != nil {
|
||
logger.Error("Failed to check MD5: %v", err)
|
||
return err
|
||
}
|
||
if !match {
|
||
logger.Error("MD5 mismatch: expected %s", expectedMD5)
|
||
return fmt.Errorf("MD5 mismatch: expected %s", expectedMD5)
|
||
}
|
||
logger.Info("MD5 checksum verified.")
|
||
err = os.Chmod(filePath, 0755)
|
||
if err != nil {
|
||
logger.Error("Failed to set executable permission for new file: %v", err)
|
||
return err
|
||
}
|
||
logger.Info("Executable permission set for new file.")
|
||
return nil
|
||
}
|
||
|
||
// executeUpdate starts the updater process to replace the current executable.
|
||
func executeUpdate(exeDir, backupPath, newExePath, currentExePath string) error {
|
||
updaterPath := filepath.Join(exeDir, updaterExeName)
|
||
cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, "true")
|
||
cmd.Stdout = os.Stdout
|
||
cmd.Stderr = os.Stderr
|
||
|
||
err := cmd.Start()
|
||
if err != nil {
|
||
logger.Error("Failed to start updater: %v", err)
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// checkFileMD5 calculates the MD5 hash of a file and compares it with the expected MD5.
|
||
func checkFileMD5(filePath, expectedMD5 string) (bool, error) {
|
||
f, err := os.Open(filePath)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
defer f.Close()
|
||
|
||
hash := md5.New()
|
||
if _, err := io.Copy(hash, f); err != nil {
|
||
return false, err
|
||
}
|
||
actualMD5 := fmt.Sprintf("%x", hash.Sum(nil))
|
||
logger.Debug("actual MD5: %v", actualMD5)
|
||
return actualMD5 == expectedMD5, nil
|
||
}
|
||
|
||
// copyFile copies a file from source to destination.
|
||
func copyFile(src, dst string) error {
|
||
sourceFile, err := os.Open(src)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer sourceFile.Close()
|
||
|
||
destFile, err := os.Create(dst)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer destFile.Close()
|
||
|
||
_, err = io.Copy(destFile, sourceFile)
|
||
return err
|
||
}
|