Skip to content

Commit 980d422

Browse files
committed
Fix status code: split pipeline load from input parsing
pipeline loading -> 500 input parsing -> 400 Signed-off-by: Raphael Glon <[email protected]>
1 parent 45907cd commit 980d422

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

api_inference_community/routes.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
IMAGE,
1414
IMAGE_INPUTS,
1515
IMAGE_OUTPUTS,
16+
KNOWN_INPUTS,
1617
ffmpeg_convert,
1718
normalize_payload,
1819
parse_accept,
@@ -88,6 +89,15 @@ def already_left(request: Request) -> bool:
8889
async def pipeline_route(request: Request) -> Response:
8990
start = time.time()
9091

92+
task = os.environ["TASK"]
93+
94+
# Shortcut: quickly check the task is in enum: no need to go any further otherwise, as we know for sure that
95+
# normalize_payload will fail below: this avoids us to wait for the pipeline to be loaded to return
96+
if task not in KNOWN_INPUTS:
97+
msg = f"The task `{task}` is not recognized by api-inference-community"
98+
logger.error(msg)
99+
return JSONResponse({"error": msg}, status_code=500)
100+
91101
if os.getenv("DISCARD_LEFT", "0").lower() in [
92102
"1",
93103
"true",
@@ -97,16 +107,30 @@ async def pipeline_route(request: Request) -> Response:
97107
return Response(status_code=204)
98108

99109
payload = await request.body()
100-
task = os.environ["TASK"]
110+
101111
if os.getenv("DEBUG", "0") in {"1", "true"}:
102112
pipe = request.app.get_pipeline()
113+
103114
try:
104115
pipe = request.app.get_pipeline()
105116
try:
106117
sampling_rate = pipe.sampling_rate
107118
except Exception:
108119
sampling_rate = None
120+
if task in AUDIO_INPUTS:
121+
msg = f"The task `{task}` is not recognized by api-inference-community"
122+
logger.error(msg)
123+
return JSONResponse({"error": msg}, status_code=500)
124+
except Exception as e:
125+
return JSONResponse({"error": str(e)}, status_code=500)
126+
127+
try:
109128
inputs, params = normalize_payload(payload, task, sampling_rate=sampling_rate)
129+
except EnvironmentError as e:
130+
# Since we catch the environment edge cases earlier above, this should not happen here anymore
131+
# harmless to keep it, just in case
132+
logger.error("Error while parsing input %s", e)
133+
return JSONResponse({"error": str(e)}, status_code=500)
110134
except ValidationError as e:
111135
errors = []
112136
for error in e.errors():
@@ -120,7 +144,9 @@ async def pipeline_route(request: Request) -> Response:
120144
)
121145
return JSONResponse({"error": errors}, status_code=400)
122146
except Exception as e:
123-
return JSONResponse({"error": str(e)}, status_code=500)
147+
# We assume the payload is bad -> 400
148+
logger.warning("Error while parsing input %s", e)
149+
return JSONResponse({"error": str(e)}, status_code=400)
124150

125151
accept = request.headers.get("accept", "")
126152
lora_adapter = request.headers.get("lora")

api_inference_community/validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def check_inputs(inputs, tag):
218218
"zero-shot-classification",
219219
}
220220

221+
KNOWN_INPUTS = AUDIO_INPUTS.union(IMAGE_INPUTS).union(TEXT_INPUTS)
221222

222223
AUDIO = [
223224
"flac",

0 commit comments

Comments
 (0)