自动更新主程序功能
This commit is contained in:
parent
d6d9d7709f
commit
0bb14fd536
1
go.work
1
go.work
|
|
@ -5,6 +5,7 @@ use (
|
||||||
./config-loader
|
./config-loader
|
||||||
./shared
|
./shared
|
||||||
./tests
|
./tests
|
||||||
|
./updater
|
||||||
./version
|
./version
|
||||||
./welcome
|
./welcome
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,11 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -261,3 +263,62 @@ func (c *Client) PostToStream(uri string, body interface{}, query map[string]str
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) Download(uri string, query map[string]string, filepath string) error {
|
||||||
|
c.trackRequest()
|
||||||
|
defer c.requestDone()
|
||||||
|
|
||||||
|
// 构建请求URL
|
||||||
|
u, err := url.Parse(c.baseURL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to parse base URL: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
u.Path = path.Join(u.Path, uri)
|
||||||
|
|
||||||
|
if query != nil {
|
||||||
|
q := u.Query()
|
||||||
|
for k, v := range query {
|
||||||
|
q.Set(k, v)
|
||||||
|
}
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建GET请求
|
||||||
|
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create GET request: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Download request failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 检查响应状态
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
logger.Error("Received non-OK status code: %d", resp.StatusCode)
|
||||||
|
return fmt.Errorf("download failed with status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建目标文件
|
||||||
|
out, err := os.Create(filepath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create file: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
// 将响应内容写入文件
|
||||||
|
_, err = io.Copy(out, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to write file: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,4 +11,5 @@ const (
|
||||||
MachineInfoApi = "/machine/info"
|
MachineInfoApi = "/machine/info"
|
||||||
GetVersionApi = "/version"
|
GetVersionApi = "/version"
|
||||||
QuestionStreamApi = "/question"
|
QuestionStreamApi = "/question"
|
||||||
|
DownloadNewApi = "/download"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -12,14 +12,14 @@ func main() {
|
||||||
// 创建welcome service
|
// 创建welcome service
|
||||||
service, err := welcome.NewService()
|
service, err := welcome.NewService()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to create welcome service:", err)
|
logger.Error("Failed to create welcome service: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer service.Stop()
|
defer service.Stop()
|
||||||
|
|
||||||
// 启动服务
|
// 启动服务
|
||||||
if err := service.Start(); err != nil {
|
if err := service.Start(); err != nil {
|
||||||
logger.Error("Failed to start welcome service:", err)
|
logger.Error("Failed to start welcome service: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if len(os.Args) < 4 {
|
||||||
|
fmt.Println("Usage: updater <oldPath> <newPath> <targetPath>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldPath := os.Args[1]
|
||||||
|
newPath := os.Args[2]
|
||||||
|
targetPath := os.Args[3]
|
||||||
|
|
||||||
|
// 替换当前程序
|
||||||
|
err := os.Rename(newPath, targetPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to replace executable: %v\n", err)
|
||||||
|
_ = os.Rename(oldPath, targetPath)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("Replaced executable: %s -> %s\n", newPath, targetPath)
|
||||||
|
|
||||||
|
// 设置可执行权限
|
||||||
|
err = os.Chmod(targetPath, 0755)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to set permission on new executable: %v\n", err)
|
||||||
|
_ = os.Remove(targetPath)
|
||||||
|
_ = os.Rename(oldPath, targetPath)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动新程序
|
||||||
|
cmd := exec.Command(targetPath)
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Dir = filepath.Dir(targetPath)
|
||||||
|
|
||||||
|
err = cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to start new process: %v\n", err)
|
||||||
|
// 回滚
|
||||||
|
_ = os.Rename(oldPath, targetPath)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("New process started successfully.")
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
module updater
|
||||||
|
|
||||||
|
go 1.24.2
|
||||||
|
|
@ -12,7 +12,7 @@ type Result struct {
|
||||||
NeedUpdate bool `json:"needUpdate"`
|
NeedUpdate bool `json:"needUpdate"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetVersion() *Result {
|
func GetVersion() (*Result, error) {
|
||||||
apiEndpoint := constants.GetVersionApi
|
apiEndpoint := constants.GetVersionApi
|
||||||
client := client.NewClient()
|
client := client.NewClient()
|
||||||
// params
|
// params
|
||||||
|
|
@ -24,5 +24,5 @@ func GetVersion() *Result {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to get version: %v", err)
|
logger.Error("Failed to get version: %v", err)
|
||||||
}
|
}
|
||||||
return &result
|
return &result, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,12 @@ func initConfig() error {
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if err := initConfig(); err != nil {
|
if err := initConfig(); err != nil {
|
||||||
fmt.Println("Failed to init config: %v", err)
|
fmt.Printf("Failed to init config: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.UpdateLogLevel()
|
logger.UpdateLogLevel()
|
||||||
|
|
||||||
|
logger.Debug("Version: %s", constants.CurrentVersion)
|
||||||
|
|
||||||
handler.DoUpdate()
|
handler.DoUpdate()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,144 @@
|
||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bash_go_service/shared/pkg/client"
|
||||||
|
"bash_go_service/shared/pkg/constants"
|
||||||
"bash_go_service/shared/pkg/logger"
|
"bash_go_service/shared/pkg/logger"
|
||||||
|
"crypto/md5"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
"version/api"
|
"version/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
updaterExeName = "updater-main"
|
||||||
|
)
|
||||||
|
|
||||||
func DoUpdate() {
|
func DoUpdate() {
|
||||||
logger.Info("start")
|
logger.Debug("start check updating")
|
||||||
res := api.GetVersion()
|
res, err := api.GetVersion()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to get version: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
logger.Debug("res: %v", res)
|
logger.Debug("res: %v", res)
|
||||||
|
|
||||||
|
if !res.NeedUpdate {
|
||||||
|
logger.Info("No update needed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exePath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to get executable path: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
exeDir := filepath.Dir(exePath)
|
||||||
|
exeBase := filepath.Base(exePath)
|
||||||
|
|
||||||
|
// 备份旧程序
|
||||||
|
backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405"))
|
||||||
|
backupPath := filepath.Join(exeDir, backupName)
|
||||||
|
logger.Debug("backupPath: %v", backupPath)
|
||||||
|
err = copyFile(exePath, backupPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to backup current executable: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("Backup created: %s", backupPath)
|
||||||
|
|
||||||
|
// 下载新程序到临时路径
|
||||||
|
newExePath := filepath.Join(exeDir, exeBase+".new")
|
||||||
|
cli := client.NewClient()
|
||||||
|
query := map[string]string{
|
||||||
|
"version": res.Version,
|
||||||
|
}
|
||||||
|
err = cli.Download(constants.DownloadNewApi, query, newExePath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Download failed: %v", err)
|
||||||
|
_ = os.Rename(backupPath, exePath) // rollback
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("New executable downloaded to: %s", newExePath)
|
||||||
|
|
||||||
|
// 🔥 校验 MD5
|
||||||
|
if res.MD5 == "" {
|
||||||
|
logger.Error("MD5 is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
match, Merr := checkFileMD5(newExePath, res.MD5)
|
||||||
|
if Merr != nil {
|
||||||
|
logger.Error("Failed to check MD5: %v", Merr)
|
||||||
|
_ = os.Rename(backupPath, exePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
logger.Error("MD5 mismatch: expected %s", res.MD5)
|
||||||
|
_ = os.Rename(backupPath, exePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("MD5 checksum verified.")
|
||||||
|
|
||||||
|
// 设置新文件执行权限
|
||||||
|
err = os.Chmod(newExePath, 0755)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to set executable permission: %v", err)
|
||||||
|
_ = os.Remove(newExePath)
|
||||||
|
_ = copyFile(backupPath, exePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("Executable permission set for new file.")
|
||||||
|
|
||||||
|
// 启动 updater 进行替换
|
||||||
|
updaterPath := filepath.Join(exeDir, updaterExeName)
|
||||||
|
cmd := exec.Command(updaterPath, backupPath, newExePath, exePath)
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
err = cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to start updater: %v", err)
|
||||||
|
_ = os.Rename(backupPath, exePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("Updater started, exiting current process.")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MD5 校验函数
|
||||||
|
func checkFileMD5(filePath, expectedMD5 string) (bool, error) {
|
||||||
|
f, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
hash := md5.New()
|
||||||
|
if _, err := io.Copy(hash, f); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
actualMD5 := fmt.Sprintf("%x", hash.Sum(nil))
|
||||||
|
return actualMD5 == expectedMD5, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复制文件函数
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
sourceFile, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer sourceFile.Close()
|
||||||
|
|
||||||
|
destFile, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer destFile.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(destFile, sourceFile)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue