From af5d8a528ddfad62064a8dbbafb0d700fe6ab9c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 9 Oct 2024 18:03:41 +0200 Subject: [PATCH] fix: API error management --- app/main.py | 60 +++++++++++++++++++++++++---------------------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/app/main.py b/app/main.py index 569af698..98f487dc 100644 --- a/app/main.py +++ b/app/main.py @@ -314,13 +314,13 @@ async def call_list_get( try: phone_number = PhoneNumber(phone_number) if phone_number else None except ValueError as e: - raise _validation_error(e) + raise RequestValidationError([f"Invalid phone number: {e}"]) from e count = 100 calls, _ = await _db.call_asearch_all(phone_number=phone_number, count=count) if not calls: - raise _standard_error( - message=f"Call {phone_number} not found", + raise HTTPException( + detail=f"Call {phone_number} not found", status_code=HTTPStatus.NOT_FOUND, ) @@ -352,11 +352,11 @@ async def call_get(call_id_or_phone_number: str) -> CallGetModel: try: phone_number = PhoneNumber(call_id_or_phone_number) except ValueError as e: - raise _validation_error(e) + raise RequestValidationError([str(e)]) from e call = await _db.call_asearch_one(phone_number=phone_number) if not call: - raise _standard_error( - message=f"Call {phone_number} not found", + raise HTTPException( + detail=f"Call {call_id_or_phone_number} not found", status_code=HTTPStatus.NOT_FOUND, ) return TypeAdapter(CallGetModel).dump_python(call) @@ -379,7 +379,7 @@ async def call_post(request: Request) -> CallGetModel: body = await request.json() initiate = CallInitiateModel.model_validate(body) except ValidationError as e: - raise _validation_error(e) + raise RequestValidationError([str(e)]) from e url, call = await _communicationservices_event_url(initiate.phone_number, initiate) span_attribute(CallAttributes.CALL_ID, str(call.call_id)) @@ -490,8 +490,8 @@ async def communicationservices_event_post( # Validate JWT token service_jwt: str | None = request.headers.get("Authorization") if not service_jwt: - raise _standard_error( - message="Authorization header missing", + raise HTTPException( + detail="Authorization header missing", status_code=HTTPStatus.UNAUTHORIZED, ) @@ -511,24 +511,18 @@ async def communicationservices_event_post( ) except jwt.PyJWTError: logger.warning("Invalid JWT token", exc_info=True) - raise _standard_error( - message="Invalid JWT token", + raise HTTPException( + detail="Invalid JWT token", status_code=HTTPStatus.UNAUTHORIZED, ) # Validate request try: events = await request.json() - except ValueError: - raise _standard_error( - message="Invalid JSON format", - status_code=HTTPStatus.BAD_REQUEST, - ) + except ValueError as e: + raise RequestValidationError([f"Invalid JSON format: {e}"]) from e if not events or not isinstance(events, list): - raise _standard_error( - message="Events must be a list", - status_code=HTTPStatus.BAD_REQUEST, - ) + raise RequestValidationError(["Events must be a list"]) # Process events in parallel await asyncio.gather( @@ -820,11 +814,9 @@ async def _training_callback(_call: CallStateModel) -> None: training_callback=_training_callback, ) if not event_status: - return JSONResponse( - content=_standard_error( - message="SMS event failed", - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - ) + raise HTTPException( + detail="SMS event failed", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) return Response( @@ -835,7 +827,9 @@ async def _training_callback(_call: CallStateModel) -> None: @api.exception_handler(StarletteHTTPException) -async def http_exception_handler(request: Request, exc: StarletteHTTPException): # noqa: ARG001 +async def http_exception_handler( + request: Request, exc: StarletteHTTPException +) -> JSONResponse: # noqa: ARG001 """ Handle HTTP exceptions and return the error in a standard format. """ @@ -846,7 +840,9 @@ async def http_exception_handler(request: Request, exc: StarletteHTTPException): @api.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError): # noqa: ARG001 +async def validation_exception_handler( + request: Request, exc: RequestValidationError +) -> JSONResponse: # noqa: ARG001 """ Handle validation exceptions and return the error in a standard format. """ @@ -876,7 +872,7 @@ def _str_to_contexts(value: str | None) -> set[CallContextEnum] | None: return res or None -def _validation_error(e: ValidationError | Exception) -> HTTPException: +def _validation_error(e: ValidationError | Exception) -> JSONResponse: """ Generate a standard validation error response. """ @@ -896,9 +892,9 @@ def _validation_error(e: ValidationError | Exception) -> HTTPException: def _standard_error( message: str, + status_code, details: list[str] | None = None, - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, -) -> HTTPException: +) -> JSONResponse: """ Generate a standard error response. """ @@ -908,8 +904,8 @@ def _standard_error( message=message, ) ) - return HTTPException( - detail=model.model_dump_json(), + return JSONResponse( + content=model.model_dump(mode="json"), status_code=status_code, )