Skip to content

Commit d1d942d

Browse files
Updating payload format
1 parent 177f888 commit d1d942d

File tree

5 files changed

+85
-72
lines changed

5 files changed

+85
-72
lines changed

ads/aqua/app.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ def __init__(self) -> None:
6767
self._md_auth = default_signer({"service_endpoint": OCI_MD_SERVICE_ENDPOINT})
6868
self.ds_client = oc.OCIClientFactory(**self._auth).data_science
6969
self.compute_client = oc.OCIClientFactory(**default_signer()).compute
70-
print("self._md_auth: ", self._md_auth)
71-
print("OCI_MD_SERVICE_ENDPOINT: ", OCI_MD_SERVICE_ENDPOINT)
7270
self.model_deployment_client = oc.OCIClientFactory(
7371
**self._md_auth
7472
).model_deployment

ads/aqua/extension/deployment_handler.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from ads.aqua.common.decorator import handle_exceptions
1212
from ads.aqua.extension.base_handler import AquaAPIhandler
1313
from ads.aqua.extension.errors import Errors
14-
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
15-
from ads.aqua.modeldeployment.entities import ModelParams
14+
from ads.aqua.modeldeployment import AquaDeploymentApp
1615
from ads.config import COMPARTMENT_OCID
1716

1817

@@ -178,9 +177,9 @@ def list_shapes(self):
178177

179178
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
180179
@handle_exceptions
181-
async def post(self, *args, **kwargs): # noqa: ARG002
180+
async def post(self, model_deployment_id):
182181
"""
183-
Handles inference request for the Active Model Deployments
182+
Handles streaming inference request for the Active Model Deployments
184183
Raises
185184
------
186185
HTTPError
@@ -194,38 +193,34 @@ async def post(self, *args, **kwargs): # noqa: ARG002
194193
if not input_data:
195194
raise HTTPError(400, Errors.NO_INPUT_DATA)
196195

197-
model_deployment_id = input_data.get("id")
198-
199196
prompt = input_data.get("prompt")
200-
if not prompt:
201-
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt"))
197+
messages = input_data.get("messages")
202198

203-
model_params = (
204-
input_data.get("model_params") if input_data.get("model_params") else {}
205-
)
206-
try:
207-
model_params_obj = ModelParams(**model_params)
208-
except Exception as ex:
199+
if not prompt and not messages:
209200
raise HTTPError(
210-
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
211-
) from ex
201+
400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages")
202+
)
203+
if not input_data.get("model"):
204+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
205+
206+
if "stream" not in input_data:
207+
input_data.update(stream=True)
212208

213209
self.set_header("Content-Type", "text/event-stream")
214210
self.set_header("Cache-Control", "no-cache")
215211
self.set_header("Transfer-Encoding", "chunked")
216212
await self.flush()
217-
218213
try:
219-
response_gen = MDInferenceResponse(
220-
prompt, model_params_obj
221-
).get_model_deployment_response(model_deployment_id)
214+
response_gen = AquaDeploymentApp().get_model_deployment_response(
215+
model_deployment_id, input_data
216+
)
222217
for chunk in response_gen:
223218
if not chunk:
224219
continue
225220
self.write(f"data: {chunk}\n\n")
226221
await self.flush()
227-
except StreamClosedError:
228-
self.log.warning("Client disconnected.")
222+
except StreamClosedError as ex:
223+
raise HTTPError(500, str(ex)) from ex
229224
finally:
230225
self.finish()
231226

@@ -291,5 +286,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
291286
("deployments/?([^/]*)", AquaDeploymentHandler),
292287
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
293288
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
294-
("inference", AquaDeploymentStreamingInferenceHandler),
289+
("inference/stream/?([^/]*)", AquaDeploymentStreamingInferenceHandler),
295290
]

ads/aqua/modeldeployment/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
from ads.aqua.modeldeployment.deployment import AquaDeploymentApp
6-
from ads.aqua.modeldeployment.inference import MDInferenceResponse
75

8-
__all__ = ["AquaDeploymentApp", "MDInferenceResponse"]
6+
__all__ = ["AquaDeploymentApp"]

ads/aqua/modeldeployment/deployment.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,9 @@ def _create_multi(
628628
config_data["model_task"] = model.model_task
629629

630630
if model.fine_tune_weights_location:
631-
config_data["fine_tune_weights_location"] = model.fine_tune_weights_location
631+
config_data["fine_tune_weights_location"] = (
632+
model.fine_tune_weights_location
633+
)
632634

633635
model_config.append(config_data)
634636
model_name_list.append(model.model_name)
@@ -789,7 +791,7 @@ def _create_deployment(
789791
telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)}
790792

791793
if Tags.BASE_MODEL_CUSTOM in tags:
792-
telemetry_kwargs[ "custom_base_model"] = True
794+
telemetry_kwargs["custom_base_model"] = True
793795

794796
# tracks unique deployments that were created in the user compartment
795797
self.telemetry.record_event_async(
@@ -1309,4 +1311,57 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
13091311
or gpu_specs.shapes.get(oci_shape.name.upper()),
13101312
)
13111313
for oci_shape in oci_shapes
1312-
]
1314+
]
1315+
1316+
@staticmethod
1317+
def _stream_sanitizer(response):
1318+
for chunk in response.data.raw.stream(1024 * 1024, decode_content=True):
1319+
if not chunk:
1320+
continue
1321+
1322+
try:
1323+
decoded = chunk.decode("utf-8").strip()
1324+
if not decoded.startswith("data:"):
1325+
continue
1326+
1327+
data_json = decoded[len("data:") :].strip()
1328+
parsed = json.loads(data_json)
1329+
text = parsed["choices"][0]["text"]
1330+
yield text
1331+
1332+
except Exception:
1333+
continue
1334+
1335+
@telemetry(entry_point="plugin=inference&action=get_response", name="aqua")
1336+
def get_model_deployment_response(self, model_deployment_id: str, payload: dict):
1337+
"""
1338+
Returns Model deployment inference response in streaming fashion
1339+
1340+
Parameters
1341+
----------
1342+
model_deployment_id: str
1343+
Model deployment ocid
1344+
payload: dict
1345+
model params.
1346+
{
1347+
"max_tokens": 1024,
1348+
"temperature": 0.5,
1349+
"prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?",
1350+
"top_p": 0.4,
1351+
"top_k": 100,
1352+
"model": "odsc-llm",
1353+
"frequency_penalty": 1,
1354+
"presence_penalty": 1,
1355+
"stream": true
1356+
}
1357+
1358+
Returns
1359+
-------
1360+
Model deployment inference response in streaming fashion
1361+
1362+
"""
1363+
1364+
response = self.model_deployment_client.predict_with_response_stream(
1365+
model_deployment_id=model_deployment_id, request_body=payload
1366+
)
1367+
yield from self._stream_sanitizer(response)

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ class TestDataset:
487487
"model_name": "test_model_1",
488488
"model_task": "text_embedding",
489489
"artifact_location": "test_location_1",
490-
"fine_tune_weights_location" : None
490+
"fine_tune_weights_location": None,
491491
},
492492
{
493493
"env_var": {},
@@ -496,7 +496,7 @@ class TestDataset:
496496
"model_name": "test_model_2",
497497
"model_task": "image_text_to_text",
498498
"artifact_location": "test_location_2",
499-
"fine_tune_weights_location" : None
499+
"fine_tune_weights_location": None,
500500
},
501501
{
502502
"env_var": {},
@@ -505,7 +505,7 @@ class TestDataset:
505505
"model_name": "test_model_3",
506506
"model_task": "code_synthesis",
507507
"artifact_location": "test_location_3",
508-
"fine_tune_weights_location" : "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.<ocid>"
508+
"fine_tune_weights_location": "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.<ocid>",
509509
},
510510
],
511511
"model_id": "ocid1.datasciencemodel.oc1.<region>.<OCID>",
@@ -972,7 +972,7 @@ class TestDataset:
972972
"model_name": "model_one",
973973
"model_task": "text_embedding",
974974
"artifact_location": "artifact_location_one",
975-
"fine_tune_weights_location": None
975+
"fine_tune_weights_location": None,
976976
},
977977
{
978978
"env_var": {"--test_key_two": "test_value_two"},
@@ -981,7 +981,7 @@ class TestDataset:
981981
"model_name": "model_two",
982982
"model_task": "image_text_to_text",
983983
"artifact_location": "artifact_location_two",
984-
"fine_tune_weights_location": None
984+
"fine_tune_weights_location": None,
985985
},
986986
{
987987
"env_var": {"--test_key_three": "test_value_three"},
@@ -990,7 +990,7 @@ class TestDataset:
990990
"model_name": "model_three",
991991
"model_task": "code_synthesis",
992992
"artifact_location": "artifact_location_three",
993-
"fine_tune_weights_location" : "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.<ocid>"
993+
"fine_tune_weights_location": "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.<ocid>",
994994
},
995995
]
996996

@@ -1817,7 +1817,7 @@ def test_create_deployment_for_multi_model(
18171817
model_task="code_synthesis",
18181818
gpu_count=2,
18191819
artifact_location="test_location_3",
1820-
fine_tune_weights_location= "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.<ocid>"
1820+
fine_tune_weights_location="oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.<ocid>",
18211821
)
18221822

18231823
result = self.app.create(
@@ -2283,36 +2283,3 @@ def test_validate_multimodel_deployment_feasibility_positive_single(
22832283
total_gpus,
22842284
"test_data/deployment/aqua_summary_multi_model_single.json",
22852285
)
2286-
2287-
2288-
class TestMDInferenceResponse(unittest.TestCase):
2289-
def setUp(self):
2290-
self.app = MDInferenceResponse()
2291-
2292-
@classmethod
2293-
def setUpClass(cls):
2294-
cls.curr_dir = os.path.dirname(os.path.abspath(__file__))
2295-
2296-
@classmethod
2297-
def tearDownClass(cls):
2298-
cls.curr_dir = None
2299-
2300-
@patch("requests.post")
2301-
def test_get_model_deployment_response(self, mock_post):
2302-
"""Test to check if model deployment response is returned correctly."""
2303-
2304-
endpoint = TestDataset.MODEL_DEPLOYMENT_URL + "/predict"
2305-
self.app.prompt = "What is 1+1?"
2306-
self.app.model_params = ModelParams(**TestDataset.model_params)
2307-
2308-
mock_response = MagicMock()
2309-
response_json = os.path.join(
2310-
self.curr_dir, "test_data/deployment/aqua_deployment_response.json"
2311-
)
2312-
with open(response_json, "r") as _file:
2313-
mock_response.content = _file.read()
2314-
mock_response.status_code = 200
2315-
mock_post.return_value = mock_response
2316-
2317-
result = self.app.get_model_deployment_response(endpoint)
2318-
assert result["choices"][0]["text"] == " The answer is 2"

0 commit comments

Comments
 (0)