Skip to content

Commit c7b7a42

Browse files
Adding endpoint override feature
1 parent 2e6195d commit c7b7a42

File tree

3 files changed

+34
-22
lines changed

3 files changed

+34
-22
lines changed

ads/aqua/common/enums.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class Resource(ExtendedEnum):
2020
MODEL_VERSION_SET = "model-version-sets"
2121

2222

23+
class PredictEndpoints(ExtendedEnum):
24+
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
25+
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
26+
EMBEDDING_ENDPOINT = "/v1/embedding"
27+
28+
2329
class Tags(ExtendedEnum):
2430
TASK = "task"
2531
LICENSE = "license"

ads/aqua/extension/deployment_handler.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def list_shapes(self):
176176

177177
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
178178
@handle_exceptions
179-
async def post(self, model_deployment_id):
179+
def post(self, model_deployment_id):
180180
"""
181181
Handles streaming inference request for the Active Model Deployments
182182
Raises
@@ -201,21 +201,18 @@ async def post(self, model_deployment_id):
201201
)
202202
if not input_data.get("model"):
203203
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
204-
205-
if "stream" not in input_data:
206-
input_data.update(stream=True)
207-
204+
route_override_header = self.request.headers.get("route", None)
208205
self.set_header("Content-Type", "text/event-stream")
209206
self.set_header("Cache-Control", "no-cache")
210207
self.set_header("Transfer-Encoding", "chunked")
211-
await self.flush()
208+
self.flush()
212209
try:
213210
response_gen = AquaDeploymentApp().get_model_deployment_response(
214-
model_deployment_id, input_data
211+
model_deployment_id, input_data, route_override_header
215212
)
216213
for chunk in response_gen:
217214
self.write(chunk)
218-
await self.flush()
215+
self.flush()
219216
except Exception as ex:
220217
raise HTTPError(500, str(ex)) from ex
221218
finally:

ads/aqua/modeldeployment/deployment.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
ComputeShapeSummary,
1919
ContainerPath,
2020
)
21-
from ads.aqua.common.enums import InferenceContainerTypeFamily, ModelFormat, Tags
21+
from ads.aqua.common.enums import (
22+
InferenceContainerTypeFamily,
23+
ModelFormat,
24+
PredictEndpoints,
25+
Tags,
26+
)
2227
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
2328
from ads.aqua.common.utils import (
2429
DEFINED_METADATA_TO_FILE_MAP,
@@ -937,7 +942,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
937942
model_deployment = self.ds_client.get_model_deployment(
938943
model_deployment_id=model_deployment_id, **kwargs
939944
).data
940-
941945
oci_aqua = (
942946
(
943947
Tags.AQUA_TAG in model_deployment.freeform_tags
@@ -982,7 +986,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
982986
aqua_deployment = AquaDeployment.from_oci_model_deployment(
983987
model_deployment, self.region
984988
)
985-
986989
if Tags.MULTIMODEL_TYPE_TAG in model_deployment.freeform_tags:
987990
aqua_model_id = model_deployment.freeform_tags.get(
988991
Tags.AQUA_MODEL_ID_TAG, UNKNOWN
@@ -1013,7 +1016,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
10131016
aqua_deployment.models = [
10141017
AquaMultiModelRef(**metadata) for metadata in multi_model_metadata
10151018
]
1016-
10171019
return AquaDeploymentDetail(
10181020
**vars(aqua_deployment),
10191021
log_group=AquaResourceIdentifier(
@@ -1334,7 +1336,9 @@ def _stream_sanitizer(response):
13341336
continue
13351337

13361338
@telemetry(entry_point="plugin=inference&action=get_response", name="aqua")
1337-
def get_model_deployment_response(self, model_deployment_id: str, payload: dict):
1339+
def get_model_deployment_response(
1340+
self, model_deployment_id: str, payload: dict, route_override_header: str
1341+
):
13381342
"""
13391343
Returns Model deployment inference response in streaming fashion
13401344
@@ -1365,29 +1369,34 @@ def get_model_deployment_response(self, model_deployment_id: str, payload: dict)
13651369
model_deployment = self.get(model_deployment_id)
13661370
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
13671371
endpoint_type = model_deployment.environment_variables.get(
1368-
"MODEL_DEPLOY_PREDICT_ENDPOINT", "/v1/completions"
1372+
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
13691373
)
13701374
aqua_client = Client(endpoint=endpoint)
1371-
if endpoint_type == "/v1/completions":
1372-
for chunk in aqua_client.generate(
1373-
prompt=payload.pop("prompt"),
1375+
1376+
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
1377+
endpoint_type,
1378+
route_override_header,
1379+
):
1380+
for chunk in aqua_client.chat(
1381+
messages=payload.pop("messages"),
13741382
payload=payload,
13751383
stream=True,
13761384
):
13771385
try:
1378-
yield chunk["choices"][0]["text"]
1386+
yield chunk["choices"][0]["delta"]["content"]
13791387
except Exception as e:
13801388
logger.debug(
13811389
f"Exception occurred while parsing streaming response: {e}"
13821390
)
1383-
else:
1384-
for chunk in aqua_client.chat(
1385-
messages=payload.pop("messages"),
1391+
1392+
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
1393+
for chunk in aqua_client.generate(
1394+
prompt=payload.pop("prompt"),
13861395
payload=payload,
13871396
stream=True,
13881397
):
13891398
try:
1390-
yield chunk["choices"][0]["delta"]["content"]
1399+
yield chunk["choices"][0]["text"]
13911400
except Exception as e:
13921401
logger.debug(
13931402
f"Exception occurred while parsing streaming response: {e}"

0 commit comments

Comments
 (0)