149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
from fastapi import APIRouter, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from typing import List
|
|
import numpy as np
|
|
from app.dependencies import (
|
|
model, error_memory, clean_traceback, split_traceback_layers,
|
|
compute_layered_similarity_sco, rebuild_index,
|
|
id_to_index, index_to_id
|
|
)
|
|
import faiss
|
|
from app.utils.response_wrapper import standard_response
|
|
|
|
router = APIRouter()
|
|
aggregate_index = None
|
|
|
|
class ErrorInsert(BaseModel):
|
|
error: str
|
|
db_id: str
|
|
|
|
class ErrorQuery(BaseModel):
|
|
error: str
|
|
top_n: int
|
|
|
|
@router.post("/insert", summary="插入数据并重新建立索引")
|
|
@standard_response
|
|
def insert_errors(errors: List[ErrorInsert]):
|
|
global aggregate_index, id_to_index, index_to_id
|
|
|
|
for entry in errors:
|
|
cleaned = clean_traceback(entry.error)
|
|
layers = split_traceback_layers(cleaned)
|
|
layer_vectors = model.encode(layers)
|
|
agg_vector = np.mean(layer_vectors, axis=0)
|
|
|
|
error_memory[entry.db_id] = {
|
|
"error": entry.error,
|
|
"vector": agg_vector,
|
|
"layers": layers,
|
|
"layer_vectors": layer_vectors
|
|
}
|
|
aggregate_index, id_to_index, index_to_id = rebuild_index(error_memory)
|
|
return {"inserted": [e.db_id for e in errors]}
|
|
|
|
@router.get("/list", summary="获取已建立索引的所有 db_id")
|
|
@standard_response
|
|
def list_db_ids():
|
|
return list(error_memory.keys())
|
|
|
|
@router.delete("/delete", summary="删除指定 db_id 的数据")
|
|
@standard_response
|
|
def delete_errors(ids: List[str] = Query(...)):
|
|
global aggregate_index, id_to_index, index_to_id
|
|
deleted = []
|
|
|
|
for db_id in ids:
|
|
if db_id in error_memory:
|
|
del error_memory[db_id]
|
|
deleted.append(db_id)
|
|
|
|
aggregate_index, id_to_index, index_to_id = rebuild_index(error_memory)
|
|
return {"deleted": deleted}
|
|
|
|
@router.put("/update", summary="更新指定 db_id 的数据")
|
|
@standard_response
|
|
def update_error(error: ErrorInsert):
|
|
if error.db_id not in error_memory:
|
|
raise HTTPException(status_code=404, detail="db_id not found")
|
|
|
|
cleaned = clean_traceback(error.error)
|
|
layers = split_traceback_layers(cleaned)
|
|
layer_vectors = model.encode(layers)
|
|
agg_vector = np.mean(layer_vectors, axis=0)
|
|
|
|
error_memory[error.db_id] = {
|
|
"error": error.error,
|
|
"vector": agg_vector,
|
|
"layers": layers,
|
|
"layer_vectors": layer_vectors
|
|
}
|
|
|
|
global aggregate_index, id_to_index, index_to_id
|
|
aggregate_index, id_to_index, index_to_id = rebuild_index(error_memory)
|
|
return {"updated": error.db_id}
|
|
|
|
@router.get("/page", summary="分页获取错误信息")
|
|
@standard_response
|
|
def get_error_page(page: int = Query(1, ge=1), page_size: int = Query(10, gt=0)):
|
|
if not error_memory:
|
|
raise HTTPException(status_code=404, detail="数据库为空")
|
|
|
|
total = len(error_memory)
|
|
start = (page - 1) * page_size
|
|
end = start + page_size
|
|
|
|
if start >= total:
|
|
raise HTTPException(status_code=404, detail="页码超出范围")
|
|
|
|
paginated_errors = list(error_memory.items())[start:end]
|
|
return {
|
|
"total": total,
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"data": [{"db_id": db_id, "error": entry["error"]} for db_id, entry in paginated_errors]
|
|
}
|
|
|
|
@router.post("/search", summary="搜索 top_n 个最相似的错误")
|
|
@standard_response
|
|
def search_error(query: ErrorQuery):
|
|
if not error_memory:
|
|
raise HTTPException(status_code=404, detail="数据库为空")
|
|
|
|
cleaned = clean_traceback(query.error)
|
|
user_layers = split_traceback_layers(cleaned)
|
|
user_vectors = model.encode(user_layers)
|
|
|
|
user_agg_vector = np.mean(user_vectors, axis=0).astype('float32').reshape(1, -1)
|
|
k = min(query.top_n, len(error_memory))
|
|
sim, indices = aggregate_index.search(user_agg_vector, k)
|
|
|
|
results = []
|
|
for idx, score in zip(indices[0], sim[0]):
|
|
db_id = index_to_id[idx]
|
|
db_entry = error_memory[db_id]
|
|
|
|
# 分层匹配得分
|
|
layer_score = compute_layered_similarity_sco(user_vectors, db_entry["layer_vectors"])
|
|
|
|
# 逐层匹配,提取每一层的关键字
|
|
matched_keywords = []
|
|
for user_layer, db_layer in zip(user_layers, db_entry["layers"]):
|
|
# 这里简单取两个文本的公共单词作为关键字(你可以换成更复杂的匹配方法)
|
|
user_tokens = set(user_layer.split())
|
|
db_tokens = set(db_layer.split())
|
|
common_tokens = user_tokens & db_tokens
|
|
matched_keywords.append(list(common_tokens))
|
|
|
|
results.append({
|
|
"db_id": db_id,
|
|
"aggregate_similarity": round(float(score), 4), # 聚合索引匹配得分
|
|
"layer_similarity": round(layer_score, 4), # 分层匹配得分
|
|
"matched_layers": db_entry["layers"], # 匹配到的 traceback 每一层
|
|
"matched_keywords": matched_keywords # 每一层匹配到的关键字列表
|
|
})
|
|
|
|
# 先按分层得分排,再按聚合得分排
|
|
results.sort(key=lambda x: (x["layer_similarity"], x["aggregate_similarity"]), reverse=True)
|
|
|
|
return results
|