From 0bb14fd5364ccb3a7d77a6c430fbeb6ed4e1de52 Mon Sep 17 00:00:00 2001 From: "qcqcqc@wsl" <1220204124@zust.edu.cn> Date: Sun, 13 Apr 2025 16:56:17 +0800 Subject: [PATCH] =?UTF-8?q?=E8=87=AA=E5=8A=A8=E6=9B=B4=E6=96=B0=E4=B8=BB?= =?UTF-8?q?=E7=A8=8B=E5=BA=8F=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.work | 1 + shared/pkg/client/methods.go | 61 ++++++++++++ shared/pkg/constants/server.go | 1 + tests/testcase/userlogin/userlogin.go | 4 +- updater/cmd/main.go | 53 ++++++++++ updater/go.mod | 3 + version/api/api.go | 4 +- version/cmd/main.go | 4 +- version/handler/handler.go | 136 +++++++++++++++++++++++++- 9 files changed, 260 insertions(+), 7 deletions(-) create mode 100644 updater/cmd/main.go create mode 100644 updater/go.mod diff --git a/go.work b/go.work index 8fe1114..346da61 100644 --- a/go.work +++ b/go.work @@ -5,6 +5,7 @@ use ( ./config-loader ./shared ./tests + ./updater ./version ./welcome ) diff --git a/shared/pkg/client/methods.go b/shared/pkg/client/methods.go index 48869b2..97fa03e 100644 --- a/shared/pkg/client/methods.go +++ b/shared/pkg/client/methods.go @@ -5,9 +5,11 @@ import ( "bufio" "bytes" "encoding/json" + "fmt" "io" "net/http" "net/url" + "os" "path" ) @@ -261,3 +263,62 @@ func (c *Client) PostToStream(uri string, body interface{}, query map[string]str return nil } + +func (c *Client) Download(uri string, query map[string]string, filepath string) error { + c.trackRequest() + defer c.requestDone() + + // 构建请求URL + u, err := url.Parse(c.baseURL) + if err != nil { + logger.Error("Failed to parse base URL: %v", err) + return err + } + u.Path = path.Join(u.Path, uri) + + if query != nil { + q := u.Query() + for k, v := range query { + q.Set(k, v) + } + u.RawQuery = q.Encode() + } + + // 创建GET请求 + req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, u.String(), nil) + if err != nil { + logger.Error("Failed to create GET request: %v", err) + return err + } + + // 发送请求 + resp, err := c.client.Do(req) + if err != nil { + logger.Error("Download request failed: %v", err) + return err + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode != http.StatusOK { + logger.Error("Received non-OK status code: %d", resp.StatusCode) + return fmt.Errorf("download failed with status code: %d", resp.StatusCode) + } + + // 创建目标文件 + out, err := os.Create(filepath) + if err != nil { + logger.Error("Failed to create file: %v", err) + return err + } + defer out.Close() + + // 将响应内容写入文件 + _, err = io.Copy(out, resp.Body) + if err != nil { + logger.Error("Failed to write file: %v", err) + return err + } + + return nil +} diff --git a/shared/pkg/constants/server.go b/shared/pkg/constants/server.go index ed4892b..4a87fbc 100644 --- a/shared/pkg/constants/server.go +++ b/shared/pkg/constants/server.go @@ -11,4 +11,5 @@ const ( MachineInfoApi = "/machine/info" GetVersionApi = "/version" QuestionStreamApi = "/question" + DownloadNewApi = "/download" ) diff --git a/tests/testcase/userlogin/userlogin.go b/tests/testcase/userlogin/userlogin.go index 795f973..7c0c215 100644 --- a/tests/testcase/userlogin/userlogin.go +++ b/tests/testcase/userlogin/userlogin.go @@ -12,14 +12,14 @@ func main() { // 创建welcome service service, err := welcome.NewService() if err != nil { - logger.Error("Failed to create welcome service:", err) + logger.Error("Failed to create welcome service: %v", err) return } defer service.Stop() // 启动服务 if err := service.Start(); err != nil { - logger.Error("Failed to start welcome service:", err) + logger.Error("Failed to start welcome service: %v", err) return } diff --git a/updater/cmd/main.go b/updater/cmd/main.go new file mode 100644 index 0000000..1034626 --- /dev/null +++ b/updater/cmd/main.go @@ -0,0 +1,53 @@ +package main + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" +) + +func main() { + if len(os.Args) < 4 { + fmt.Println("Usage: updater ") + os.Exit(1) + } + + oldPath := os.Args[1] + newPath := os.Args[2] + targetPath := os.Args[3] + + // 替换当前程序 + err := os.Rename(newPath, targetPath) + if err != nil { + fmt.Printf("Failed to replace executable: %v\n", err) + _ = os.Rename(oldPath, targetPath) + os.Exit(1) + } + fmt.Printf("Replaced executable: %s -> %s\n", newPath, targetPath) + + // 设置可执行权限 + err = os.Chmod(targetPath, 0755) + if err != nil { + fmt.Printf("Failed to set permission on new executable: %v\n", err) + _ = os.Remove(targetPath) + _ = os.Rename(oldPath, targetPath) + os.Exit(1) + } + + // 启动新程序 + cmd := exec.Command(targetPath) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Dir = filepath.Dir(targetPath) + + err = cmd.Start() + if err != nil { + fmt.Printf("Failed to start new process: %v\n", err) + // 回滚 + _ = os.Rename(oldPath, targetPath) + os.Exit(1) + } + + fmt.Println("New process started successfully.") +} diff --git a/updater/go.mod b/updater/go.mod new file mode 100644 index 0000000..8a8f7d2 --- /dev/null +++ b/updater/go.mod @@ -0,0 +1,3 @@ +module updater + +go 1.24.2 diff --git a/version/api/api.go b/version/api/api.go index df639e3..4fbb6d6 100644 --- a/version/api/api.go +++ b/version/api/api.go @@ -12,7 +12,7 @@ type Result struct { NeedUpdate bool `json:"needUpdate"` } -func GetVersion() *Result { +func GetVersion() (*Result, error) { apiEndpoint := constants.GetVersionApi client := client.NewClient() // params @@ -24,5 +24,5 @@ func GetVersion() *Result { if err != nil { logger.Error("Failed to get version: %v", err) } - return &result + return &result, err } diff --git a/version/cmd/main.go b/version/cmd/main.go index 8243111..1e00697 100644 --- a/version/cmd/main.go +++ b/version/cmd/main.go @@ -17,10 +17,12 @@ func initConfig() error { func main() { if err := initConfig(); err != nil { - fmt.Println("Failed to init config: %v", err) + fmt.Printf("Failed to init config: %v", err) return } logger.UpdateLogLevel() + logger.Debug("Version: %s", constants.CurrentVersion) + handler.DoUpdate() } diff --git a/version/handler/handler.go b/version/handler/handler.go index 6136269..2ba4d34 100644 --- a/version/handler/handler.go +++ b/version/handler/handler.go @@ -1,12 +1,144 @@ package handler import ( + "bash_go_service/shared/pkg/client" + "bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/logger" + "crypto/md5" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "time" "version/api" ) +const ( + updaterExeName = "updater-main" +) + func DoUpdate() { - logger.Info("start") - res := api.GetVersion() + logger.Debug("start check updating") + res, err := api.GetVersion() + if err != nil { + logger.Error("Failed to get version: %v", err) + return + } logger.Debug("res: %v", res) + + if !res.NeedUpdate { + logger.Info("No update needed.") + return + } + + exePath, err := os.Executable() + if err != nil { + logger.Error("Failed to get executable path: %v", err) + return + } + exeDir := filepath.Dir(exePath) + exeBase := filepath.Base(exePath) + + // 备份旧程序 + backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405")) + backupPath := filepath.Join(exeDir, backupName) + logger.Debug("backupPath: %v", backupPath) + err = copyFile(exePath, backupPath) + if err != nil { + logger.Error("Failed to backup current executable: %v", err) + return + } + logger.Info("Backup created: %s", backupPath) + + // 下载新程序到临时路径 + newExePath := filepath.Join(exeDir, exeBase+".new") + cli := client.NewClient() + query := map[string]string{ + "version": res.Version, + } + err = cli.Download(constants.DownloadNewApi, query, newExePath) + if err != nil { + logger.Error("Download failed: %v", err) + _ = os.Rename(backupPath, exePath) // rollback + return + } + logger.Info("New executable downloaded to: %s", newExePath) + + // 🔥 校验 MD5 + if res.MD5 == "" { + logger.Error("MD5 is empty") + return + } + match, Merr := checkFileMD5(newExePath, res.MD5) + if Merr != nil { + logger.Error("Failed to check MD5: %v", Merr) + _ = os.Rename(backupPath, exePath) + return + } + if !match { + logger.Error("MD5 mismatch: expected %s", res.MD5) + _ = os.Rename(backupPath, exePath) + return + } + logger.Info("MD5 checksum verified.") + + // 设置新文件执行权限 + err = os.Chmod(newExePath, 0755) + if err != nil { + logger.Error("Failed to set executable permission: %v", err) + _ = os.Remove(newExePath) + _ = copyFile(backupPath, exePath) + return + } + logger.Info("Executable permission set for new file.") + + // 启动 updater 进行替换 + updaterPath := filepath.Join(exeDir, updaterExeName) + cmd := exec.Command(updaterPath, backupPath, newExePath, exePath) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + err = cmd.Start() + if err != nil { + logger.Error("Failed to start updater: %v", err) + _ = os.Rename(backupPath, exePath) + return + } + logger.Info("Updater started, exiting current process.") + os.Exit(0) +} + +// MD5 校验函数 +func checkFileMD5(filePath, expectedMD5 string) (bool, error) { + f, err := os.Open(filePath) + if err != nil { + return false, err + } + defer f.Close() + + hash := md5.New() + if _, err := io.Copy(hash, f); err != nil { + return false, err + } + actualMD5 := fmt.Sprintf("%x", hash.Sum(nil)) + return actualMD5 == expectedMD5, nil +} + +// 复制文件函数 +func copyFile(src, dst string) error { + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + destFile, err := os.Create(dst) + if err != nil { + return err + } + defer destFile.Close() + + _, err = io.Copy(destFile, sourceFile) + return err }