|
11 | 11 | from oci.data_science.models import ModelDeploymentShapeSummary
|
12 | 12 | from pydantic import ValidationError
|
13 | 13 |
|
| 14 | +from ads.aqua import Client |
14 | 15 | from ads.aqua.app import AquaApp, logger
|
15 | 16 | from ads.aqua.common.entities import (
|
16 | 17 | AquaMultiModelRef,
|
@@ -1361,7 +1362,33 @@ def get_model_deployment_response(self, model_deployment_id: str, payload: dict)
|
1361 | 1362 |
|
1362 | 1363 | """
|
1363 | 1364 |
|
1364 |
| - response = self.model_deployment_client.predict_with_response_stream( |
1365 |
| - model_deployment_id=model_deployment_id, request_body=payload |
| 1365 | + model_deployment = self.get(model_deployment_id) |
| 1366 | + endpoint = model_deployment.endpoint + "/predictWithResponseStream" |
| 1367 | + endpoint_type = model_deployment.environment_variables.get( |
| 1368 | + "MODEL_DEPLOY_PREDICT_ENDPOINT", "/v1/completions" |
1366 | 1369 | )
|
1367 |
| - yield from self._stream_sanitizer(response) |
| 1370 | + aqua_client = Client(endpoint=endpoint) |
| 1371 | + if endpoint_type == "/v1/completions": |
| 1372 | + for chunk in aqua_client.generate( |
| 1373 | + prompt=payload.pop("prompt"), |
| 1374 | + payload=payload, |
| 1375 | + stream=True, |
| 1376 | + ): |
| 1377 | + try: |
| 1378 | + yield chunk["choices"][0]["text"] |
| 1379 | + except Exception as e: |
| 1380 | + logger.debug( |
| 1381 | + f"Exception occurred while parsing streaming response: {e}" |
| 1382 | + ) |
| 1383 | + else: |
| 1384 | + for chunk in aqua_client.chat( |
| 1385 | + messages=payload.pop("messages"), |
| 1386 | + payload=payload, |
| 1387 | + stream=True, |
| 1388 | + ): |
| 1389 | + try: |
| 1390 | + yield chunk["choices"][0]["delta"]["content"] |
| 1391 | + except Exception as e: |
| 1392 | + logger.debug( |
| 1393 | + f"Exception occurred while parsing streaming response: {e}" |
| 1394 | + ) |
0 commit comments