bash_problem_search/app/router.py

106 lines
3.3 KiB
Python

from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from typing import List
import numpy as np
from app.dependencies import (
model, error_memory, clean_traceback, split_traceback_layers,
compute_layered_similarity_sco, rebuild_index,
id_to_index, index_to_id
)
import faiss
router = APIRouter()
aggregate_index = None
class ErrorInsert(BaseModel):
error: str
db_id: str
class ErrorQuery(BaseModel):
error: str
top_n: int
@router.post("/insert", summary="插入数据并重新建立索引")
def insert_errors(errors: List[ErrorInsert]):
global aggregate_index, id_to_index, index_to_id
for entry in errors:
cleaned = clean_traceback(entry.error)
layers = split_traceback_layers(cleaned)
layer_vectors = model.encode(layers)
agg_vector = np.mean(layer_vectors, axis=0)
error_memory[entry.db_id] = {
"error": entry.error,
"vector": agg_vector,
"layers": layers,
"layer_vectors": layer_vectors
}
aggregate_index, id_to_index, index_to_id = rebuild_index(error_memory)
return {"inserted": [e.db_id for e in errors]}
@router.get("/list", summary="获取已建立索引的所有 db_id")
def list_db_ids():
return list(error_memory.keys())
@router.delete("/delete", summary="删除指定 db_id 的数据")
def delete_errors(ids: List[str] = Query(...)):
global aggregate_index, id_to_index, index_to_id
deleted = []
for db_id in ids:
if db_id in error_memory:
del error_memory[db_id]
deleted.append(db_id)
aggregate_index, id_to_index, index_to_id = rebuild_index(error_memory)
return {"deleted": deleted}
@router.put("/update", summary="更新指定 db_id 的数据")
def update_error(error: ErrorInsert):
if error.db_id not in error_memory:
raise HTTPException(status_code=404, detail="db_id not found")
cleaned = clean_traceback(error.error)
layers = split_traceback_layers(cleaned)
layer_vectors = model.encode(layers)
agg_vector = np.mean(layer_vectors, axis=0)
error_memory[error.db_id] = {
"error": error.error,
"vector": agg_vector,
"layers": layers,
"layer_vectors": layer_vectors
}
global aggregate_index, id_to_index, index_to_id
aggregate_index, id_to_index, index_to_id = rebuild_index(error_memory)
return {"updated": error.db_id}
@router.post("/search", summary="搜索 top_n 个最相似的错误")
def search_error(query: ErrorQuery):
if not error_memory:
raise HTTPException(status_code=404, detail="数据库为空")
cleaned = clean_traceback(query.error)
user_layers = split_traceback_layers(cleaned)
user_vectors = model.encode(user_layers)
user_agg_vector = np.mean(user_vectors, axis=0).astype('float32').reshape(1, -1)
k = min(query.top_n, len(error_memory))
sim, indices = aggregate_index.search(user_agg_vector, k)
results = []
for idx, score in zip(indices[0], sim[0]):
db_id = index_to_id[idx]
db_entry = error_memory[db_id]
layer_score = compute_layered_similarity_sco(user_vectors, db_entry["layer_vectors"])
results.append({
"db_id": db_id,
"similarity": round(layer_score, 4),
"matched_layers": db_entry["layers"]
})
return results