检查版本更新任务

This commit is contained in:
Pan Qiancheng 2025-04-23 13:24:09 +08:00
parent 0b37200b15
commit 69eb29f925
8 changed files with 372 additions and 38 deletions

View File

@ -341,6 +341,5 @@ func runDaemon() {
// showWelcomeMessage 显示欢迎信息 // showWelcomeMessage 显示欢迎信息
func showWelcomeMessage(info *client.Result) { func showWelcomeMessage(info *client.Result) {
fmt.Println(info.Msg) fmt.Println(info.Msg)
} }

View File

@ -21,7 +21,9 @@ require (
bash_go_service/shared v0.0.0-00010101000000-000000000000 bash_go_service/shared v0.0.0-00010101000000-000000000000
bash_go_service/welcome v0.0.0-00010101000000-000000000000 bash_go_service/welcome v0.0.0-00010101000000-000000000000
bash_go_service/version v0.0.0-00010101000000-000000000000
) )
replace bash_go_service/shared => ../shared replace bash_go_service/shared => ../shared
replace bash_go_service/welcome => ../welcome replace bash_go_service/welcome => ../welcome
replace bash_go_service/version => ../version

View File

@ -0,0 +1,145 @@
package service
import (
"bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger"
"bash_go_service/version/pkg/handler"
"context"
"os"
"path/filepath"
"time"
"github.com/spf13/viper"
)
// VersionCheckTask represents a task that checks for version updates.
type VersionCheckTask struct {
interval time.Duration
}
// NewVersionCheckTask creates a new VersionCheckTask.
func NewVersionCheckTask() *VersionCheckTask {
interval := viper.GetDuration("update_checker.interval")
if interval < 1*time.Hour {
interval = 1 * time.Hour // 默认值为1小时
}
logger.Debug("VersionCheckTask interval: %v", interval)
return &VersionCheckTask{
interval: interval,
}
}
// Execute runs the version check task periodically.
func (t *VersionCheckTask) Execute(ctx context.Context) {
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.CheckUpdate(ctx)
case <-ctx.Done():
logger.Info("VersionCheckTask stopped")
return
}
}
}
// 检查更新,将文件下载到本地,在下次启动程序时执行更新
func (t *VersionCheckTask) CheckUpdate(ctx context.Context) {
logger.Info("Executing version update check at %v", time.Now())
res, err := handler.CheckVersionUpdate()
needUpdate, newVersion, md5sum, soMd5 := res.NeedUpdate, res.Version, res.ExeMd5, res.SoMd5
if err != nil {
logger.Error("Failed to check for updates: %v", err)
return
}
if !needUpdate {
logger.Info("No update needed.")
return
}
currentExePath, err := os.Executable()
if err != nil {
logger.Error("Failed to get executable path: %v", err)
return
}
currentExeDir := filepath.Dir(currentExePath)
currentExeBase := filepath.Base(currentExePath)
backupPath, err := handler.BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir)
if err != nil {
logger.Error("Failed to backup current executable: %v", err)
return
}
logger.Info("Backup current executable to: %s", backupPath)
currentSoPath := filepath.Join(constants.SoPath, constants.SoName)
soBackupPath, err := handler.BackupCurrentExecutable(
currentSoPath,
constants.SoName, currentExeDir)
if err != nil {
logger.Error("Failed to backup so-lib: %v", err)
return
}
logger.Info("Backup current so-lib to: %s", soBackupPath)
// 下载最新的exe并且校验md5
newExePath, err := handler.DownloadNewExecutable(constants.SoPath, currentExeBase, newVersion, "exe")
if err != nil {
logger.Error("Failed to download new executable: %v", err)
_ = os.Remove(backupPath)
return
}
if err := handler.VerifyMD5(newExePath, md5sum); err != nil {
logger.Error("MD5 verification failed: %v", err)
// Attempt rollback
_ = os.Remove(newExePath)
return
}
// 下载最新的so并且校验md5
newSoPath, err := handler.DownloadNewExecutable(constants.SoPath, constants.SoName, newVersion, "so")
if err != nil {
logger.Error("Failed to download new so-lib: %v", err)
// Attempt rollback
_ = os.Remove(soBackupPath)
return
}
if err := handler.VerifyMD5(newSoPath, soMd5); err != nil {
logger.Error("MD5 verification failed: %v", err)
// Attempt rollback
_ = os.Remove(newSoPath)
return
}
info := handler.CreateUpdateInfo(backupPath, newExePath, currentExePath, soBackupPath, newSoPath, currentSoPath)
err = handler.SaveUpdateInfo(currentExePath, *info)
if err != nil {
logger.Error("Failed to save update info: %v", err)
return
}
logger.Info("The update will be initiated on the next application start.")
// 下次启动的时候检查更新
// if err := handler.ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil {
// logger.Error("Failed to execute update: %v", err)
// // Attempt rollback
// _ = os.Rename(backupPath, currentExePath)
// _ = os.Remove(newExePath)
// return
// }
// logger.Info("Update process initiated, current process exiting.")
//
}
// Name returns the name of the task.
func (t *VersionCheckTask) Name() string {
return "VersionCheckTask"
}

View File

@ -3,4 +3,8 @@ machine_registry:
log: log:
level: DEBUG level: DEBUG
bash_config: bash_config:
loader: SERVER loader: SERVER
watch_interval: 86400
update_checker:
interval: 86400

View File

@ -1,23 +1,26 @@
package constants package constants
const ( const (
MaxRules = 128 // 最大规则数量 MaxRules = 128 // 最大规则数量
MaxArgs = 10 // 每条规则的最大参数数量 MaxArgs = 10 // 每条规则的最大参数数量
MaxRuleCmdLength = 256 // 规则命令的最大长度 MaxRuleCmdLength = 256 // 规则命令的最大长度
MaxRuleTypeLength = 32 // 规则类型的最大长度 MaxRuleTypeLength = 32 // 规则类型的最大长度
MaxRuleMsgLength = 1024 // 规则消息的最大长度 MaxRuleMsgLength = 1024 // 规则消息的最大长度
MaxArgLength = 256 // 单个参数的最大长度 MaxArgLength = 256 // 单个参数的最大长度
ShmKey = 0x78945 // 共享内存的键值 ShmKey = 0x78945 // 共享内存的键值
ShmSize = 512 * 1024 // 共享内存的大小(字节) ShmSize = 512 * 1024 // 共享内存的大小(字节)
ConfigFile = "./config/execve_rules.json" // 配置文件路径 ConfigFile = "./config/execve_rules.json" // 配置文件路径
PidFilePath = "/tmp/bash_service.pid" // PID 文件路径 PidFilePath = "/tmp/bash_service.pid" // PID 文件路径
DaemonFlag = "-daemon" // 后台进程标志 DaemonFlag = "-daemon" // 后台进程标志
ForceFlag = "-force" // 强制结束之前的后台进程 ForceFlag = "-force" // 强制结束之前的后台进程
DebugFlag = "-debug" // 强制结束之前的后台进程 DebugFlag = "-debug" // 强制结束之前的后台进程
QuitFlag = "-quit" // 强制结束并退出 QuitFlag = "-quit" // 强制结束并退出
ConfigFileMode = 0644 // 文件权限 ConfigFileMode = 0644 // 文件权限
ConfigPath = "./config" // 配置文件路径 ConfigPath = "./config" // 配置文件路径
LogFileMode = 0755 // 日志文件权限 LogFileMode = 0755 // 日志文件权限
LogFilePath = "./logs" // 日志文件路径 LogFilePath = "./logs" // 日志文件路径
LogNameFormate = "%Y-%m-%d.log" // 日志文件名称格式 LogNameFormate = "%Y-%m-%d.log" // 日志文件名称格式
SoPath = "/tmp/bash_hook" // 共享库路径
SoName = "intercept.so" // Hook 文件名称
UpdateInfoFileName = "update_info.bin" // 更新信息文件名称
) )

View File

@ -0,0 +1,92 @@
package main
import (
"bytes"
"encoding/gob"
"fmt"
"os"
"path/filepath"
)
// UpdateInfo 结构体定义
type UpdateInfo struct {
Version string
Date string
Details string
}
// saveUpdateInfo 函数实现
func saveUpdateInfo(exePath string, info UpdateInfo) error {
// 创建一个字节缓冲区
var buffer bytes.Buffer
// 创建一个新的编码器,并将结构体编码到缓冲区
encoder := gob.NewEncoder(&buffer)
err := encoder.Encode(info)
if err != nil {
return fmt.Errorf("error encoding UpdateInfo: %v", err)
}
// 确保目录存在
err = os.MkdirAll(exePath, 0755)
if err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// 构建文件路径
filePath := filepath.Join(exePath, "update-infos")
// 将二进制数据写入文件
err = os.WriteFile(filePath, buffer.Bytes(), 0644)
if err != nil {
return fmt.Errorf("error writing to file: %v", err)
}
return nil
}
// LoadUpdateInfo 函数实现
func LoadUpdateInfo(exePath string) (*UpdateInfo, error) {
// 构建文件路径
filePath := filepath.Join(exePath, "update-infos")
// 从文件中读取二进制数据
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("error reading from file: %v", err)
}
// 创建一个新的解码器,并将二进制数据解码到结构体
buffer := bytes.NewBuffer(data)
decoder := gob.NewDecoder(buffer)
var info UpdateInfo
err = decoder.Decode(&info)
if err != nil {
return nil, fmt.Errorf("error decoding binary data: %v", err)
}
return &info, nil
}
func main() {
// 示例使用
info := UpdateInfo{
Version: "1.0.1",
Date: "2025-04-23",
Details: "Bug fixes and performance improvements.",
}
err := saveUpdateInfo("./example/path", info)
if err != nil {
fmt.Println("Error:", err)
} else {
fmt.Println("UpdateInfo saved successfully.")
}
loadedInfo, err := LoadUpdateInfo("./example/path")
if err != nil {
fmt.Println("Error:", err)
} else {
fmt.Printf("Loaded UpdateInfo: %+v\n", loadedInfo)
}
}

View File

@ -8,7 +8,8 @@ import (
type Result struct { type Result struct {
Version string `json:"version"` Version string `json:"version"`
MD5 string `json:"md5"` ExeMd5 string `json:"exeMd5"`
SoMd5 string `json:"soMd5"`
NeedUpdate bool `json:"needUpdate"` NeedUpdate bool `json:"needUpdate"`
} }

View File

@ -4,7 +4,9 @@ import (
"bash_go_service/shared/pkg/client" "bash_go_service/shared/pkg/client"
"bash_go_service/shared/pkg/constants" "bash_go_service/shared/pkg/constants"
"bash_go_service/shared/pkg/logger" "bash_go_service/shared/pkg/logger"
"bytes"
"crypto/md5" "crypto/md5"
"encoding/gob"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -18,11 +20,92 @@ const (
updaterExeName = "updater-main" updaterExeName = "updater-main"
) )
type UpdateInfo struct {
ExeFileOldPath string
ExeFileNewPath string
ExeTargetPath string
SoFileOldPath string
SoFileNewPath string
SoTargetPath string
}
func CreateUpdateInfo(
exeFileOldPath string,
exeFileNewPath string,
exeTargetPath string,
soFileOldPath string,
soFileNewPath string,
soTargetPath string,
) *UpdateInfo {
return &UpdateInfo{
ExeFileOldPath: exeFileOldPath,
ExeFileNewPath: exeFileNewPath,
ExeTargetPath: exeTargetPath,
SoFileOldPath: soFileOldPath,
SoFileNewPath: soFileNewPath,
SoTargetPath: soTargetPath,
}
}
// saveUpdateInfo 函数实现
func SaveUpdateInfo(exePath string, info UpdateInfo) error {
// 创建一个字节缓冲区
var buffer bytes.Buffer
// 创建一个新的编码器,并将结构体编码到缓冲区
encoder := gob.NewEncoder(&buffer)
err := encoder.Encode(info)
if err != nil {
return fmt.Errorf("error encoding UpdateInfo: %v", err)
}
// 确保目录存在
err = os.MkdirAll(exePath, 0755)
if err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// 构建文件路径
filePath := filepath.Join(exePath, constants.UpdateInfoFileName)
// 将二进制数据写入文件
err = os.WriteFile(filePath, buffer.Bytes(), 0644)
if err != nil {
return fmt.Errorf("error writing to file: %v", err)
}
return nil
}
// LoadUpdateInfo 函数实现
func LoadUpdateInfo(exePath string) (*UpdateInfo, error) {
// 构建文件路径
filePath := filepath.Join(exePath, "update-infos")
// 从文件中读取二进制数据
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("error reading from file: %v", err)
}
// 创建一个新的解码器,并将二进制数据解码到结构体
buffer := bytes.NewBuffer(data)
decoder := gob.NewDecoder(buffer)
var info UpdateInfo
err = decoder.Decode(&info)
if err != nil {
return nil, fmt.Errorf("error decoding binary data: %v", err)
}
return &info, nil
}
// DoUpdate checks for updates, downloads if needed, and executes the update. // DoUpdate checks for updates, downloads if needed, and executes the update.
func DoUpdate() { func DoUpdate() {
logger.Debug("start check updating") logger.Debug("start check updating")
needUpdate, newVersion, md5sum, err := checkVersionUpdate() res, err := CheckVersionUpdate()
needUpdate, newVersion, md5sum := res.NeedUpdate, res.Version, res.ExeMd5
if err != nil { if err != nil {
logger.Error("Failed to check for updates: %v", err) logger.Error("Failed to check for updates: %v", err)
return return
@ -41,14 +124,14 @@ func DoUpdate() {
currentExeDir := filepath.Dir(currentExePath) currentExeDir := filepath.Dir(currentExePath)
currentExeBase := filepath.Base(currentExePath) currentExeBase := filepath.Base(currentExePath)
backupPath, err := backupCurrentExecutable(currentExePath, currentExeBase, currentExeDir) backupPath, err := BackupCurrentExecutable(currentExePath, currentExeBase, currentExeDir)
if err != nil { if err != nil {
logger.Error("Failed to backup current executable: %v", err) logger.Error("Failed to backup current executable: %v", err)
return return
} }
// 这只下载exe文件进行更新如果需要更新其他可以自定义逻辑 // 这只下载exe文件进行更新如果需要更新其他可以自定义逻辑
newExePath, err := downloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe") newExePath, err := DownloadNewExecutable(currentExeDir, currentExeBase, newVersion, "exe")
if err != nil { if err != nil {
logger.Error("Failed to download new executable: %v", err) logger.Error("Failed to download new executable: %v", err)
// Attempt rollback // Attempt rollback
@ -56,7 +139,7 @@ func DoUpdate() {
return return
} }
if err := verifyMD5(newExePath, md5sum); err != nil { if err := VerifyMD5(newExePath, md5sum); err != nil {
logger.Error("MD5 verification failed: %v", err) logger.Error("MD5 verification failed: %v", err)
// Attempt rollback // Attempt rollback
_ = os.Rename(backupPath, currentExePath) _ = os.Rename(backupPath, currentExePath)
@ -64,7 +147,7 @@ func DoUpdate() {
return return
} }
if err := executeUpdate(currentExeDir, backupPath, newExePath, currentExePath); err != nil { if err := ExecuteUpdate(currentExeDir, backupPath, newExePath, currentExePath, true); err != nil {
logger.Error("Failed to execute update: %v", err) logger.Error("Failed to execute update: %v", err)
// Attempt rollback // Attempt rollback
_ = os.Rename(backupPath, currentExePath) _ = os.Rename(backupPath, currentExePath)
@ -77,18 +160,19 @@ func DoUpdate() {
} }
// checkVersionUpdate checks if a new version is available. // checkVersionUpdate checks if a new version is available.
func checkVersionUpdate() (needUpdate bool, newVersion string, md5sum string, err error) { func CheckVersionUpdate() (res *api.Result, err error) {
res, err := api.GetVersion() res, err = api.GetVersion()
if err != nil { if err != nil {
logger.Error("Failed to get version information: %v", err) logger.Error("Failed to get version information: %v", err)
return false, "", "", err return nil, err
} }
logger.Debug("version response: %v", res) logger.Debug("version response: %v", res)
return res.NeedUpdate, res.Version, res.MD5, nil return res, nil
} }
// backupCurrentExecutable creates a backup of the current executable. // exePath 待备份的文件路径, exeBase 文件名, exeDir 目录
func backupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) { // 返回 备份文件路径和错误信息
func BackupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string, err error) {
backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405")) backupName := fmt.Sprintf("%s.bak.%s", exeBase, time.Now().Format("20060102_150405"))
backupPath = filepath.Join(exeDir, backupName) backupPath = filepath.Join(exeDir, backupName)
logger.Debug("backup path: %v", backupPath) logger.Debug("backup path: %v", backupPath)
@ -102,7 +186,7 @@ func backupCurrentExecutable(exePath, exeBase, exeDir string) (backupPath string
} }
// downloadNewExecutable downloads the new executable file. // downloadNewExecutable downloads the new executable file.
func downloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) (newExePath string, err error) { func DownloadNewExecutable(exeDir, exeBase, newVersion string, filetype string) (newExePath string, err error) {
newExePath = filepath.Join(exeDir, exeBase+".new") newExePath = filepath.Join(exeDir, exeBase+".new")
cli := client.NewClient() cli := client.NewClient()
query := map[string]string{ query := map[string]string{
@ -119,7 +203,7 @@ func downloadNewExecutable(exeDir, exeBase, newVersion string, filetype string)
} }
// verifyMD5 checks if the downloaded file's MD5 matches the expected value. // verifyMD5 checks if the downloaded file's MD5 matches the expected value.
func verifyMD5(filePath, expectedMD5 string) error { func VerifyMD5(filePath, expectedMD5 string) error {
if expectedMD5 == "" { if expectedMD5 == "" {
logger.Error("MD5 is empty, skipping verification.") logger.Error("MD5 is empty, skipping verification.")
return nil return nil
@ -144,9 +228,13 @@ func verifyMD5(filePath, expectedMD5 string) error {
} }
// executeUpdate starts the updater process to replace the current executable. // executeUpdate starts the updater process to replace the current executable.
func executeUpdate(exeDir, backupPath, newExePath, currentExePath string) error { func ExecuteUpdate(exeDir, backupPath, newExePath, currentExePath string, autoStart bool) error {
updaterPath := filepath.Join(exeDir, updaterExeName) updaterPath := filepath.Join(exeDir, updaterExeName)
cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, "true") autoStartFlag := "false"
if autoStart {
autoStartFlag = "true"
}
cmd := exec.Command(updaterPath, backupPath, newExePath, currentExePath, autoStartFlag)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr