152 lines
5.6 KiB
Python
152 lines
5.6 KiB
Python
import re
|
|
import time
|
|
from sentence_transformers import SentenceTransformer
|
|
import faiss
|
|
import numpy as np
|
|
import random
|
|
import string
|
|
import psutil
|
|
import os
|
|
# import pynvml
|
|
|
|
# 初始化 GPU 监控
|
|
# pynvml.nvmlInit()
|
|
# gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 默认用第一个GPU
|
|
|
|
def get_memory_info():
|
|
process = psutil.Process(os.getpid())
|
|
mem_info = process.memory_info()
|
|
return mem_info.rss / 1024 / 1024 # 单位 MB
|
|
|
|
def get_gpu_memory_info():
|
|
# mem = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
|
|
# used = mem.used / 1024 / 1024 # MB
|
|
# total = mem.total / 1024 / 1024
|
|
# return used, total
|
|
return 0, 0
|
|
|
|
|
|
# 1. 初始化模型
|
|
model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
|
# 2. 模拟一个更大的错误数据库
|
|
NUM_ERRORS = 100
|
|
error_db_large = []
|
|
|
|
def generate_random_traceback(num_lines=10):
|
|
traceback = "Traceback (most recent call last):\n"
|
|
for i in range(num_lines):
|
|
file_path = f"/path/to/file_{''.join(random.choices(string.ascii_lowercase, k=5))}.py"
|
|
line_num = random.randint(1, 200)
|
|
function_name = f"function_{''.join(random.choices(string.ascii_lowercase, k=8))}"
|
|
traceback += f' File "{file_path}", line {line_num}, in {function_name}\n'
|
|
error_type = random.choice(["IndexError", "TypeError", "ValueError", "KeyError", "ModuleNotFoundError"])
|
|
error_message = f"{error_type}: {' '.join(random.choices(string.ascii_lowercase, k=15))}"
|
|
traceback += error_message + "\n"
|
|
return traceback
|
|
|
|
for i in range(NUM_ERRORS):
|
|
error_db_large.append({
|
|
"error": generate_random_traceback(num_lines=random.randint(5, 15)),
|
|
"solution": f"Sample solution for error {i}"
|
|
})
|
|
|
|
# 3. 错误信息清洗函数 (与之前相同)
|
|
def clean_traceback(traceback_string):
|
|
cleaned = traceback_string
|
|
cleaned = cleaned.replace('\r', '')
|
|
cleaned = re.sub(r'^Traceback \(most recent call last\):[\n]+', '', cleaned, flags=re.MULTILINE)
|
|
file_line_regex = re.compile(r'^ *File \"(.*?)\", line (\d+)(, in .*?)?$', re.MULTILINE)
|
|
|
|
def replace_file_line(match):
|
|
full_path = match.group(1)
|
|
line_num = match.group(2)
|
|
function_part = match.group(3) or ''
|
|
last_sep_index = max(full_path.rfind('/'), full_path.rfind('\\'))
|
|
filename = full_path[last_sep_index + 1:] if last_sep_index != -1 else full_path
|
|
return f' File "{filename}", line {line_num}{function_part}'
|
|
|
|
cleaned = file_line_regex.sub(replace_file_line, cleaned)
|
|
cleaned = re.sub(r'<.* at 0x[0-9a-fA-F]+>', '<...>', cleaned)
|
|
cleaned = re.sub(r'\n\s*\n+', '\n', cleaned).strip()
|
|
return cleaned
|
|
|
|
# 4. 清洗大型错误数据库
|
|
start_time = time.time()
|
|
cleaned_error_db_large = []
|
|
for entry in error_db_large:
|
|
cleaned_error = clean_traceback(entry["error"])
|
|
cleaned_error_db_large.append({
|
|
"cleaned_error": cleaned_error,
|
|
"original_error": entry["error"],
|
|
"solution": entry["solution"]
|
|
})
|
|
|
|
cleaning_time = time.time() - start_time
|
|
print(f"清洗 {NUM_ERRORS} 条错误信息耗时: {cleaning_time:.4f} 秒")
|
|
print(f"内存使用: {get_memory_info():.2f} MB")
|
|
gpu_used, gpu_total = get_gpu_memory_info()
|
|
print(f"GPU 显存使用: {gpu_used:.2f} / {gpu_total:.2f} MB")
|
|
|
|
# 5. 向量化清洗后的错误信息并构建索引
|
|
start_time = time.time()
|
|
embeddings_large = np.array([model.encode(e['cleaned_error']) for e in cleaned_error_db_large])
|
|
embedding_time = time.time() - start_time
|
|
print(f"向量化 {NUM_ERRORS} 条错误信息耗时: {embedding_time:.4f} 秒")
|
|
print(f"内存使用: {get_memory_info():.2f} MB")
|
|
gpu_used, gpu_total = get_gpu_memory_info()
|
|
print(f"GPU 显存使用: {gpu_used:.2f} / {gpu_total:.2f} MB")
|
|
|
|
|
|
start_time = time.time()
|
|
index_large = faiss.IndexFlatL2(embeddings_large.shape[1])
|
|
index_large.add(embeddings_large)
|
|
index_building_time = time.time() - start_time
|
|
print(f"构建包含 {NUM_ERRORS} 条向量的 FAISS 索引耗时: {index_building_time:.4f} 秒")
|
|
print(f"内存使用: {get_memory_info():.2f} MB")
|
|
gpu_used, gpu_total = get_gpu_memory_info()
|
|
print(f"GPU 显存使用: {gpu_used:.2f} / {gpu_total:.2f} MB")
|
|
|
|
|
|
# 6. 用户报错(测试查询)
|
|
user_error_performance = generate_random_traceback(num_lines=12)
|
|
cleaned_user_error_performance = clean_traceback(user_error_performance)
|
|
query_vec_performance = model.encode(cleaned_user_error_performance)
|
|
query_vec_performance = np.array([query_vec_performance])
|
|
|
|
|
|
NUM_QUERIES = 100 # 查询次数
|
|
k = 10 # 每次查询返回前 k 个最相似的错误
|
|
|
|
query_times = []
|
|
similarities = []
|
|
|
|
for i in range(NUM_QUERIES):
|
|
user_error = generate_random_traceback(num_lines=random.randint(8, 14))
|
|
cleaned_user_error = clean_traceback(user_error)
|
|
query_vec = model.encode(cleaned_user_error)
|
|
query_vec = np.array([query_vec])
|
|
|
|
start_query_time = time.time()
|
|
D, I = index_large.search(query_vec, k=k)
|
|
elapsed = time.time() - start_query_time
|
|
query_times.append(elapsed)
|
|
|
|
matched_index = I[0][0]
|
|
distance = D[0][0]
|
|
similarity = 1 - (distance ** 2) / 2
|
|
similarities.append(similarity)
|
|
|
|
# 输出统计结果
|
|
print(f"\n--- 多次查询性能测试 ---")
|
|
print(f"总查询次数: {NUM_QUERIES}")
|
|
print(f"平均查询耗时: {np.mean(query_times) * 1000:.4f} ms")
|
|
print(f"最小查询耗时: {np.min(query_times) * 1000:.4f} ms")
|
|
print(f"最大查询耗时: {np.max(query_times) * 1000:.4f} ms")
|
|
print(f"查询耗时标准差: {np.std(query_times) * 1000:.4f} ms")
|
|
print(f"平均匹配相似度: {np.mean(similarities):.4f}")
|
|
print(f"相似度标准差: {np.std(similarities):.4f}")
|
|
print(f"内存使用: {get_memory_info():.2f} MB")
|
|
gpu_used, gpu_total = get_gpu_memory_info()
|
|
print(f"GPU 显存使用: {gpu_used:.2f} / {gpu_total:.2f} MB")
|