vector_match_tests/perf_2.py

98 lines
4.2 KiB
Python

import re
import time
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import random
import string
# 1. 初始化模型
model = SentenceTransformer('all-MiniLM-L6-v2')
EMBEDDING_DIM = model.get_sentence_embedding_dimension()
# 2. 模拟错误数据生成
def generate_random_traceback(num_lines=10):
traceback = "Traceback (most recent call last):\n"
for i in range(num_lines):
file_path = f"/path/to/file_{''.join(random.choices(string.ascii_lowercase, k=5))}.py"
line_num = random.randint(1, 200)
function_name = f"function_{''.join(random.choices(string.ascii_lowercase, k=8))}"
traceback += f' File "{file_path}", line {line_num}, in {function_name}\n'
error_type = random.choice(["IndexError", "TypeError", "ValueError", "KeyError", "ModuleNotFoundError"])
error_message = f"{error_type}: {' '.join(random.choices(string.ascii_lowercase, k=15))}"
traceback += error_message + "\n"
return traceback
def generate_new_error_entry(index):
error = generate_random_traceback(num_lines=random.randint(5, 15))
return {
"error": error,
"solution": f"Real-time solution for error {index}"
}
# 3. 错误信息清洗函数 (与之前相同)
def clean_traceback(traceback_string):
cleaned = traceback_string
cleaned = cleaned.replace('\r', '')
cleaned = re.sub(r'^Traceback \(most recent call last\):[\n]+', '', cleaned, flags=re.MULTILINE)
file_line_regex = re.compile(r'^ *File \"(.*?)\", line (\d+)(, in .*?)?$', re.MULTILINE)
def replace_file_line(match):
full_path = match.group(1)
line_num = match.group(2)
function_part = match.group(3) or ''
last_sep_index = max(full_path.rfind('/'), full_path.rfind('\\'))
filename = full_path[last_sep_index + 1:] if last_sep_index != -1 else full_path
return f' File "{filename}", line {line_num}{function_part}'
cleaned = file_line_regex.sub(replace_file_line, cleaned)
cleaned = re.sub(r'<.* at 0x[0-9a-fA-F]+>', '<...>', cleaned)
cleaned = re.sub(r'\n\s*\n+', '\n', cleaned).strip()
return cleaned
# 4. 初始加载数据并构建索引
INITIAL_SIZE = 500
initial_error_db = [generate_new_error_entry(i) for i in range(INITIAL_SIZE)]
cleaned_initial_errors = [clean_traceback(entry["error"]) for entry in initial_error_db]
initial_embeddings = np.array([model.encode(error) for error in cleaned_initial_errors])
index = faiss.IndexFlatL2(EMBEDDING_DIM)
index.add(initial_embeddings)
print(f"初始加载 {INITIAL_SIZE} 条错误信息到向量数据库。")
# 5. 模拟实时新增数据并查询
NUM_UPDATES = 100
QUERY_INTERVAL = 10
BATCH_SIZE = 5
print("\n--- 模拟实时更新和查询 ---")
for i in range(NUM_UPDATES):
# 模拟生成新的错误信息
new_entries = [generate_new_error_entry(INITIAL_SIZE + i * BATCH_SIZE + j) for j in range(BATCH_SIZE)]
cleaned_new_errors = [clean_traceback(entry["error"]) for entry in new_entries]
new_embeddings = np.array([model.encode(error) for error in cleaned_new_errors])
# 添加到 FAISS 索引
start_add_time = time.time()
index.add(new_embeddings)
add_time = time.time() - start_add_time
print(f"添加了 {len(new_entries)} 条新错误信息,耗时: {add_time:.6f} 秒,当前数据库大小: {index.ntotal}")
# 定期进行查询
if (i + 1) % QUERY_INTERVAL == 0:
# 随机选择一个已有的错误信息进行查询(可以是初始的或新添加的)
random_index = random.randint(0, index.ntotal - 1)
if random_index < INITIAL_SIZE:
query_error = cleaned_initial_errors[random_index]
else:
query_error = cleaned_new_errors[random_index - (INITIAL_SIZE + (i // QUERY_INTERVAL -1) * BATCH_SIZE) - BATCH_SIZE : random_index - (INITIAL_SIZE + (i // QUERY_INTERVAL -1) * BATCH_SIZE)]
query_vec = model.encode(query_error)
query_vec = np.array([query_vec])
start_search_time = time.time()
D, I = index.search(query_vec, k=1)
search_time = time.time() - start_search_time
print(f" 查询耗时 (数据库大小: {index.ntotal}): {search_time:.6f}")
print("\n--- 实时更新和查询测试完成 ---")