统一返回格式
This commit is contained in:
parent
073781f809
commit
1bb97f78fe
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,39 @@
|
|||
# app/exception_handlers.py
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from app.schemas.response import ResponseModel
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=ResponseModel(
|
||||
msg="Validation Failed",
|
||||
success=False,
|
||||
data=exc.errors(), # 可视情况换成 None 或 str(exc)
|
||||
code=422
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ResponseModel(
|
||||
msg=exc.detail,
|
||||
success=False,
|
||||
data=None,
|
||||
code=exc.status_code
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=ResponseModel(
|
||||
msg="Internal Server Error",
|
||||
success=False,
|
||||
data=str(exc),
|
||||
code=500
|
||||
).model_dump()
|
||||
)
|
||||
13
app/main.py
13
app/main.py
|
|
@ -1,5 +1,13 @@
|
|||
from fastapi import FastAPI
|
||||
from app.router import router
|
||||
from app.utils.response_wrapper import standard_response
|
||||
from app.exception_handlers import (
|
||||
validation_exception_handler,
|
||||
http_exception_handler,
|
||||
general_exception_handler
|
||||
)
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
app = FastAPI(
|
||||
title="Vector Search API",
|
||||
|
|
@ -9,8 +17,13 @@ app = FastAPI(
|
|||
openapi_url="/openapi.json"
|
||||
)
|
||||
|
||||
# 注册自定义异常处理器
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
|
||||
app.add_exception_handler(Exception, general_exception_handler)
|
||||
app.include_router(router)
|
||||
|
||||
@app.get("/")
|
||||
@standard_response
|
||||
def root():
|
||||
return {"message": "Vector Search API is running"}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from app.dependencies import (
|
|||
id_to_index, index_to_id
|
||||
)
|
||||
import faiss
|
||||
from app.utils.response_wrapper import standard_response
|
||||
|
||||
router = APIRouter()
|
||||
aggregate_index = None
|
||||
|
|
@ -21,6 +22,7 @@ class ErrorQuery(BaseModel):
|
|||
top_n: int
|
||||
|
||||
@router.post("/insert", summary="插入数据并重新建立索引")
|
||||
@standard_response
|
||||
def insert_errors(errors: List[ErrorInsert]):
|
||||
global aggregate_index, id_to_index, index_to_id
|
||||
|
||||
|
|
@ -40,10 +42,12 @@ def insert_errors(errors: List[ErrorInsert]):
|
|||
return {"inserted": [e.db_id for e in errors]}
|
||||
|
||||
@router.get("/list", summary="获取已建立索引的所有 db_id")
|
||||
@standard_response
|
||||
def list_db_ids():
|
||||
return list(error_memory.keys())
|
||||
|
||||
@router.delete("/delete", summary="删除指定 db_id 的数据")
|
||||
@standard_response
|
||||
def delete_errors(ids: List[str] = Query(...)):
|
||||
global aggregate_index, id_to_index, index_to_id
|
||||
deleted = []
|
||||
|
|
@ -57,6 +61,7 @@ def delete_errors(ids: List[str] = Query(...)):
|
|||
return {"deleted": deleted}
|
||||
|
||||
@router.put("/update", summary="更新指定 db_id 的数据")
|
||||
@standard_response
|
||||
def update_error(error: ErrorInsert):
|
||||
if error.db_id not in error_memory:
|
||||
raise HTTPException(status_code=404, detail="db_id not found")
|
||||
|
|
@ -77,7 +82,29 @@ def update_error(error: ErrorInsert):
|
|||
aggregate_index, id_to_index, index_to_id = rebuild_index(error_memory)
|
||||
return {"updated": error.db_id}
|
||||
|
||||
@router.get("/page", summary="分页获取错误信息")
|
||||
@standard_response
|
||||
def get_error_page(page: int = Query(1, ge=1), page_size: int = Query(10, gt=0)):
|
||||
if not error_memory:
|
||||
raise HTTPException(status_code=404, detail="数据库为空")
|
||||
|
||||
total = len(error_memory)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
if start >= total:
|
||||
raise HTTPException(status_code=404, detail="页码超出范围")
|
||||
|
||||
paginated_errors = list(error_memory.items())[start:end]
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"data": [{"db_id": db_id, "error": entry["error"]} for db_id, entry in paginated_errors]
|
||||
}
|
||||
|
||||
@router.post("/search", summary="搜索 top_n 个最相似的错误")
|
||||
@standard_response
|
||||
def search_error(query: ErrorQuery):
|
||||
if not error_memory:
|
||||
raise HTTPException(status_code=404, detail="数据库为空")
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -0,0 +1,9 @@
|
|||
# app/schemas/response.py
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Optional
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
msg: str
|
||||
success: bool
|
||||
data: Optional[Any] = None
|
||||
code: Optional[int] = None
|
||||
Binary file not shown.
|
|
@ -0,0 +1,16 @@
|
|||
# app/utils/response_wrapper.py
|
||||
from functools import wraps
|
||||
from fastapi.responses import JSONResponse
|
||||
from app.schemas.response import ResponseModel
|
||||
import inspect
|
||||
|
||||
def standard_response(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
response = ResponseModel(msg="success", success=True, data=result, code=200)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
return wrapper
|
||||
Loading…
Reference in New Issue