298 lines
7.7 KiB
Go
298 lines
7.7 KiB
Go
package handler
|
||
|
||
import (
|
||
"bash_go_service/shared/pkg/client"
|
||
"bash_go_service/shared/pkg/constants"
|
||
"bash_go_service/shared/pkg/logger"
|
||
"bytes"
|
||
"crypto/md5"
|
||
"encoding/gob"
|
||
"fmt"
|
||
"io"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"time"
|
||
"version/api"
|
||
)
|
||
|
||
const (
|
||
updaterExeName = "updater-main"
|
||
)
|
||
|
||
type UpdateInfo struct {
|
||
ExeFileOldPath string
|
||
ExeFileNewPath string
|
||
ExeTargetPath string
|
||
SoFileOldPath string
|
||
SoFileNewPath string
|
||
SoTargetPath string
|
||
}
|
||
|
||
func CreateUpdateInfo(
|
||
exeFileOldPath string,
|
||
exeFileNewPath string,
|
||
exeTargetPath string,
|
||
soFileOldPath string,
|
||
soFileNewPath string,
|
||
soTargetPath string,
|
||
) *UpdateInfo {
|
||
return &UpdateInfo{
|
||
ExeFileOldPath: exeFileOldPath,
|
||
ExeFileNewPath: exeFileNewPath,
|
||
ExeTargetPath: exeTargetPath,
|
||
SoFileOldPath: soFileOldPath,
|
||
SoFileNewPath: soFileNewPath,
|
||
SoTargetPath: soTargetPath,
|
||
}
|
||
}
|
||
|
||
// saveUpdateInfo 函数实现
|
||
func SaveUpdateInfo(exePath string, info UpdateInfo) error {
|
||
saveDir := filepath.Join(exePath, "../")
|
||
// 创建一个字节缓冲区
|
||
var buffer bytes.Buffer
|
||
|
||
// 创建一个新的编码器,并将结构体编码到缓冲区
|
||
encoder := gob.NewEncoder(&buffer)
|
||
err := encoder.Encode(info)
|
||
if err != nil {
|
||
return fmt.Errorf("error encoding UpdateInfo: %v", err)
|
||
}
|
||
|
||
// 确保目录存在
|
||
err = os.MkdirAll(saveDir, 0755)
|
||
if err != nil {
|
||
return fmt.Errorf("error creating directory: %v", err)
|
||
}
|
||
|
||
// 构建文件路径
|
||
filePath := filepath.Join(saveDir, constants.UpdateInfoFileName)
|
||
|
||
// 将二进制数据写入文件
|
||
err = os.WriteFile(filePath, buffer.Bytes(), 0644)
|
||
if err != nil {
|
||
return fmt.Errorf("error writing to file: %v", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// LoadUpdateInfo 函数实现
|
||
func LoadUpdateInfo(exePath string) (*UpdateInfo, error) {
|
||
savePath := filepath.Join(exePath, "../")
|
||
// 构建文件路径
|
||
filePath := filepath.Join(savePath, constants.UpdateInfoFileName)
|
||
|
||
// 从文件中读取二进制数据
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error reading from file: %v", err)
|
||
}
|
||
|
||
// 创建一个新的解码器,并将二进制数据解码到结构体
|
||
buffer := bytes.NewBuffer(data)
|
||
decoder := gob.NewDecoder(buffer)
|
||
var info UpdateInfo
|
||
err = decoder.Decode(&info)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error decoding binary data: %v", err)
|
||
}
|
||
|
||
return &info, nil
|
||
}
|
||
|
||
func RemoveUpdateInfo(exePath string) error {
|
||
// 构建文件路径
|
||
filePath := filepath.Join(exePath, constants.UpdateInfoFileName)
|
||
|
||
// 删除文件
|
||
err := os.Remove(filePath)
|
||
if err != nil {
|
||
return fmt.Errorf("error removing file: %v", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DoUpdate checks for updates, downloads if needed, and executes the update.
|
||
func DoUpdate() {
|
||
logger.Debug("start check updating")
|
||
|
||
res, err := CheckVersionUpdate()
|
||
needUpdate, newVersion, md5sum := res.NeedUpdate, res.Version, res.ExeMd5
|
||
if err != nil {
|
||
logger.Error("Failed to check for updates: %v", err)
|
||
return
|
||
}
|
||
|
||
if !needUpdate {
|
||
logger.Info("No update needed.")
|
||
return
|
||
}
|
||
|
||
currentExePath, err := os.Executable()
|
||
if err != nil {
|
||
logger.Error("Failed to get executable path: %v", err)
|
||
return
|
||
}
|
||
currentExeDir := filepath.Dir(currentExePath)
|
||
currentExeBase := filepath.Base(currentExePath)
|
||
|
||
backupPath, err := BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir)
|
||
if err != nil {
|
||
logger.Error("Failed to backup current executable: %v", err)
|
||
return
|
||
}
|
||
|
||
// 这只下载exe文件进行更新,如果需要更新其他,可以自定义逻辑
|
||
newExePath, err := DownloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe")
|
||
if err != nil {
|
||
logger.Error("Failed to download new executable: %v", err)
|
||
// Attempt rollback
|
||
_ = os.Rename(backupPath, currentExePath)
|
||
return
|
||
}
|
||
|
||
if err := VerifyMD5(newExePath, md5sum); err != nil {
|
||
logger.Error("MD5 verification failed: %v", err)
|
||
// Attempt rollback
|
||
_ = os.Rename(backupPath, currentExePath)
|
||
_ = os.Remove(newExePath)
|
||
return
|
||
}
|
||
|
||
if err := ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath, true); err != nil {
|
||
logger.Error("Failed to execute update: %v", err)
|
||
// Attempt rollback
|
||
_ = os.Rename(backupPath, currentExePath)
|
||
_ = os.Remove(newExePath)
|
||
return
|
||
}
|
||
|
||
logger.Info("Update process initiated, current process exiting.")
|
||
os.Exit(0)
|
||
}
|
||
|
||
// checkVersionUpdate checks if a new version is available.
|
||
func CheckVersionUpdate() (res *api.Result, err error) {
|
||
res, err = api.GetVersion()
|
||
if err != nil {
|
||
logger.Error("Failed to get version information: %v", err)
|
||
return nil, err
|
||
}
|
||
logger.Debug("version response: %v", res)
|
||
return res, nil
|
||
}
|
||
|
||
// exePath 待备份的文件路径, exeBase 文件名, exeDir 目录
|
||
// 返回 备份文件路径和错误信息
|
||
func BackupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) {
|
||
backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405"))
|
||
backupPath = filepath.Join(exeDir, backupName)
|
||
logger.Debug("backup path: %v", backupPath)
|
||
err = copyFile(exePath, backupPath)
|
||
if err != nil {
|
||
logger.Error("Failed to backup current executable: %v", err)
|
||
return "", err
|
||
}
|
||
logger.Info("Backup created: %s", backupPath)
|
||
return backupPath, nil
|
||
}
|
||
|
||
// downloadNewExecutable downloads the new executable file.
|
||
func DownloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) (newExePath string, err error) {
|
||
newExePath = filepath.Join(exeDir, exeBase+".new")
|
||
cli := client.NewClient()
|
||
query := map[string]string{
|
||
"version": newVersion,
|
||
"filetype": filetype,
|
||
}
|
||
err = cli.Download(constants.DownloadNewApi, query, newExePath)
|
||
if err != nil {
|
||
logger.Error("Download failed: %v", err)
|
||
return "", err
|
||
}
|
||
logger.Info("New executable downloaded to: %s", newExePath)
|
||
return newExePath, nil
|
||
}
|
||
|
||
// verifyMD5 checks if the downloaded file's MD5 matches the expected value.
|
||
func VerifyMD5(filePath, expectedMD5 string) error {
|
||
if expectedMD5 == "" {
|
||
logger.Error("MD5 is empty, skipping verification.")
|
||
return nil
|
||
}
|
||
match, err := checkFileMD5(filePath, expectedMD5)
|
||
if err != nil {
|
||
logger.Error("Failed to check MD5: %v", err)
|
||
return err
|
||
}
|
||
if !match {
|
||
logger.Error("MD5 mismatch: expected %s", expectedMD5)
|
||
return fmt.Errorf("MD5 mismatch: expected %s", expectedMD5)
|
||
}
|
||
logger.Info("MD5 checksum verified.")
|
||
err = os.Chmod(filePath, 0755)
|
||
if err != nil {
|
||
logger.Error("Failed to set executable permission for new file: %v", err)
|
||
return err
|
||
}
|
||
logger.Info("Executable permission set for new file.")
|
||
return nil
|
||
}
|
||
|
||
// executeUpdate starts the updater process to replace the current executable.
|
||
func ExecuteUpdate(exeDir, backupPath, newExePath, currentExePath string, autoStart bool) error {
|
||
updaterPath := filepath.Join(exeDir, updaterExeName)
|
||
autoStartFlag := "false"
|
||
if autoStart {
|
||
autoStartFlag = "true"
|
||
}
|
||
cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, autoStartFlag)
|
||
cmd.Stdout = os.Stdout
|
||
cmd.Stderr = os.Stderr
|
||
|
||
err := cmd.Start()
|
||
if err != nil {
|
||
logger.Error("Failed to start updater: %v", err)
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// checkFileMD5 calculates the MD5 hash of a file and compares it with the expected MD5.
|
||
func checkFileMD5(filePath, expectedMD5 string) (bool, error) {
|
||
f, err := os.Open(filePath)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
defer f.Close()
|
||
|
||
hash := md5.New()
|
||
if _, err := io.Copy(hash, f); err != nil {
|
||
return false, err
|
||
}
|
||
actualMD5 := fmt.Sprintf("%x", hash.Sum(nil))
|
||
logger.Debug("actual MD5: %v", actualMD5)
|
||
return actualMD5 == expectedMD5, nil
|
||
}
|
||
|
||
// copyFile copies a file from source to destination.
|
||
func copyFile(src, dst string) error {
|
||
sourceFile, err := os.Open(src)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer sourceFile.Close()
|
||
|
||
destFile, err := os.Create(dst)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer destFile.Close()
|
||
|
||
_, err = io.Copy(destFile, sourceFile)
|
||
return err
|
||
}
|