Skip to content

Commit 0a867c7

Browse files
Fixing UTs
1 parent c7b7a42 commit 0a867c7

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333
AquaContainerConfig,
3434
AquaContainerConfigItem,
3535
)
36-
from ads.aqua.model.enums import MultiModelSupportedTaskType
37-
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
36+
from ads.aqua.modeldeployment import AquaDeploymentApp
3837
from ads.aqua.modeldeployment.entities import (
3938
AquaDeployment,
4039
AquaDeploymentConfig,

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import ads.config
1717
from ads.aqua.extension.deployment_handler import (
1818
AquaDeploymentHandler,
19-
AquaDeploymentInferenceHandler,
2019
AquaDeploymentParamsHandler,
20+
AquaDeploymentStreamingInferenceHandler,
2121
)
2222

2323

@@ -224,23 +224,36 @@ def test_validate_deployment_params(
224224
)
225225

226226

227-
class TestAquaDeploymentInferenceHandler(unittest.TestCase):
227+
class TestAquaDeploymentStreamingInferenceHandler(unittest.TestCase):
228228
@patch.object(IPythonHandler, "__init__")
229229
def setUp(self, ipython_init_mock) -> None:
230230
ipython_init_mock.return_value = None
231-
self.inference_handler = AquaDeploymentInferenceHandler(
232-
MagicMock(), MagicMock()
233-
)
234-
self.inference_handler.request = MagicMock()
235-
self.inference_handler.finish = MagicMock()
236-
237-
@patch("ads.aqua.modeldeployment.MDInferenceResponse.get_model_deployment_response")
231+
self.handler = AquaDeploymentStreamingInferenceHandler(MagicMock(), MagicMock())
232+
self.handler.request = MagicMock()
233+
self.handler.set_header = MagicMock()
234+
self.handler.write = MagicMock()
235+
self.handler.flush = MagicMock()
236+
self.handler.finish = MagicMock()
237+
238+
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.get_model_deployment_response")
238239
def test_post(self, mock_get_model_deployment_response):
239240
"""Test post method to return model deployment response."""
240-
self.inference_handler.get_json_body = MagicMock(
241-
return_value=TestDataset.inference_request
241+
mock_response_gen = iter(["chunk1", "chunk2"])
242+
243+
mock_get_model_deployment_response.return_value = mock_response_gen
244+
245+
self.handler.get_json_body = MagicMock(
246+
return_value={"prompt": "Hello", "model": "some-model"}
242247
)
243-
self.inference_handler.post()
248+
self.handler.request.headers = {"route": "test-route"}
249+
250+
self.handler.post("mock-deployment-id")
251+
244252
mock_get_model_deployment_response.assert_called_with(
245-
TestDataset.inference_request["endpoint"]
253+
"mock-deployment-id",
254+
{"prompt": "Hello", "model": "some-model"},
255+
"test-route",
246256
)
257+
self.handler.write.assert_any_call("chunk1")
258+
self.handler.write.assert_any_call("chunk2")
259+
self.handler.finish.assert_called_once()

0 commit comments

Comments
 (0)