|
16 | 16 | import ads.config
|
17 | 17 | from ads.aqua.extension.deployment_handler import (
|
18 | 18 | AquaDeploymentHandler,
|
19 |
| - AquaDeploymentInferenceHandler, |
20 | 19 | AquaDeploymentParamsHandler,
|
| 20 | + AquaDeploymentStreamingInferenceHandler, |
21 | 21 | )
|
22 | 22 |
|
23 | 23 |
|
@@ -224,23 +224,36 @@ def test_validate_deployment_params(
|
224 | 224 | )
|
225 | 225 |
|
226 | 226 |
|
227 |
| -class TestAquaDeploymentInferenceHandler(unittest.TestCase): |
| 227 | +class TestAquaDeploymentStreamingInferenceHandler(unittest.TestCase): |
228 | 228 | @patch.object(IPythonHandler, "__init__")
|
229 | 229 | def setUp(self, ipython_init_mock) -> None:
|
230 | 230 | ipython_init_mock.return_value = None
|
231 |
| - self.inference_handler = AquaDeploymentInferenceHandler( |
232 |
| - MagicMock(), MagicMock() |
233 |
| - ) |
234 |
| - self.inference_handler.request = MagicMock() |
235 |
| - self.inference_handler.finish = MagicMock() |
236 |
| - |
237 |
| - @patch("ads.aqua.modeldeployment.MDInferenceResponse.get_model_deployment_response") |
| 231 | + self.handler = AquaDeploymentStreamingInferenceHandler(MagicMock(), MagicMock()) |
| 232 | + self.handler.request = MagicMock() |
| 233 | + self.handler.set_header = MagicMock() |
| 234 | + self.handler.write = MagicMock() |
| 235 | + self.handler.flush = MagicMock() |
| 236 | + self.handler.finish = MagicMock() |
| 237 | + |
| 238 | + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.get_model_deployment_response") |
238 | 239 | def test_post(self, mock_get_model_deployment_response):
|
239 | 240 | """Test post method to return model deployment response."""
|
240 |
| - self.inference_handler.get_json_body = MagicMock( |
241 |
| - return_value=TestDataset.inference_request |
| 241 | + mock_response_gen = iter(["chunk1", "chunk2"]) |
| 242 | + |
| 243 | + mock_get_model_deployment_response.return_value = mock_response_gen |
| 244 | + |
| 245 | + self.handler.get_json_body = MagicMock( |
| 246 | + return_value={"prompt": "Hello", "model": "some-model"} |
242 | 247 | )
|
243 |
| - self.inference_handler.post() |
| 248 | + self.handler.request.headers = {"route": "test-route"} |
| 249 | + |
| 250 | + self.handler.post("mock-deployment-id") |
| 251 | + |
244 | 252 | mock_get_model_deployment_response.assert_called_with(
|
245 |
| - TestDataset.inference_request["endpoint"] |
| 253 | + "mock-deployment-id", |
| 254 | + {"prompt": "Hello", "model": "some-model"}, |
| 255 | + "test-route", |
246 | 256 | )
|
| 257 | + self.handler.write.assert_any_call("chunk1") |
| 258 | + self.handler.write.assert_any_call("chunk2") |
| 259 | + self.handler.finish.assert_called_once() |
0 commit comments