5
5
from typing import List , Union
6
6
from urllib .parse import urlparse
7
7
8
+ from tornado .iostream import StreamClosedError
8
9
from tornado .web import HTTPError
9
10
10
11
from ads .aqua .common .decorator import handle_exceptions
@@ -175,21 +176,9 @@ def list_shapes(self):
175
176
)
176
177
177
178
178
- class AquaDeploymentInferenceHandler (AquaAPIhandler ):
179
- @staticmethod
180
- def validate_predict_url (endpoint ):
181
- try :
182
- url = urlparse (endpoint )
183
- if url .scheme != "https" :
184
- return False
185
- if not url .netloc :
186
- return False
187
- return url .path .endswith ("/predict" )
188
- except Exception :
189
- return False
190
-
179
+ class AquaDeploymentStreamingInferenceHandler (AquaAPIhandler ):
191
180
@handle_exceptions
192
- def post (self , * args , ** kwargs ): # noqa: ARG002
181
+ async def post (self , * args , ** kwargs ): # noqa: ARG002
193
182
"""
194
183
Handles inference request for the Active Model Deployments
195
184
Raises
@@ -205,12 +194,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
205
194
if not input_data :
206
195
raise HTTPError (400 , Errors .NO_INPUT_DATA )
207
196
208
- endpoint = input_data .get ("endpoint" )
209
- if not endpoint :
210
- raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("endpoint" ))
211
-
212
- if not self .validate_predict_url (endpoint ):
213
- raise HTTPError (400 , Errors .INVALID_INPUT_DATA_FORMAT .format ("endpoint" ))
197
+ model_deployment_id = input_data .get ("id" )
214
198
215
199
prompt = input_data .get ("prompt" )
216
200
if not prompt :
@@ -226,11 +210,24 @@ def post(self, *args, **kwargs): # noqa: ARG002
226
210
400 , Errors .INVALID_INPUT_DATA_FORMAT .format ("model_params" )
227
211
) from ex
228
212
229
- return self .finish (
230
- MDInferenceResponse (prompt , model_params_obj ).get_model_deployment_response (
231
- endpoint
232
- )
233
- )
213
+ self .set_header ("Content-Type" , "text/event-stream" )
214
+ self .set_header ("Cache-Control" , "no-cache" )
215
+ self .set_header ("Transfer-Encoding" , "chunked" )
216
+ await self .flush ()
217
+
218
+ try :
219
+ response_gen = MDInferenceResponse (
220
+ prompt , model_params_obj
221
+ ).get_model_deployment_response (model_deployment_id )
222
+ for chunk in response_gen :
223
+ if not chunk :
224
+ continue
225
+ self .write (f"data: { chunk } \n \n " )
226
+ await self .flush ()
227
+ except StreamClosedError :
228
+ self .log .warning ("Client disconnected." )
229
+ finally :
230
+ self .finish ()
234
231
235
232
236
233
class AquaDeploymentParamsHandler (AquaAPIhandler ):
@@ -294,5 +291,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
294
291
("deployments/?([^/]*)" , AquaDeploymentHandler ),
295
292
("deployments/?([^/]*)/activate" , AquaDeploymentHandler ),
296
293
("deployments/?([^/]*)/deactivate" , AquaDeploymentHandler ),
297
- ("inference" , AquaDeploymentInferenceHandler ),
294
+ ("inference" , AquaDeploymentStreamingInferenceHandler ),
298
295
]
0 commit comments