拆分函数

This commit is contained in:
Pan Qiancheng 2025-04-23 09:20:29 +08:00
parent 153cb9e597
commit 8d2bef03a1
3 changed files with 195 additions and 146 deletions

View File

@ -4,7 +4,7 @@ import (
"bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger" "bash_go_service/shared/pkg/logger"
"fmt" "fmt"
"version/handler" "version/pkg/handler"
"github.com/spf13/viper" "github.com/spf13/viper"
) )

View File

@ -1,145 +0,0 @@
package handler
import (
"bash_go_service/shared/pkg/client"
"bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger"
"crypto/md5"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"time"
"version/api"
)
const (
updaterExeName = "updater-main"
)
func DoUpdate() {
logger.Debug("start check updating")
res, err := api.GetVersion()
if err != nil {
logger.Error("Failed to get version: %v", err)
return
}
logger.Debug("res: %v", res)
if !res.NeedUpdate {
logger.Info("No update needed.")
return
}
exePath, err := os.Executable()
if err != nil {
logger.Error("Failed to get executable path: %v", err)
return
}
exeDir := filepath.Dir(exePath)
exeBase := filepath.Base(exePath)
// 备份旧程序
backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405"))
backupPath := filepath.Join(exeDir, backupName)
logger.Debug("backupPath: %v", backupPath)
err = copyFile(exePath, backupPath)
if err != nil {
logger.Error("Failed to backup current executable: %v", err)
return
}
logger.Info("Backup created: %s", backupPath)
// 下载新程序到临时路径
newExePath := filepath.Join(exeDir, exeBase+".new")
cli := client.NewClient()
query := map[string]string{
"version": res.Version,
}
err = cli.Download(constants.DownloadNewApi, query, newExePath)
if err != nil {
logger.Error("Download failed: %v", err)
_ = os.Rename(backupPath, exePath) // rollback
return
}
logger.Info("New executable downloaded to: %s", newExePath)
// 🔥 校验 MD5
if res.MD5 == "" {
logger.Error("MD5 is empty")
return
}
match, Merr := checkFileMD5(newExePath, res.MD5)
if Merr != nil {
logger.Error("Failed to check MD5: %v", Merr)
_ = os.Rename(backupPath, exePath)
return
}
if !match {
logger.Error("MD5 mismatch: expected %s", res.MD5)
_ = os.Rename(backupPath, exePath)
return
}
logger.Info("MD5 checksum verified.")
// 设置新文件执行权限
err = os.Chmod(newExePath, 0755)
if err != nil {
logger.Error("Failed to set executable permission: %v", err)
_ = os.Remove(newExePath)
_ = copyFile(backupPath, exePath)
return
}
logger.Info("Executable permission set for new file.")
// 启动 updater 进行替换
updaterPath := filepath.Join(exeDir, updaterExeName)
cmd := exec.Command(updaterPath, backupPath, newExePath, exePath, "true")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Start()
if err != nil {
logger.Error("Failed to start updater: %v", err)
_ = os.Rename(backupPath, exePath)
return
}
logger.Info("Updater started, exiting current process.")
os.Exit(0)
}
// MD5 校验函数
func checkFileMD5(filePath, expectedMD5 string) (bool, error) {
f, err := os.Open(filePath)
if err != nil {
return false, err
}
defer f.Close()
hash := md5.New()
if _, err := io.Copy(hash, f); err != nil {
return false, err
}
actualMD5 := fmt.Sprintf("%x", hash.Sum(nil))
logger.Info("actual MD5: %v", actualMD5)
return actualMD5 == expectedMD5, nil
}
// 复制文件函数
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}

View File

@ -0,0 +1,194 @@
package handler
import (
"bash_go_service/shared/pkg/client"
"bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger"
"crypto/md5"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"time"
"version/api"
)
const (
updaterExeName = "updater-main"
)
// DoUpdate checks for updates, downloads if needed, and executes the update.
func DoUpdate() {
logger.Debug("start check updating")
needUpdate, newVersion, md5sum, err := checkVersionUpdate()
if err != nil {
logger.Error("Failed to check for updates: %v", err)
return
}
if !needUpdate {
logger.Info("No update needed.")
return
}
currentExePath, err := os.Executable()
if err != nil {
logger.Error("Failed to get executable path: %v", err)
return
}
currentExeDir := filepath.Dir(currentExePath)
currentExeBase := filepath.Base(currentExePath)
backupPath, err := backupCurrentExecutable(currentExePath, currentExeBase, currentExeDir)
if err != nil {
logger.Error("Failed to backup current executable: %v", err)
return
}
// 这只下载exe文件进行更新如果需要更新其他可以自定义逻辑
newExePath, err := downloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe")
if err != nil {
logger.Error("Failed to download new executable: %v", err)
// Attempt rollback
_ = os.Rename(backupPath, currentExePath)
return
}
if err := verifyMD5(newExePath, md5sum); err != nil {
logger.Error("MD5 verification failed: %v", err)
// Attempt rollback
_ = os.Rename(backupPath, currentExePath)
_ = os.Remove(newExePath)
return
}
if err := executeUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil {
logger.Error("Failed to execute update: %v", err)
// Attempt rollback
_ = os.Rename(backupPath, currentExePath)
_ = os.Remove(newExePath)
return
}
logger.Info("Update process initiated, current process exiting.")
os.Exit(0)
}
// checkVersionUpdate checks if a new version is available.
func checkVersionUpdate() (needUpdate bool, newVersion string, md5sum string, err error) {
res, err := api.GetVersion()
if err != nil {
logger.Error("Failed to get version information: %v", err)
return false, "", "", err
}
logger.Debug("version response: %v", res)
return res.NeedUpdate, res.Version, res.MD5, nil
}
// backupCurrentExecutable creates a backup of the current executable.
func backupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) {
backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405"))
backupPath = filepath.Join(exeDir, backupName)
logger.Debug("backup path: %v", backupPath)
err = copyFile(exePath, backupPath)
if err != nil {
logger.Error("Failed to backup current executable: %v", err)
return "", err
}
logger.Info("Backup created: %s", backupPath)
return backupPath, nil
}
// downloadNewExecutable downloads the new executable file.
func downloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) (newExePath string, err error) {
newExePath = filepath.Join(exeDir, exeBase+".new")
cli := client.NewClient()
query := map[string]string{
"version": newVersion,
"filetype": filetype,
}
err = cli.Download(constants.DownloadNewApi, query, newExePath)
if err != nil {
logger.Error("Download failed: %v", err)
return "", err
}
logger.Info("New executable downloaded to: %s", newExePath)
return newExePath, nil
}
// verifyMD5 checks if the downloaded file's MD5 matches the expected value.
func verifyMD5(filePath, expectedMD5 string) error {
if expectedMD5 == "" {
logger.Error("MD5 is empty, skipping verification.")
return nil
}
match, err := checkFileMD5(filePath, expectedMD5)
if err != nil {
logger.Error("Failed to check MD5: %v", err)
return err
}
if !match {
logger.Error("MD5 mismatch: expected %s", expectedMD5)
return fmt.Errorf("MD5 mismatch: expected %s", expectedMD5)
}
logger.Info("MD5 checksum verified.")
err = os.Chmod(filePath, 0755)
if err != nil {
logger.Error("Failed to set executable permission for new file: %v", err)
return err
}
logger.Info("Executable permission set for new file.")
return nil
}
// executeUpdate starts the updater process to replace the current executable.
func executeUpdate(exeDir, backupPath, newExePath, currentExePath string) error {
updaterPath := filepath.Join(exeDir, updaterExeName)
cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, "true")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
logger.Error("Failed to start updater: %v", err)
return err
}
return nil
}
// checkFileMD5 calculates the MD5 hash of a file and compares it with the expected MD5.
func checkFileMD5(filePath, expectedMD5 string) (bool, error) {
f, err := os.Open(filePath)
if err != nil {
return false, err
}
defer f.Close()
hash := md5.New()
if _, err := io.Copy(hash, f); err != nil {
return false, err
}
actualMD5 := fmt.Sprintf("%x", hash.Sum(nil))
logger.Debug("actual MD5: %v", actualMD5)
return actualMD5 == expectedMD5, nil
}
// copyFile copies a file from source to destination.
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}