From e8083c0cac0b6ab13b1fd5b214aa62efa928b971 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 11 Feb 2025 11:48:56 +0100 Subject: [PATCH] fai ai get response tests --- tests/test_inference_providers.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index d259453ea5..c2afc239f5 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -59,16 +59,37 @@ def test_text_to_image_payload(self): "image_size": {"width": 512, "height": 512}, } + def test_text_to_image_response(self, mocker): + helper = FalAITextToImageTask() + mock = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") + response = helper.get_response({"images": [{"url": "image_url"}]}) + mock.return_value.get.assert_called_once_with("image_url") + assert response == mock.return_value.get.return_value.content + def test_text_to_speech_payload(self): helper = FalAITextToSpeechTask() payload = helper._prepare_payload("Hello world", {}, "username/repo_name") assert payload == {"lyrics": "Hello world"} + def test_text_to_speech_response(self, mocker): + helper = FalAITextToSpeechTask() + mock = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") + response = helper.get_response({"audio": {"url": "audio_url"}}) + mock.return_value.get.assert_called_once_with("audio_url") + assert response == mock.return_value.get.return_value.content + def test_text_to_video_payload(self): helper = FalAITextToVideoTask() payload = helper._prepare_payload("a cat walking", {"num_frames": 16}, "username/repo_name") assert payload == {"prompt": "a cat walking", "num_frames": 16} + def test_text_to_video_response(self, mocker): + helper = FalAITextToVideoTask() + mock = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") + response = helper.get_response({"video": {"url": "video_url"}}) + mock.return_value.get.assert_called_once_with("video_url") + assert response == mock.return_value.get.return_value.content + class TestHFInferenceProvider: def test_prepare_mapped_model(self, mocker):