自动更新主程序功能

This commit is contained in:
Pan Qiancheng 2025-04-13 16:56:17 +08:00
parent d6d9d7709f
commit 0bb14fd536
9 changed files with 260 additions and 7 deletions

View File

@ -5,6 +5,7 @@ use (
./config-loader ./config-loader
./shared ./shared
./tests ./tests
./updater
./version ./version
./welcome ./welcome
) )

View File

@ -5,9 +5,11 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path" "path"
) )
@ -261,3 +263,62 @@ func (c *Client) PostToStream(uri string, body interface{}, query map[string]str
return nil return nil
} }
func (c *Client) Download(uri string, query map[string]string, filepath string) error {
c.trackRequest()
defer c.requestDone()
// 构建请求URL
u, err := url.Parse(c.baseURL)
if err != nil {
logger.Error("Failed to parse base URL: %v", err)
return err
}
u.Path = path.Join(u.Path, uri)
if query != nil {
q := u.Query()
for k, v := range query {
q.Set(k, v)
}
u.RawQuery = q.Encode()
}
// 创建GET请求
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, u.String(), nil)
if err != nil {
logger.Error("Failed to create GET request: %v", err)
return err
}
// 发送请求
resp, err := c.client.Do(req)
if err != nil {
logger.Error("Download request failed: %v", err)
return err
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
logger.Error("Received non-OK status code: %d", resp.StatusCode)
return fmt.Errorf("download failed with status code: %d", resp.StatusCode)
}
// 创建目标文件
out, err := os.Create(filepath)
if err != nil {
logger.Error("Failed to create file: %v", err)
return err
}
defer out.Close()
// 将响应内容写入文件
_, err = io.Copy(out, resp.Body)
if err != nil {
logger.Error("Failed to write file: %v", err)
return err
}
return nil
}

View File

@ -11,4 +11,5 @@ const (
MachineInfoApi = "/machine/info" MachineInfoApi = "/machine/info"
GetVersionApi = "/version" GetVersionApi = "/version"
QuestionStreamApi = "/question" QuestionStreamApi = "/question"
DownloadNewApi = "/download"
) )

View File

@ -12,14 +12,14 @@ func main() {
// 创建welcome service // 创建welcome service
service, err := welcome.NewService() service, err := welcome.NewService()
if err != nil { if err != nil {
logger.Error("Failed to create welcome service:", err) logger.Error("Failed to create welcome service: %v", err)
return return
} }
defer service.Stop() defer service.Stop()
// 启动服务 // 启动服务
if err := service.Start(); err != nil { if err := service.Start(); err != nil {
logger.Error("Failed to start welcome service:", err) logger.Error("Failed to start welcome service: %v", err)
return return
} }

53
updater/cmd/main.go Normal file
View File

@ -0,0 +1,53 @@
package main
import (
"fmt"
"os"
"os/exec"
"path/filepath"
)
func main() {
if len(os.Args) < 4 {
fmt.Println("Usage: updater <oldPath> <newPath> <targetPath>")
os.Exit(1)
}
oldPath := os.Args[1]
newPath := os.Args[2]
targetPath := os.Args[3]
// 替换当前程序
err := os.Rename(newPath, targetPath)
if err != nil {
fmt.Printf("Failed to replace executable: %v\n", err)
_ = os.Rename(oldPath, targetPath)
os.Exit(1)
}
fmt.Printf("Replaced executable: %s -> %s\n", newPath, targetPath)
// 设置可执行权限
err = os.Chmod(targetPath, 0755)
if err != nil {
fmt.Printf("Failed to set permission on new executable: %v\n", err)
_ = os.Remove(targetPath)
_ = os.Rename(oldPath, targetPath)
os.Exit(1)
}
// 启动新程序
cmd := exec.Command(targetPath)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Dir = filepath.Dir(targetPath)
err = cmd.Start()
if err != nil {
fmt.Printf("Failed to start new process: %v\n", err)
// 回滚
_ = os.Rename(oldPath, targetPath)
os.Exit(1)
}
fmt.Println("New process started successfully.")
}

3
updater/go.mod Normal file
View File

@ -0,0 +1,3 @@
module updater
go 1.24.2

View File

@ -12,7 +12,7 @@ type Result struct {
NeedUpdate bool `json:"needUpdate"` NeedUpdate bool `json:"needUpdate"`
} }
func GetVersion() *Result { func GetVersion() (*Result, error) {
apiEndpoint := constants.GetVersionApi apiEndpoint := constants.GetVersionApi
client := client.NewClient() client := client.NewClient()
// params // params
@ -24,5 +24,5 @@ func GetVersion() *Result {
if err != nil { if err != nil {
logger.Error("Failed to get version: %v", err) logger.Error("Failed to get version: %v", err)
} }
return &result return &result, err
} }

View File

@ -17,10 +17,12 @@ func initConfig() error {
func main() { func main() {
if err := initConfig(); err != nil { if err := initConfig(); err != nil {
fmt.Println("Failed to init config: %v", err) fmt.Printf("Failed to init config: %v", err)
return return
} }
logger.UpdateLogLevel() logger.UpdateLogLevel()
logger.Debug("Version: %s", constants.CurrentVersion)
handler.DoUpdate() handler.DoUpdate()
} }

View File

@ -1,12 +1,144 @@
package handler package handler
import ( import (
"bash_go_service/shared/pkg/client"
"bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger" "bash_go_service/shared/pkg/logger"
"crypto/md5"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"time"
"version/api" "version/api"
) )
const (
updaterExeName = "updater-main"
)
func DoUpdate() { func DoUpdate() {
logger.Info("start") logger.Debug("start check updating")
res := api.GetVersion() res, err := api.GetVersion()
if err != nil {
logger.Error("Failed to get version: %v", err)
return
}
logger.Debug("res: %v", res) logger.Debug("res: %v", res)
if !res.NeedUpdate {
logger.Info("No update needed.")
return
}
exePath, err := os.Executable()
if err != nil {
logger.Error("Failed to get executable path: %v", err)
return
}
exeDir := filepath.Dir(exePath)
exeBase := filepath.Base(exePath)
// 备份旧程序
backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405"))
backupPath := filepath.Join(exeDir, backupName)
logger.Debug("backupPath: %v", backupPath)
err = copyFile(exePath, backupPath)
if err != nil {
logger.Error("Failed to backup current executable: %v", err)
return
}
logger.Info("Backup created: %s", backupPath)
// 下载新程序到临时路径
newExePath := filepath.Join(exeDir, exeBase+".new")
cli := client.NewClient()
query := map[string]string{
"version": res.Version,
}
err = cli.Download(constants.DownloadNewApi, query, newExePath)
if err != nil {
logger.Error("Download failed: %v", err)
_ = os.Rename(backupPath, exePath) // rollback
return
}
logger.Info("New executable downloaded to: %s", newExePath)
// 🔥 校验 MD5
if res.MD5 == "" {
logger.Error("MD5 is empty")
return
}
match, Merr := checkFileMD5(newExePath, res.MD5)
if Merr != nil {
logger.Error("Failed to check MD5: %v", Merr)
_ = os.Rename(backupPath, exePath)
return
}
if !match {
logger.Error("MD5 mismatch: expected %s", res.MD5)
_ = os.Rename(backupPath, exePath)
return
}
logger.Info("MD5 checksum verified.")
// 设置新文件执行权限
err = os.Chmod(newExePath, 0755)
if err != nil {
logger.Error("Failed to set executable permission: %v", err)
_ = os.Remove(newExePath)
_ = copyFile(backupPath, exePath)
return
}
logger.Info("Executable permission set for new file.")
// 启动 updater 进行替换
updaterPath := filepath.Join(exeDir, updaterExeName)
cmd := exec.Command(updaterPath, backupPath, newExePath, exePath)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Start()
if err != nil {
logger.Error("Failed to start updater: %v", err)
_ = os.Rename(backupPath, exePath)
return
}
logger.Info("Updater started, exiting current process.")
os.Exit(0)
}
// MD5 校验函数
func checkFileMD5(filePath, expectedMD5 string) (bool, error) {
f, err := os.Open(filePath)
if err != nil {
return false, err
}
defer f.Close()
hash := md5.New()
if _, err := io.Copy(hash, f); err != nil {
return false, err
}
actualMD5 := fmt.Sprintf("%x", hash.Sum(nil))
return actualMD5 == expectedMD5, nil
}
// 复制文件函数
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
} }