|
18 | 18 | ComputeShapeSummary,
|
19 | 19 | ContainerPath,
|
20 | 20 | )
|
21 |
| -from ads.aqua.common.enums import InferenceContainerTypeFamily, ModelFormat, Tags |
| 21 | +from ads.aqua.common.enums import ( |
| 22 | + InferenceContainerTypeFamily, |
| 23 | + ModelFormat, |
| 24 | + PredictEndpoints, |
| 25 | + Tags, |
| 26 | +) |
22 | 27 | from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
|
23 | 28 | from ads.aqua.common.utils import (
|
24 | 29 | DEFINED_METADATA_TO_FILE_MAP,
|
@@ -937,7 +942,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
|
937 | 942 | model_deployment = self.ds_client.get_model_deployment(
|
938 | 943 | model_deployment_id=model_deployment_id, **kwargs
|
939 | 944 | ).data
|
940 |
| - |
941 | 945 | oci_aqua = (
|
942 | 946 | (
|
943 | 947 | Tags.AQUA_TAG in model_deployment.freeform_tags
|
@@ -982,7 +986,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
|
982 | 986 | aqua_deployment = AquaDeployment.from_oci_model_deployment(
|
983 | 987 | model_deployment, self.region
|
984 | 988 | )
|
985 |
| - |
986 | 989 | if Tags.MULTIMODEL_TYPE_TAG in model_deployment.freeform_tags:
|
987 | 990 | aqua_model_id = model_deployment.freeform_tags.get(
|
988 | 991 | Tags.AQUA_MODEL_ID_TAG, UNKNOWN
|
@@ -1013,7 +1016,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
|
1013 | 1016 | aqua_deployment.models = [
|
1014 | 1017 | AquaMultiModelRef(**metadata) for metadata in multi_model_metadata
|
1015 | 1018 | ]
|
1016 |
| - |
1017 | 1019 | return AquaDeploymentDetail(
|
1018 | 1020 | **vars(aqua_deployment),
|
1019 | 1021 | log_group=AquaResourceIdentifier(
|
@@ -1334,7 +1336,9 @@ def _stream_sanitizer(response):
|
1334 | 1336 | continue
|
1335 | 1337 |
|
1336 | 1338 | @telemetry(entry_point="plugin=inference&action=get_response", name="aqua")
|
1337 |
| - def get_model_deployment_response(self, model_deployment_id: str, payload: dict): |
| 1339 | + def get_model_deployment_response( |
| 1340 | + self, model_deployment_id: str, payload: dict, route_override_header: str |
| 1341 | + ): |
1338 | 1342 | """
|
1339 | 1343 | Returns Model deployment inference response in streaming fashion
|
1340 | 1344 |
|
@@ -1365,29 +1369,34 @@ def get_model_deployment_response(self, model_deployment_id: str, payload: dict)
|
1365 | 1369 | model_deployment = self.get(model_deployment_id)
|
1366 | 1370 | endpoint = model_deployment.endpoint + "/predictWithResponseStream"
|
1367 | 1371 | endpoint_type = model_deployment.environment_variables.get(
|
1368 |
| - "MODEL_DEPLOY_PREDICT_ENDPOINT", "/v1/completions" |
| 1372 | + "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT |
1369 | 1373 | )
|
1370 | 1374 | aqua_client = Client(endpoint=endpoint)
|
1371 |
| - if endpoint_type == "/v1/completions": |
1372 |
| - for chunk in aqua_client.generate( |
1373 |
| - prompt=payload.pop("prompt"), |
| 1375 | + |
| 1376 | + if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( |
| 1377 | + endpoint_type, |
| 1378 | + route_override_header, |
| 1379 | + ): |
| 1380 | + for chunk in aqua_client.chat( |
| 1381 | + messages=payload.pop("messages"), |
1374 | 1382 | payload=payload,
|
1375 | 1383 | stream=True,
|
1376 | 1384 | ):
|
1377 | 1385 | try:
|
1378 |
| - yield chunk["choices"][0]["text"] |
| 1386 | + yield chunk["choices"][0]["delta"]["content"] |
1379 | 1387 | except Exception as e:
|
1380 | 1388 | logger.debug(
|
1381 | 1389 | f"Exception occurred while parsing streaming response: {e}"
|
1382 | 1390 | )
|
1383 |
| - else: |
1384 |
| - for chunk in aqua_client.chat( |
1385 |
| - messages=payload.pop("messages"), |
| 1391 | + |
| 1392 | + elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: |
| 1393 | + for chunk in aqua_client.generate( |
| 1394 | + prompt=payload.pop("prompt"), |
1386 | 1395 | payload=payload,
|
1387 | 1396 | stream=True,
|
1388 | 1397 | ):
|
1389 | 1398 | try:
|
1390 |
| - yield chunk["choices"][0]["delta"]["content"] |
| 1399 | + yield chunk["choices"][0]["text"] |
1391 | 1400 | except Exception as e:
|
1392 | 1401 | logger.debug(
|
1393 | 1402 | f"Exception occurred while parsing streaming response: {e}"
|
|
0 commit comments