feat: 使用缓冲区优化,使用pipe优化窗口大小监听

This commit is contained in:
QCQCQC@Debian 2025-12-12 18:40:42 +08:00
parent 708549ddbd
commit d27cbaa6e9
6 changed files with 485 additions and 66 deletions

152
TEST_DEBUG.md Normal file
View File

@ -0,0 +1,152 @@
# DEBUG 模式测试说明
## 问题描述
之前遇到的 "connection reset by peer" 错误的根本原因是:
1. **TCP 流式传输特性**: 数据可能分多次到达,不能假设一次 `read()` 就能读到完整数据
2. **部分读取问题**: 原代码只调用一次 `read()`,导致读取不完整的消息头或载荷
3. **错误表现**: 客户端读取消息头时只收到 12 字节(期望 16 字节),导致解析失败并关闭连接
## 修复方案
### C 代码修复 (`socket_protocol.c`)
1. **添加 `read_full()` 函数**: 循环读取直到获取完整的指定字节数
```c
static ssize_t read_full(int sock, void* buf, size_t count) {
// 循环读取,处理 EINTR 和部分读取
// 确保读取到 count 字节或连接关闭
}
```
2. **添加 `write_full()` 函数**: 循环写入直到发送完整的指定字节数
```c
static ssize_t write_full(int sock, const void* buf, size_t count) {
// 循环写入,处理 EINTR 和部分写入
// 确保写入 count 字节
}
```
3. **修改 `read_message()`**: 使用 `read_full()` 读取消息头和载荷
4. **修改 `write_message()`**: 使用 `write_full()` 发送消息头和载荷
### Go 代码修复 (`protocol.go`)
1. **使用 `io.ReadFull()`**: 确保完整读取消息头和载荷
```go
// 读取消息头
headerBuf := make([]byte, 16)
if _, err := io.ReadFull(conn, headerBuf); err != nil {
return 0, nil, err
}
// 读取载荷
if _, err := io.ReadFull(conn, payload); err != nil {
return 0, nil, err
}
```
2. **完整写入**: 构建完整缓冲区后循环写入
```go
written := 0
for written < len(buf) {
n, err := conn.Write(buf[written:])
if err != nil {
return err
}
written += n
}
```
## DEBUG 日志系统
### C 代码 DEBUG 功能
1. **`debug.h``debug.c`**:
- `DEBUG_LOG(fmt, ...)`: 带时间戳、文件位置的日志
- `DEBUG_HEX(prefix, data, len)`: 十六进制数据转储
- 只在 `DEBUG=1` 编译时启用
2. **关键日志点**:
- 消息发送/接收的开始和完成
- 消息头和载荷的十六进制转储(前64字节)
- 读写进度(部分读取时)
- 终端模式切换
- 连接生命周期事件
### 编译和测试
```bash
# 1. 编译 DEBUG 模式的 C 客户端
cd /home/qcqcqc/workspace/Projects/bash_smart/execve_hook
make clean
make DEBUG=1 test_socket_client
# 2. 在一个终端启动 Go 服务端
cd /home/qcqcqc/workspace/Projects/bash_smart/go_service
sudo GOPROXY=https://goproxy.cn,direct go run ./cmd/tests/test_socket_terminal/progress-animated/main.go
# 3. 在另一个终端运行 DEBUG 客户端
cd /home/qcqcqc/workspace/Projects/bash_smart/execve_hook
./build/test_socket_client 2> debug.log
# 4. 查看 DEBUG 日志
tail -f debug.log
```
### DEBUG 日志示例
```
[DEBUG 16:54:33.123] [src/client.c:320:seeking_solutions] 开始 seeking_solutions: filename=/usr/bin/ls
[DEBUG 16:54:33.124] [src/client.c:355:seeking_solutions] 尝试连接到 /var/run/bash-smart.sock
[DEBUG 16:54:33.125] [src/client.c:365:seeking_solutions] 连接成功
[DEBUG 16:54:33.126] [src/socket_protocol.c:35:write_message] 写入消息: type=1, payload_len=1024
[DEBUG HEX] 消息头 (16 bytes):
54 4d 53 42 01 00 00 00 00 04 00 00 00 00 00 00
[DEBUG 16:54:33.127] [src/socket_protocol.c:65:write_full] 进度 1024/1024 字节
[DEBUG 16:54:33.128] [src/socket_protocol.c:76:read_message] 开始读取消息...
[DEBUG 16:54:33.129] [src/socket_protocol.c:96:read_full] 进度 16/16 字节
[DEBUG HEX] 收到消息头 (16 bytes):
54 4d 53 42 03 00 00 00 6c 03 00 00 00 00 00 00
[DEBUG 16:54:33.130] [src/socket_protocol.c:100:read_message] 消息类型=3, 载荷长度=876
```
## 验证修复
修复后应该看到:
1. ✅ 所有消息完整读取,没有 "Failed to read message header" 错误
2. ✅ 连接正常完成,TUI 显示完整流程
3. ✅ 终端正确恢复,光标在新行开始
4. ✅ DEBUG 日志显示所有读写操作成功完成
## 技术要点
### 为什么需要完整读写?
1. **TCP 流特性**: TCP 是字节流协议,不保证消息边界
2. **内核缓冲区**: 数据可能在内核缓冲区中分批到达
3. **网络延迟**: 跨网络通信时更容易出现部分读取
4. **信号中断**: `EINTR` 会导致系统调用提前返回
### `read_full()` vs `read()`
| 函数 | 行为 | 适用场景 |
|------|------|---------|
| `read()` | 返回当前可用的数据,可能少于请求的字节数 | 不确定数据长度的场景 |
| `read_full()` | 循环读取直到获取完整的请求字节数或错误 | 已知数据长度的协议(如我们的消息协议) |
### Go 的 `io.ReadFull()`
Go 标准库提供了 `io.ReadFull()`,其行为等同于我们实现的 `read_full()`:
- 返回 `io.ErrUnexpectedEOF` 如果只读取了部分数据就遇到 EOF
- 返回 `io.EOF` 如果一个字节都没读到就遇到 EOF
- 只在成功读取完整字节数时返回 `nil` 错误
## 后续改进
1. ✅ 添加读写超时机制
2. ✅ 添加消息大小限制(防止内存攻击)
3. ✅ 实现消息队列缓冲
4. ⏳ 添加压缩支持(大载荷)
5. ⏳ 添加加密支持(敏感数据)

32
build_debug.sh Normal file
View File

@ -0,0 +1,32 @@
#!/bin/bash
# 构建 DEBUG 模式的测试客户端
set -e
echo "======================================"
echo " 构建 DEBUG 模式"
echo "======================================"
echo ""
cd "$(dirname "$0")"
# 清理旧的构建
echo "清理旧的构建文件..."
make clean
# 构建 DEBUG 模式
echo ""
echo "编译 DEBUG 模式的测试客户端..."
make DEBUG=1 test_socket_client
echo ""
echo "======================================"
echo " 构建完成!"
echo "======================================"
echo ""
echo "运行测试客户端:"
echo " ./build/test_socket_client"
echo ""
echo "DEBUG 日志会输出到 stderr"
echo ""

View File

@ -19,12 +19,14 @@
#define BUFFER_SIZE 4096 #define BUFFER_SIZE 4096
// 全局变量,用于信号处理器和主线程通信 // 全局变量,用于信号处理器和主线程通信
static volatile sig_atomic_t g_window_size_changed = 0;
static volatile sig_atomic_t g_should_exit = 0; static volatile sig_atomic_t g_should_exit = 0;
static int g_socket_fd = -1; static int g_socket_fd = -1;
static pthread_mutex_t g_socket_mutex = PTHREAD_MUTEX_INITIALIZER; static pthread_mutex_t g_socket_mutex = PTHREAD_MUTEX_INITIALIZER;
static volatile sig_atomic_t g_terminal_modified = 0; // 标记终端是否被修改 static volatile sig_atomic_t g_terminal_modified = 0; // 标记终端是否被修改
// 用于窗口大小变化通知的管道
static int g_winch_pipe[2] = {-1, -1};
// 恢复终端状态的清理函数 // 恢复终端状态的清理函数
static void cleanup_terminal(void) { static void cleanup_terminal(void) {
// 总是尝试恢复终端,即使标志未设置 // 总是尝试恢复终端,即使标志未设置
@ -39,7 +41,11 @@ static void cleanup_terminal(void) {
// SIGWINCH信号处理器 // SIGWINCH信号处理器
static void handle_sigwinch(int sig) { static void handle_sigwinch(int sig) {
(void)sig; (void)sig;
g_window_size_changed = 1; if (g_winch_pipe[1] != -1) {
char dummy = 1;
ssize_t result = write(g_winch_pipe[1], &dummy, 1);
(void)result;
}
} }
// SIGINT/SIGTERM信号处理器 // SIGINT/SIGTERM信号处理器
@ -51,12 +57,14 @@ static void handle_exit_signal(int sig) {
// 获取并发送终端信息 // 获取并发送终端信息
static int send_terminal_info(int sock, MessageType msg_type) { static int send_terminal_info(int sock, MessageType msg_type) {
DEBUG_LOG("开始发送终端信息: msg_type=%d", msg_type);
TerminalInfoFixed term_info_fixed; TerminalInfoFixed term_info_fixed;
memset(&term_info_fixed, 0, sizeof(term_info_fixed)); memset(&term_info_fixed, 0, sizeof(term_info_fixed));
// 检查是否为TTY // 检查是否为TTY
int is_tty = isatty(STDIN_FILENO); int is_tty = isatty(STDIN_FILENO);
term_info_fixed.is_tty = is_tty; term_info_fixed.is_tty = is_tty;
DEBUG_LOG("is_tty=%d", is_tty);
// 获取窗口大小 // 获取窗口大小
if (is_tty) { if (is_tty) {
@ -66,6 +74,7 @@ static int send_terminal_info(int sock, MessageType msg_type) {
term_info_fixed.cols = ws.ws_col; term_info_fixed.cols = ws.ws_col;
term_info_fixed.x_pixel = ws.ws_xpixel; term_info_fixed.x_pixel = ws.ws_xpixel;
term_info_fixed.y_pixel = ws.ws_ypixel; term_info_fixed.y_pixel = ws.ws_ypixel;
DEBUG_LOG("终端大小: %dx%d", ws.ws_col, ws.ws_row);
} }
// 获取termios属性 // 获取termios属性
@ -117,6 +126,11 @@ static int send_terminal_info(int sock, MessageType msg_type) {
int result = write_message(sock, msg_type, payload, payload_len); int result = write_message(sock, msg_type, payload, payload_len);
free(payload); free(payload);
if (result == 0) {
DEBUG_LOG("终端信息发送成功");
} else {
DEBUG_LOG("终端信息发送失败");
}
return result; return result;
} }
@ -186,11 +200,16 @@ static void* response_listener_thread(void* arg) {
int result = read_message(sock, &msg_type, &payload, &payload_len); int result = read_message(sock, &msg_type, &payload, &payload_len);
if (result <= 0) { if (result <= 0) {
DEBUG_LOG("read_message 返回 %d退出循环", result);
break; // 连接关闭或错误 break; // 连接关闭或错误
} }
DEBUG_LOG("收到消息: type=%d, payload_len=%u", msg_type, payload_len);
if (msg_type == MSG_TYPE_SERVER_RESPONSE && payload != NULL) { if (msg_type == MSG_TYPE_SERVER_RESPONSE && payload != NULL) {
DEBUG_LOG("写入响应到 fd=%d, 长度=%u", *output_fd, payload_len);
ssize_t written = write(*output_fd, payload, payload_len); ssize_t written = write(*output_fd, payload, payload_len);
DEBUG_LOG("实际写入: %zd 字节", written);
(void)written; (void)written;
if (*output_fd == STDOUT_FILENO) { if (*output_fd == STDOUT_FILENO) {
fflush(stdout); fflush(stdout);
@ -203,6 +222,7 @@ static void* response_listener_thread(void* arg) {
free_message_payload(payload); free_message_payload(payload);
} }
DEBUG_LOG("响应监听线程退出");
return NULL; return NULL;
} }
@ -210,10 +230,23 @@ static void* response_listener_thread(void* arg) {
static void* window_monitor_thread(void* arg) { static void* window_monitor_thread(void* arg) {
(void)arg; (void)arg;
DEBUG_LOG("窗口监听线程启动");
struct pollfd pfd;
pfd.fd = g_winch_pipe[0];
pfd.events = POLLIN;
while (!g_should_exit) { while (!g_should_exit) {
// 等待窗口大小变化信号 int ret = poll(&pfd, 1, 500);
if (g_window_size_changed) { if (ret <= 0) {
g_window_size_changed = 0; continue;
}
// 清空管道
char buf[64];
read(g_winch_pipe[0], buf, sizeof(buf));
DEBUG_LOG("检测到窗口大小变化");
pthread_mutex_lock(&g_socket_mutex); pthread_mutex_lock(&g_socket_mutex);
int sock = g_socket_fd; int sock = g_socket_fd;
@ -223,18 +256,15 @@ static void* window_monitor_thread(void* arg) {
break; break;
} }
// 发送窗口大小更新消息
if (send_terminal_info(sock, MSG_TYPE_WINDOW_SIZE_UPDATE) < 0) { if (send_terminal_info(sock, MSG_TYPE_WINDOW_SIZE_UPDATE) < 0) {
DEBUG_LOG("Failed to send window size update\n"); DEBUG_LOG("发送窗口大小更新失败");
break; break;
} }
DEBUG_LOG("Window size updated sent to server\n"); DEBUG_LOG("窗口大小更新已发送");
}
usleep(100000); // 睡眠100ms
} }
DEBUG_LOG("窗口监听线程退出");
return NULL; return NULL;
} }
@ -242,10 +272,13 @@ static void* window_monitor_thread(void* arg) {
static void* terminal_input_thread(void* arg) { static void* terminal_input_thread(void* arg) {
(void)arg; (void)arg;
DEBUG_LOG("终端输入线程启动");
char buf[BUFFER_SIZE]; char buf[BUFFER_SIZE];
// 设置终端为原始模式并启用鼠标跟踪 // 设置终端为原始模式并启用鼠标跟踪
if (isatty(STDIN_FILENO)) { if (isatty(STDIN_FILENO)) {
DEBUG_LOG("设置终端为原始模式");
setup_terminal_raw_mode(STDIN_FILENO); setup_terminal_raw_mode(STDIN_FILENO);
enable_mouse_tracking(STDOUT_FILENO); enable_mouse_tracking(STDOUT_FILENO);
g_terminal_modified = 1; // 标记终端已被修改 g_terminal_modified = 1; // 标记终端已被修改
@ -299,6 +332,7 @@ static void* terminal_input_thread(void* arg) {
int seeking_solutions(const char* filename, char* const argv[], int seeking_solutions(const char* filename, char* const argv[],
char* const envp[], const char* logPath, int* output_fd) { char* const envp[], const char* logPath, int* output_fd) {
DEBUG_LOG("开始 seeking_solutions: filename=%s", filename);
char abs_path[PATH_MAX]; char abs_path[PATH_MAX];
char pwd[PATH_MAX]; char pwd[PATH_MAX];
@ -335,9 +369,11 @@ int seeking_solutions(const char* filename, char* const argv[],
abs_path[PATH_MAX - 1] = '\0'; abs_path[PATH_MAX - 1] = '\0';
} }
// 创建socket连接 // 创建 socket 连接
DEBUG_LOG("尝试连接到 %s", SOCKET_PATH);
int sock = socket(AF_UNIX, SOCK_STREAM, 0); int sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (sock == -1) { if (sock == -1) {
DEBUG_LOG("创建 socket 失败: %s", strerror(errno));
perror("socket"); perror("socket");
return -1; return -1;
} }
@ -354,11 +390,20 @@ int seeking_solutions(const char* filename, char* const argv[],
return -1; return -1;
} }
// 设置全局socket DEBUG_LOG("连接成功");
// 设置全局 socket
pthread_mutex_lock(&g_socket_mutex); pthread_mutex_lock(&g_socket_mutex);
g_socket_fd = sock; g_socket_fd = sock;
pthread_mutex_unlock(&g_socket_mutex); pthread_mutex_unlock(&g_socket_mutex);
// 创建窗口大小变化通知管道
if (pipe(g_winch_pipe) < 0) {
DEBUG_LOG("创建管道失败: %s", strerror(errno));
close(sock);
return -1;
}
// 设置信号处理器 // 设置信号处理器
struct sigaction sa_winch; struct sigaction sa_winch;
memset(&sa_winch, 0, sizeof(sa_winch)); memset(&sa_winch, 0, sizeof(sa_winch));
@ -507,15 +552,18 @@ int seeking_solutions(const char* filename, char* const argv[],
memcpy(ptr, shell_type, shell_type_len); memcpy(ptr, shell_type, shell_type_len);
// 发送初始化消息 // 发送初始化消息
DEBUG_LOG("发送初始化消息: payload_len=%u", total_payload_len);
if (write_message(sock, MSG_TYPE_INIT, init_payload, total_payload_len) < 0) { if (write_message(sock, MSG_TYPE_INIT, init_payload, total_payload_len) < 0) {
DEBUG_LOG("Failed to send init message\n"); DEBUG_LOG("Failed to send init message\n");
free(init_payload); free(init_payload);
close(sock); close(sock);
return -1; return -1;
} }
DEBUG_LOG("初始化消息发送成功");
free(init_payload); free(init_payload);
// 启动响应监听线程 // 启动响应监听线程
DEBUG_LOG("创建响应监听线程");
pthread_t response_thread; pthread_t response_thread;
if (pthread_create(&response_thread, NULL, response_listener_thread, output_fd) != 0) { if (pthread_create(&response_thread, NULL, response_listener_thread, output_fd) != 0) {
DEBUG_LOG("Failed to create response listener thread\n"); DEBUG_LOG("Failed to create response listener thread\n");
@ -524,6 +572,7 @@ int seeking_solutions(const char* filename, char* const argv[],
} }
// 启动窗口监听线程 // 启动窗口监听线程
DEBUG_LOG("创建窗口监听线程");
pthread_t window_thread; pthread_t window_thread;
if (pthread_create(&window_thread, NULL, window_monitor_thread, NULL) != 0) { if (pthread_create(&window_thread, NULL, window_monitor_thread, NULL) != 0) {
DEBUG_LOG("Failed to create window monitor thread\n"); DEBUG_LOG("Failed to create window monitor thread\n");
@ -534,6 +583,7 @@ int seeking_solutions(const char* filename, char* const argv[],
} }
// 启动终端输入监听线程 // 启动终端输入监听线程
DEBUG_LOG("创建终端输入监听线程");
pthread_t input_thread; pthread_t input_thread;
if (pthread_create(&input_thread, NULL, terminal_input_thread, NULL) != 0) { if (pthread_create(&input_thread, NULL, terminal_input_thread, NULL) != 0) {
DEBUG_LOG("Failed to create terminal input thread\n"); DEBUG_LOG("Failed to create terminal input thread\n");
@ -546,15 +596,35 @@ int seeking_solutions(const char* filename, char* const argv[],
} }
// 等待响应线程结束(表示服务器关闭连接) // 等待响应线程结束(表示服务器关闭连接)
DEBUG_LOG("等待响应线程结束...");
pthread_join(response_thread, NULL); pthread_join(response_thread, NULL);
// 清理 // 清理
DEBUG_LOG("开始清理资源");
g_should_exit = 1; g_should_exit = 1;
// 写入管道唤醒窗口监控线程
if (g_winch_pipe[1] != -1) {
char dummy = 0;
ssize_t result = write(g_winch_pipe[1], &dummy, 1);
(void)result;
}
pthread_cancel(window_thread); pthread_cancel(window_thread);
pthread_cancel(input_thread); pthread_cancel(input_thread);
pthread_join(window_thread, NULL); pthread_join(window_thread, NULL);
pthread_join(input_thread, NULL); pthread_join(input_thread, NULL);
// 关闭管道
if (g_winch_pipe[0] != -1) {
close(g_winch_pipe[0]);
g_winch_pipe[0] = -1;
}
if (g_winch_pipe[1] != -1) {
close(g_winch_pipe[1]);
g_winch_pipe[1] = -1;
}
// 恢复终端状态 // 恢复终端状态
cleanup_terminal(); cleanup_terminal();

View File

@ -1,11 +1,77 @@
#include <stdio.h>
#include <stdarg.h>
#include <time.h>
#include <sys/time.h>
#include <unistd.h>
#include <string.h>
#ifdef DEBUG #ifdef DEBUG
#include <execinfo.h> #include <execinfo.h>
#include <stdio.h>
#include <unistd.h>
void print_stacktrace() { void print_stacktrace() {
void *buffer[100]; void *buffer[100];
int size = backtrace(buffer, 100); int size = backtrace(buffer, 100);
backtrace_symbols_fd(buffer, size, STDERR_FILENO); backtrace_symbols_fd(buffer, size, STDERR_FILENO);
} }
// DEBUG 日志函数
void debug_log(const char *file, int line, const char *func, const char *fmt, ...) {
struct timeval tv;
gettimeofday(&tv, NULL);
struct tm *tm_info = localtime(&tv.tv_sec);
char time_str[64];
strftime(time_str, sizeof(time_str), "%H:%M:%S", tm_info);
// 打印时间、文件、行号、函数名
fprintf(stderr, "[DEBUG %s.%03ld] [%s:%d:%s] ",
time_str, tv.tv_usec / 1000, file, line, func);
// 打印用户消息
va_list args;
va_start(args, fmt);
vfprintf(stderr, fmt, args);
va_end(args);
fprintf(stderr, "\n");
fflush(stderr);
}
// 十六进制数据转储
void debug_hex_dump(const char *prefix, const void *data, size_t len) {
const unsigned char *bytes = (const unsigned char *)data;
fprintf(stderr, "[DEBUG HEX] %s (%zu bytes):\n", prefix, len);
for (size_t i = 0; i < len; i++) {
fprintf(stderr, "%02x ", bytes[i]);
if ((i + 1) % 16 == 0) {
fprintf(stderr, "\n");
}
}
if (len % 16 != 0) {
fprintf(stderr, "\n");
}
fflush(stderr);
}
#else
// 空实现
void debug_log(const char *file, int line, const char *func, const char *fmt, ...) {
(void)file;
(void)line;
(void)func;
(void)fmt;
}
void debug_hex_dump(const char *prefix, const void *data, size_t len) {
(void)prefix;
(void)data;
(void)len;
}
void print_stacktrace() {
// 空实现
}
#endif #endif

View File

@ -1,16 +1,27 @@
#ifndef DEBUG_H #ifndef DEBUG_H
#define DEBUG_H #define DEBUG_H
#ifdef DEBUG #include <stddef.h>
#include <execinfo.h>
#ifdef DEBUG
// DEBUG 日志函数声明
void debug_log(const char *file, int line, const char *func, const char *fmt, ...);
void debug_hex_dump(const char *prefix, const void *data, size_t len);
void print_stacktrace(); void print_stacktrace();
// 便捷宏
#define DEBUG_LOG(fmt, ...) \ #define DEBUG_LOG(fmt, ...) \
fprintf(stderr, "[DEBUG][PID %d] %s:%d:%s(): " fmt "\n\r", getpid(), \ debug_log(__FILE__, __LINE__, __func__, fmt, ##__VA_ARGS__)
__FILE__, __LINE__, __func__, ##__VA_ARGS__)
#define DEBUG_HEX(prefix, data, len) \
debug_hex_dump(prefix, data, len)
#else #else
// 空实现
#define DEBUG_LOG(fmt, ...) ((void)0) #define DEBUG_LOG(fmt, ...) ((void)0)
#define DEBUG_HEX(prefix, data, len) ((void)0)
void print_stacktrace();
#endif #endif

View File

@ -1,54 +1,129 @@
#include "socket_protocol.h" #include "socket_protocol.h"
#include <arpa/inet.h>
#include <errno.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <unistd.h> #include <unistd.h>
#include <arpa/inet.h>
#include "debug.h" #include "debug.h"
// 保存原始终端设置 // 保存原始终端设置
static struct termios g_original_termios; static struct termios g_original_termios;
static int g_termios_saved = 0; static int g_termios_saved = 0;
// 向 socket 完整写入指定字节数(处理部分写入)
static ssize_t write_full(int sock, const void* buf, size_t count) {
size_t total_written = 0;
const char* ptr = (const char*)buf;
while (total_written < count) {
ssize_t n = write(sock, ptr + total_written, count - total_written);
if (n < 0) {
if (errno == EINTR) {
continue; // 被信号中断,重试
}
DEBUG_LOG("write_full: write error: %s", strerror(errno));
return -1;
}
total_written += n;
if (total_written < count) {
DEBUG_LOG("write_full: 进度 %zu/%zu 字节", total_written, count);
}
}
return (ssize_t)total_written;
}
// 写入完整消息 // 写入完整消息
int write_message(int sock, MessageType type, const void* payload, uint32_t payload_len) { int write_message(int sock, MessageType type, const void* payload,
uint32_t payload_len) {
DEBUG_LOG("写入消息: type=%d, payload_len=%u", type, payload_len);
MessageHeader header; MessageHeader header;
header.magic = MESSAGE_MAGIC; header.magic = MESSAGE_MAGIC;
header.type = type; header.type = type;
header.payload_len = payload_len; header.payload_len = payload_len;
header.reserved = 0; header.reserved = 0;
// 发送消息头 DEBUG_HEX("消息头", &header, sizeof(header));
ssize_t written = write(sock, &header, sizeof(header));
// 发送消息头(确保完整发送)
ssize_t written = write_full(sock, &header, sizeof(header));
if (written != sizeof(header)) { if (written != sizeof(header)) {
DEBUG_LOG("Failed to write message header\n"); DEBUG_LOG(
"Failed to write message header: 期望 %zu 字节, 实际写入 %zd "
"字节\n",
sizeof(header), written);
return -1; return -1;
} }
// 写入载荷 // 写入载荷(确保完整发送)
if (payload_len > 0 && payload != NULL) { if (payload_len > 0 && payload != NULL) {
written = write(sock, payload, payload_len); DEBUG_HEX("消息载荷", payload, payload_len > 64 ? 64 : payload_len);
written = write_full(sock, payload, payload_len);
if (written != (ssize_t)payload_len) { if (written != (ssize_t)payload_len) {
DEBUG_LOG("Failed to write message payload\n"); DEBUG_LOG(
"Failed to write message payload: 期望 %u 字节, 实际写入 %zd "
"字节\n",
payload_len, written);
return -1; return -1;
} }
DEBUG_LOG("载荷写入成功: %zd 字节", written);
} }
DEBUG_LOG("消息写入完成");
return 0; return 0;
} }
// 读取完整消息 // 从 socket 完整读取指定字节数(处理部分读取)
int read_message(int sock, MessageType* type, void** payload, uint32_t* payload_len) { static ssize_t read_full(int sock, void* buf, size_t count) {
MessageHeader header; size_t total_read = 0;
char* ptr = (char*)buf;
// 读取消息头 while (total_read < count) {
ssize_t bytes_read = read(sock, &header, sizeof(header)); ssize_t n = read(sock, ptr + total_read, count - total_read);
if (bytes_read != sizeof(header)) { if (n < 0) {
if (bytes_read == 0) { if (errno == EINTR) {
return 0; // 连接正常关闭 continue; // 被信号中断,重试
} }
DEBUG_LOG("Failed to read message header, got %zd bytes\n", bytes_read); DEBUG_LOG("read_full: read error: %s", strerror(errno));
return -1; return -1;
} }
if (n == 0) {
// 连接关闭
DEBUG_LOG("read_full: 连接关闭,已读取 %zu/%zu 字节", total_read,
count);
return total_read > 0 ? (ssize_t)total_read : 0;
}
total_read += n;
DEBUG_LOG("read_full: 进度 %zu/%zu 字节", total_read, count);
}
return (ssize_t)total_read;
}
// 读取完整消息
int read_message(int sock, MessageType* type, void** payload,
uint32_t* payload_len) {
MessageHeader header;
DEBUG_LOG("开始读取消息...");
// 读取消息头(确保读取完整)
ssize_t bytes_read = read_full(sock, &header, sizeof(header));
if (bytes_read != sizeof(header)) {
if (bytes_read == 0) {
DEBUG_LOG("连接正常关闭");
return 0; // 连接正常关闭
}
DEBUG_LOG(
"Failed to read message header, got %zd bytes, expected %zu\n",
bytes_read, sizeof(header));
return -1;
}
DEBUG_HEX("收到消息头", &header, sizeof(header));
// 验证魔数 // 验证魔数
if (header.magic != MESSAGE_MAGIC) { if (header.magic != MESSAGE_MAGIC) {
@ -59,7 +134,9 @@ int read_message(int sock, MessageType* type, void** payload, uint32_t* payload_
*type = (MessageType)header.type; *type = (MessageType)header.type;
*payload_len = header.payload_len; *payload_len = header.payload_len;
// 读取载荷 DEBUG_LOG("消息类型=%d, 载荷长度=%u", *type, *payload_len);
// 读取载荷(确保读取完整)
if (header.payload_len > 0) { if (header.payload_len > 0) {
*payload = malloc(header.payload_len + 1); // +1 用于可能的字符串终止符 *payload = malloc(header.payload_len + 1); // +1 用于可能的字符串终止符
if (*payload == NULL) { if (*payload == NULL) {
@ -67,14 +144,19 @@ int read_message(int sock, MessageType* type, void** payload, uint32_t* payload_
return -1; return -1;
} }
bytes_read = read(sock, *payload, header.payload_len); bytes_read = read_full(sock, *payload, header.payload_len);
if (bytes_read != (ssize_t)header.payload_len) { if (bytes_read != (ssize_t)header.payload_len) {
DEBUG_LOG("Failed to read message payload\n"); DEBUG_LOG(
"Failed to read message payload: 期望%u字节, 实际读取%zd字节\n",
header.payload_len, bytes_read);
free(*payload); free(*payload);
*payload = NULL; *payload = NULL;
return -1; return -1;
} }
DEBUG_LOG("载荷读取成功: %zd 字节", bytes_read);
DEBUG_HEX("收到载荷", *payload, bytes_read > 64 ? 64 : bytes_read);
// 如果是字符串,添加终止符 // 如果是字符串,添加终止符
((char*)(*payload))[header.payload_len] = '\0'; ((char*)(*payload))[header.payload_len] = '\0';
} else { } else {
@ -93,15 +175,15 @@ void free_message_payload(void* payload) {
// 设置终端为原始模式(捕获所有输入) // 设置终端为原始模式(捕获所有输入)
int setup_terminal_raw_mode(int fd) { int setup_terminal_raw_mode(int fd) {
DEBUG_LOG("设置终端为原始模式: fd=%d", fd);
struct termios raw; struct termios raw;
// 保存原始终端设置(只保存一次)
if (!g_termios_saved) {
if (tcgetattr(fd, &g_original_termios) < 0) { if (tcgetattr(fd, &g_original_termios) < 0) {
DEBUG_LOG("保存原始终端设置失败");
return -1; return -1;
} }
g_termios_saved = 1; g_termios_saved = 1;
} DEBUG_LOG("原始终端设置已保存");
if (tcgetattr(fd, &raw) < 0) { if (tcgetattr(fd, &raw) < 0) {
return -1; return -1;
@ -119,23 +201,29 @@ int setup_terminal_raw_mode(int fd) {
raw.c_cc[VTIME] = 1; raw.c_cc[VTIME] = 1;
if (tcsetattr(fd, TCSAFLUSH, &raw) < 0) { if (tcsetattr(fd, TCSAFLUSH, &raw) < 0) {
DEBUG_LOG("设置终端原始模式失败");
return -1; return -1;
} }
DEBUG_LOG("终端原始模式设置成功");
return 0; return 0;
} }
// 恢复终端模式 // 恢复终端模式
int restore_terminal_mode(int fd) { int restore_terminal_mode(int fd) {
DEBUG_LOG("恢复终端模式: fd=%d", fd);
if (!g_termios_saved) { if (!g_termios_saved) {
return -1; // 没有保存过,无法恢复 DEBUG_LOG("没有保存的终端设置,跳过恢复");
return 0; // 没有保存过或已经恢复过,这不是错误
} }
if (tcsetattr(fd, TCSAFLUSH, &g_original_termios) < 0) { if (tcsetattr(fd, TCSAFLUSH, &g_original_termios) < 0) {
DEBUG_LOG("恢复终端模式失败");
return -1; return -1;
} }
g_termios_saved = 0; DEBUG_LOG("终端模式恢复成功");
g_termios_saved = 0; // 标记已恢复,防止重复恢复
return 0; return 0;
} }