13
13
IMAGE ,
14
14
IMAGE_INPUTS ,
15
15
IMAGE_OUTPUTS ,
16
+ KNOWN_INPUTS ,
16
17
ffmpeg_convert ,
17
18
normalize_payload ,
18
19
parse_accept ,
@@ -88,6 +89,15 @@ def already_left(request: Request) -> bool:
88
89
async def pipeline_route (request : Request ) -> Response :
89
90
start = time .time ()
90
91
92
+ task = os .environ ["TASK" ]
93
+
94
+ # Shortcut: quickly check the task is in enum: no need to go any further otherwise, as we know for sure that
95
+ # normalize_payload will fail below: this avoids us to wait for the pipeline to be loaded to return
96
+ if task not in KNOWN_INPUTS :
97
+ msg = f"The task `{ task } ` is not recognized by api-inference-community"
98
+ logger .error (msg )
99
+ return JSONResponse ({"error" : msg }, status_code = 500 )
100
+
91
101
if os .getenv ("DISCARD_LEFT" , "0" ).lower () in [
92
102
"1" ,
93
103
"true" ,
@@ -97,16 +107,30 @@ async def pipeline_route(request: Request) -> Response:
97
107
return Response (status_code = 204 )
98
108
99
109
payload = await request .body ()
100
- task = os . environ [ "TASK" ]
110
+
101
111
if os .getenv ("DEBUG" , "0" ) in {"1" , "true" }:
102
112
pipe = request .app .get_pipeline ()
113
+
103
114
try :
104
115
pipe = request .app .get_pipeline ()
105
116
try :
106
117
sampling_rate = pipe .sampling_rate
107
118
except Exception :
108
119
sampling_rate = None
120
+ if task in AUDIO_INPUTS :
121
+ msg = f"Sampling rate is expected for model for audio task { task } "
122
+ logger .error (msg )
123
+ return JSONResponse ({"error" : msg }, status_code = 500 )
124
+ except Exception as e :
125
+ return JSONResponse ({"error" : str (e )}, status_code = 500 )
126
+
127
+ try :
109
128
inputs , params = normalize_payload (payload , task , sampling_rate = sampling_rate )
129
+ except EnvironmentError as e :
130
+ # Since we catch the environment edge cases earlier above, this should not happen here anymore
131
+ # harmless to keep it, just in case
132
+ logger .error ("Error while parsing input %s" , e )
133
+ return JSONResponse ({"error" : str (e )}, status_code = 500 )
110
134
except ValidationError as e :
111
135
errors = []
112
136
for error in e .errors ():
@@ -120,7 +144,9 @@ async def pipeline_route(request: Request) -> Response:
120
144
)
121
145
return JSONResponse ({"error" : errors }, status_code = 400 )
122
146
except Exception as e :
123
- return JSONResponse ({"error" : str (e )}, status_code = 500 )
147
+ # We assume the payload is bad -> 400
148
+ logger .warning ("Error while parsing input %s" , e )
149
+ return JSONResponse ({"error" : str (e )}, status_code = 400 )
124
150
125
151
accept = request .headers .get ("accept" , "" )
126
152
lora_adapter = request .headers .get ("lora" )
0 commit comments