diff --git a/app/__pycache__/exception_handlers.cpython-310.pyc b/app/__pycache__/exception_handlers.cpython-310.pyc new file mode 100644 index 0000000..62efad7 Binary files /dev/null and b/app/__pycache__/exception_handlers.cpython-310.pyc differ diff --git a/app/__pycache__/main.cpython-310.pyc b/app/__pycache__/main.cpython-310.pyc index 423f7fd..18c4917 100644 Binary files a/app/__pycache__/main.cpython-310.pyc and b/app/__pycache__/main.cpython-310.pyc differ diff --git a/app/__pycache__/router.cpython-310.pyc b/app/__pycache__/router.cpython-310.pyc index db4f25a..1ace912 100644 Binary files a/app/__pycache__/router.cpython-310.pyc and b/app/__pycache__/router.cpython-310.pyc differ diff --git a/app/exception_handlers.py b/app/exception_handlers.py new file mode 100644 index 0000000..af60670 --- /dev/null +++ b/app/exception_handlers.py @@ -0,0 +1,39 @@ +# app/exception_handlers.py +from fastapi import Request +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException +from app.schemas.response import ResponseModel + +async def validation_exception_handler(request: Request, exc: RequestValidationError): + return JSONResponse( + status_code=422, + content=ResponseModel( + msg="Validation Failed", + success=False, + data=exc.errors(), # 可视情况换成 None 或 str(exc) + code=422 + ).model_dump() + ) + +async def http_exception_handler(request: Request, exc: StarletteHTTPException): + return JSONResponse( + status_code=exc.status_code, + content=ResponseModel( + msg=exc.detail, + success=False, + data=None, + code=exc.status_code + ).model_dump() + ) + +async def general_exception_handler(request: Request, exc: Exception): + return JSONResponse( + status_code=500, + content=ResponseModel( + msg="Internal Server Error", + success=False, + data=str(exc), + code=500 + ).model_dump() + ) diff --git a/app/main.py b/app/main.py index 521e565..39810ea 100644 --- a/app/main.py +++ b/app/main.py @@ -1,5 +1,13 @@ from fastapi import FastAPI from app.router import router +from app.utils.response_wrapper import standard_response +from app.exception_handlers import ( + validation_exception_handler, + http_exception_handler, + general_exception_handler +) +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException app = FastAPI( title="Vector Search API", @@ -9,8 +17,13 @@ app = FastAPI( openapi_url="/openapi.json" ) +# 注册自定义异常处理器 +app.add_exception_handler(RequestValidationError, validation_exception_handler) +app.add_exception_handler(StarletteHTTPException, http_exception_handler) +app.add_exception_handler(Exception, general_exception_handler) app.include_router(router) @app.get("/") +@standard_response def root(): return {"message": "Vector Search API is running"} diff --git a/app/router.py b/app/router.py index 47c3367..e2c01b7 100644 --- a/app/router.py +++ b/app/router.py @@ -8,6 +8,7 @@ from app.dependencies import ( id_to_index, index_to_id ) import faiss +from app.utils.response_wrapper import standard_response router = APIRouter() aggregate_index = None @@ -21,6 +22,7 @@ class ErrorQuery(BaseModel): top_n: int @router.post("/insert", summary="插入数据并重新建立索引") +@standard_response def insert_errors(errors: List[ErrorInsert]): global aggregate_index, id_to_index, index_to_id @@ -40,10 +42,12 @@ def insert_errors(errors: List[ErrorInsert]): return {"inserted": [e.db_id for e in errors]} @router.get("/list", summary="获取已建立索引的所有 db_id") +@standard_response def list_db_ids(): return list(error_memory.keys()) @router.delete("/delete", summary="删除指定 db_id 的数据") +@standard_response def delete_errors(ids: List[str] = Query(...)): global aggregate_index, id_to_index, index_to_id deleted = [] @@ -57,6 +61,7 @@ def delete_errors(ids: List[str] = Query(...)): return {"deleted": deleted} @router.put("/update", summary="更新指定 db_id 的数据") +@standard_response def update_error(error: ErrorInsert): if error.db_id not in error_memory: raise HTTPException(status_code=404, detail="db_id not found") @@ -77,7 +82,29 @@ def update_error(error: ErrorInsert): aggregate_index, id_to_index, index_to_id = rebuild_index(error_memory) return {"updated": error.db_id} +@router.get("/page", summary="分页获取错误信息") +@standard_response +def get_error_page(page: int = Query(1, ge=1), page_size: int = Query(10, gt=0)): + if not error_memory: + raise HTTPException(status_code=404, detail="数据库为空") + + total = len(error_memory) + start = (page - 1) * page_size + end = start + page_size + + if start >= total: + raise HTTPException(status_code=404, detail="页码超出范围") + + paginated_errors = list(error_memory.items())[start:end] + return { + "total": total, + "page": page, + "page_size": page_size, + "data": [{"db_id": db_id, "error": entry["error"]} for db_id, entry in paginated_errors] + } + @router.post("/search", summary="搜索 top_n 个最相似的错误") +@standard_response def search_error(query: ErrorQuery): if not error_memory: raise HTTPException(status_code=404, detail="数据库为空") diff --git a/app/schemas/__pycache__/response.cpython-310.pyc b/app/schemas/__pycache__/response.cpython-310.pyc new file mode 100644 index 0000000..30adef2 Binary files /dev/null and b/app/schemas/__pycache__/response.cpython-310.pyc differ diff --git a/app/schemas/response.py b/app/schemas/response.py new file mode 100644 index 0000000..be10411 --- /dev/null +++ b/app/schemas/response.py @@ -0,0 +1,9 @@ +# app/schemas/response.py +from pydantic import BaseModel +from typing import Any, Optional + +class ResponseModel(BaseModel): + msg: str + success: bool + data: Optional[Any] = None + code: Optional[int] = None diff --git a/app/utils/__pycache__/response_wrapper.cpython-310.pyc b/app/utils/__pycache__/response_wrapper.cpython-310.pyc new file mode 100644 index 0000000..952d2e0 Binary files /dev/null and b/app/utils/__pycache__/response_wrapper.cpython-310.pyc differ diff --git a/app/utils/response_wrapper.py b/app/utils/response_wrapper.py new file mode 100644 index 0000000..4d28851 --- /dev/null +++ b/app/utils/response_wrapper.py @@ -0,0 +1,16 @@ +# app/utils/response_wrapper.py +from functools import wraps +from fastapi.responses import JSONResponse +from app.schemas.response import ResponseModel +import inspect + +def standard_response(func): + @wraps(func) + async def wrapper(*args, **kwargs): + if inspect.iscoroutinefunction(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + response = ResponseModel(msg="success", success=True, data=result, code=200) + return JSONResponse(content=response.model_dump()) + return wrapper