自动更新主程序功能
This commit is contained in:
parent
d6d9d7709f
commit
0bb14fd536
1
go.work
1
go.work
|
|
@ -5,6 +5,7 @@ use (
|
|||
./config-loader
|
||||
./shared
|
||||
./tests
|
||||
./updater
|
||||
./version
|
||||
./welcome
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,11 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
|
|
@ -261,3 +263,62 @@ func (c *Client) PostToStream(uri string, body interface{}, query map[string]str
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Download(uri string, query map[string]string, filepath string) error {
|
||||
c.trackRequest()
|
||||
defer c.requestDone()
|
||||
|
||||
// 构建请求URL
|
||||
u, err := url.Parse(c.baseURL)
|
||||
if err != nil {
|
||||
logger.Error("Failed to parse base URL: %v", err)
|
||||
return err
|
||||
}
|
||||
u.Path = path.Join(u.Path, uri)
|
||||
|
||||
if query != nil {
|
||||
q := u.Query()
|
||||
for k, v := range query {
|
||||
q.Set(k, v)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// 创建GET请求
|
||||
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create GET request: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
logger.Error("Download request failed: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Error("Received non-OK status code: %d", resp.StatusCode)
|
||||
return fmt.Errorf("download failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 创建目标文件
|
||||
out, err := os.Create(filepath)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create file: %v", err)
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// 将响应内容写入文件
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write file: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,4 +11,5 @@ const (
|
|||
MachineInfoApi = "/machine/info"
|
||||
GetVersionApi = "/version"
|
||||
QuestionStreamApi = "/question"
|
||||
DownloadNewApi = "/download"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@ func main() {
|
|||
// 创建welcome service
|
||||
service, err := welcome.NewService()
|
||||
if err != nil {
|
||||
logger.Error("Failed to create welcome service:", err)
|
||||
logger.Error("Failed to create welcome service: %v", err)
|
||||
return
|
||||
}
|
||||
defer service.Stop()
|
||||
|
||||
// 启动服务
|
||||
if err := service.Start(); err != nil {
|
||||
logger.Error("Failed to start welcome service:", err)
|
||||
logger.Error("Failed to start welcome service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Println("Usage: updater <oldPath> <newPath> <targetPath>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
oldPath := os.Args[1]
|
||||
newPath := os.Args[2]
|
||||
targetPath := os.Args[3]
|
||||
|
||||
// 替换当前程序
|
||||
err := os.Rename(newPath, targetPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to replace executable: %v\n", err)
|
||||
_ = os.Rename(oldPath, targetPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("Replaced executable: %s -> %s\n", newPath, targetPath)
|
||||
|
||||
// 设置可执行权限
|
||||
err = os.Chmod(targetPath, 0755)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to set permission on new executable: %v\n", err)
|
||||
_ = os.Remove(targetPath)
|
||||
_ = os.Rename(oldPath, targetPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 启动新程序
|
||||
cmd := exec.Command(targetPath)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Dir = filepath.Dir(targetPath)
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to start new process: %v\n", err)
|
||||
// 回滚
|
||||
_ = os.Rename(oldPath, targetPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println("New process started successfully.")
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
module updater
|
||||
|
||||
go 1.24.2
|
||||
|
|
@ -12,7 +12,7 @@ type Result struct {
|
|||
NeedUpdate bool `json:"needUpdate"`
|
||||
}
|
||||
|
||||
func GetVersion() *Result {
|
||||
func GetVersion() (*Result, error) {
|
||||
apiEndpoint := constants.GetVersionApi
|
||||
client := client.NewClient()
|
||||
// params
|
||||
|
|
@ -24,5 +24,5 @@ func GetVersion() *Result {
|
|||
if err != nil {
|
||||
logger.Error("Failed to get version: %v", err)
|
||||
}
|
||||
return &result
|
||||
return &result, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,10 +17,12 @@ func initConfig() error {
|
|||
|
||||
func main() {
|
||||
if err := initConfig(); err != nil {
|
||||
fmt.Println("Failed to init config: %v", err)
|
||||
fmt.Printf("Failed to init config: %v", err)
|
||||
return
|
||||
}
|
||||
logger.UpdateLogLevel()
|
||||
|
||||
logger.Debug("Version: %s", constants.CurrentVersion)
|
||||
|
||||
handler.DoUpdate()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,144 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"bash_go_service/shared/pkg/client"
|
||||
"bash_go_service/shared/pkg/constants"
|
||||
"bash_go_service/shared/pkg/logger"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
"version/api"
|
||||
)
|
||||
|
||||
const (
|
||||
updaterExeName = "updater-main"
|
||||
)
|
||||
|
||||
func DoUpdate() {
|
||||
logger.Info("start")
|
||||
res := api.GetVersion()
|
||||
logger.Debug("start check updating")
|
||||
res, err := api.GetVersion()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get version: %v", err)
|
||||
return
|
||||
}
|
||||
logger.Debug("res: %v", res)
|
||||
|
||||
if !res.NeedUpdate {
|
||||
logger.Info("No update needed.")
|
||||
return
|
||||
}
|
||||
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get executable path: %v", err)
|
||||
return
|
||||
}
|
||||
exeDir := filepath.Dir(exePath)
|
||||
exeBase := filepath.Base(exePath)
|
||||
|
||||
// 备份旧程序
|
||||
backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405"))
|
||||
backupPath := filepath.Join(exeDir, backupName)
|
||||
logger.Debug("backupPath: %v", backupPath)
|
||||
err = copyFile(exePath, backupPath)
|
||||
if err != nil {
|
||||
logger.Error("Failed to backup current executable: %v", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Backup created: %s", backupPath)
|
||||
|
||||
// 下载新程序到临时路径
|
||||
newExePath := filepath.Join(exeDir, exeBase+".new")
|
||||
cli := client.NewClient()
|
||||
query := map[string]string{
|
||||
"version": res.Version,
|
||||
}
|
||||
err = cli.Download(constants.DownloadNewApi, query, newExePath)
|
||||
if err != nil {
|
||||
logger.Error("Download failed: %v", err)
|
||||
_ = os.Rename(backupPath, exePath) // rollback
|
||||
return
|
||||
}
|
||||
logger.Info("New executable downloaded to: %s", newExePath)
|
||||
|
||||
// 🔥 校验 MD5
|
||||
if res.MD5 == "" {
|
||||
logger.Error("MD5 is empty")
|
||||
return
|
||||
}
|
||||
match, Merr := checkFileMD5(newExePath, res.MD5)
|
||||
if Merr != nil {
|
||||
logger.Error("Failed to check MD5: %v", Merr)
|
||||
_ = os.Rename(backupPath, exePath)
|
||||
return
|
||||
}
|
||||
if !match {
|
||||
logger.Error("MD5 mismatch: expected %s", res.MD5)
|
||||
_ = os.Rename(backupPath, exePath)
|
||||
return
|
||||
}
|
||||
logger.Info("MD5 checksum verified.")
|
||||
|
||||
// 设置新文件执行权限
|
||||
err = os.Chmod(newExePath, 0755)
|
||||
if err != nil {
|
||||
logger.Error("Failed to set executable permission: %v", err)
|
||||
_ = os.Remove(newExePath)
|
||||
_ = copyFile(backupPath, exePath)
|
||||
return
|
||||
}
|
||||
logger.Info("Executable permission set for new file.")
|
||||
|
||||
// 启动 updater 进行替换
|
||||
updaterPath := filepath.Join(exeDir, updaterExeName)
|
||||
cmd := exec.Command(updaterPath, backupPath, newExePath, exePath)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start updater: %v", err)
|
||||
_ = os.Rename(backupPath, exePath)
|
||||
return
|
||||
}
|
||||
logger.Info("Updater started, exiting current process.")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// MD5 校验函数
|
||||
func checkFileMD5(filePath, expectedMD5 string) (bool, error) {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
hash := md5.New()
|
||||
if _, err := io.Copy(hash, f); err != nil {
|
||||
return false, err
|
||||
}
|
||||
actualMD5 := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
return actualMD5 == expectedMD5, nil
|
||||
}
|
||||
|
||||
// 复制文件函数
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, sourceFile)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue