Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Feb 10, 2025
1 parent 53cca3c commit 0474833
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 81 deletions.
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def prepare_request(
headers = {**build_hf_headers(token=api_key), **headers}

# Route to the proxy if the api_key is a HF TOKEN
base_url = get_base_url("fai-ai", BASE_URL, api_key)
base_url = get_base_url("fal-ai", BASE_URL, api_key)
if not api_key.startswith("hf_"):
headers["authorization"] = f"Key {api_key}"

Expand Down
7 changes: 1 addition & 6 deletions src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,7 @@ def prepare_request(
headers=headers,
)

def _prepare_payload(
self,
inputs: Any,
parameters: Dict[str, Any],
model: str,
) -> Dict[str, Any]:
def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any], model: str) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"input": {
"prompt": inputs,
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def prepare_request(

# Route to the proxy if the api_key is a HF TOKEN
base_url = get_base_url("together", BASE_URL, api_key)
mapped_model = mapped_model = get_mapped_model("fal-ai", model, self.task)
mapped_model = mapped_model = get_mapped_model("together", model, self.task)

if "model" in parameters:
parameters["model"] = mapped_model
Expand Down
149 changes: 76 additions & 73 deletions tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask
from huggingface_hub.inference._providers.together import TogetherTextGenerationTask, TogetherTextToImageTask
from tests.testing_utils import with_production_testing


class TestHFInferenceProvider:
@with_production_testing
def test_prepare_request(self):
helper = HFInferenceTask("text-classification")
request = helper.prepare_request(
Expand All @@ -34,7 +36,7 @@ def test_prepare_request(self):
assert request.headers["authorization"] == "Bearer hf_test_token"
assert request.json == {"inputs": "this is a dummy input", "parameters": {}}

# Testing conversational task separately
@with_production_testing
def test_prepare_request_conversational(self):
helper = HFInferenceConversational()
request = helper.prepare_request(
Expand Down Expand Up @@ -98,11 +100,9 @@ def test_get_response(self):


class TestFalAIProvider:
def test_prepare_request(self):
helper = FalAITextToImageTask()

# Test with custom fal.ai key
request = helper.prepare_request(
@with_production_testing
def test_prepare_request_no_routing(self):
request = FalAITextToImageTask().prepare_request(
inputs="dummy text input",
parameters={},
headers={},
Expand All @@ -112,18 +112,9 @@ def test_prepare_request(self):
assert request.url.startswith("https://fal.run/")
assert request.headers["authorization"] == "Key my_fal_ai_key"

# Test with missing token
with pytest.raises(ValueError, match="You must provide an api_key to work with fal.ai API."):
helper.prepare_request(
inputs="dummy text input",
parameters={},
headers={},
model="black-forest-labs/FLUX.1-dev",
api_key=None,
)

# Test routing
request = helper.prepare_request(
@with_production_testing
def test_prepare_request_with_routing(self):
request = FalAITextToImageTask().prepare_request(
inputs="dummy text input",
parameters={},
headers={},
Expand All @@ -133,6 +124,17 @@ def test_prepare_request(self):
assert request.headers["authorization"] == "Bearer hf_test_token"
assert request.url.startswith("https://router.huggingface.co/fal-ai")

@with_production_testing
def test_prepare_request_no_api_key(self):
with pytest.raises(ValueError, match="You must provide an api_key to work with fal.ai API."):
FalAITextToImageTask().prepare_request(
inputs="dummy text input",
parameters={},
headers={},
model="black-forest-labs/FLUX.1-dev",
api_key=None,
)

@pytest.mark.parametrize(
"helper,inputs,parameters,expected_payload",
[
Expand Down Expand Up @@ -182,11 +184,9 @@ def test_get_response(self):


class TestReplicateProvider:
def test_prepare_request(self):
helper = ReplicateTask("text-to-image")

# Test with custom replicate key
request = helper.prepare_request(
@with_production_testing
def test_prepare_request_no_routing(self):
request = ReplicateTask("text-to-image").prepare_request(
inputs="dummy text input",
parameters={},
headers={},
Expand All @@ -196,28 +196,31 @@ def test_prepare_request(self):
assert request.url.startswith("https://api.replicate.com/")
assert request.headers["Prefer"] == "wait"

@with_production_testing
def test_prepare_request_with_routing(self):
request = ReplicateTask("text-to-image").prepare_request(
inputs="dummy text input",
parameters={},
headers={},
model="black-forest-labs/FLUX.1-schnell",
api_key="hf_test_token",
)
assert request.url.startswith("https://router.huggingface.co/replicate")

@with_production_testing
def test_prepare_request_no_api_key(self):
# Test with missing token
with pytest.raises(ValueError, match="You must provide an api_key to work with Replicate API."):
helper.prepare_request(
ReplicateTask("text-to-image").prepare_request(
inputs="dummy text input",
parameters={},
headers={},
model="black-forest-labs/FLUX.1-schnell",
api_key=None,
)

# Test routing
request = helper.prepare_request(
inputs="dummy text input",
parameters={},
headers={},
model="black-forest-labs/FLUX.1-schnell",
api_key="hf_test_token",
)
assert request.url.startswith("https://router.huggingface.co/replicate")

@pytest.mark.parametrize(
"helper,model,inputs,parameters,expected_payload",
"helper,mapped_model,inputs,parameters,expected_payload",
[
(
ReplicateTask("text-to-image"),
Expand All @@ -233,7 +236,7 @@ def test_prepare_request(self):
),
(
ReplicateTextToSpeechTask(),
"hexgrad/Kokoro-82M",
"hexgrad/Kokoro-82M:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13",
"Hello world",
{},
{
Expand All @@ -245,7 +248,7 @@ def test_prepare_request(self):
),
(
ReplicateTask("text-to-video"),
"genmo/mochi-1-preview",
"genmo/mochi-1-preview:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
"a cat walking",
{"num_frames": 16},
{
Expand All @@ -259,20 +262,17 @@ def test_prepare_request(self):
],
ids=["text-to-image", "text-to-speech", "text-to-video"],
)
def test_prepare_payload(self, helper, model, inputs, parameters, expected_payload):
mapped_model = helper._map_model(model)
payload = helper._prepare_payload(inputs, parameters, mapped_model)
assert payload == expected_payload
def test_prepare_payload(self, helper, mapped_model, inputs, parameters, expected_payload):
assert expected_payload == helper._prepare_payload(inputs, parameters, mapped_model)

def test_get_response(self):
pytest.skip("Not implemented yet")


class TestTogetherProvider:
def test_prepare_request(self):
helper = TogetherTextGenerationTask("conversational")
# Test with custom together key
request = helper.prepare_request(
@with_production_testing
def test_prepare_request_no_routing(self):
request = TogetherTextGenerationTask("conversational").prepare_request(
inputs="this is a dummy input",
parameters={},
headers={},
Expand All @@ -282,18 +282,9 @@ def test_prepare_request(self):
assert request.url.startswith("https://api.together.xyz/")
assert request.model == "meta-llama/Llama-3-70b-chat-hf"

# Test with missing token
with pytest.raises(ValueError, match="You must provide an api_key to work with Together API."):
helper.prepare_request(
inputs="this is a dummy input",
parameters={},
headers={},
model="meta-llama/Meta-Llama-3-70B-Instruct",
api_key=None,
)

# Test routing
request = helper.prepare_request(
@with_production_testing
def test_prepare_request_with_routing(self):
request = TogetherTextGenerationTask("conversational").prepare_request(
inputs="this is a dummy input",
parameters={},
headers={},
Expand All @@ -302,6 +293,17 @@ def test_prepare_request(self):
)
assert request.url.startswith("https://router.huggingface.co/together")

@with_production_testing
def test_prepare_request_no_api_key(self):
with pytest.raises(ValueError, match="You must provide an api_key to work with Together API."):
TogetherTextGenerationTask("conversational").prepare_request(
inputs="this is a dummy input",
parameters={},
headers={},
model="meta-llama/Meta-Llama-3-70B-Instruct",
api_key=None,
)

@pytest.mark.parametrize(
"helper,inputs,parameters,expected_payload",
[
Expand Down Expand Up @@ -343,10 +345,9 @@ def test_get_response(self):


class TestSambanovaProvider:
def test_prepare_request(self):
helper = SambanovaConversationalTask()
# Test with custom sambanova key
request = helper.prepare_request(
@with_production_testing
def test_prepare_request_no_routing(self):
request = SambanovaConversationalTask().prepare_request(
inputs="this is a dummy input",
parameters={},
headers={},
Expand All @@ -357,18 +358,9 @@ def test_prepare_request(self):
assert request.model == "Meta-Llama-3.1-8B-Instruct"
assert "messages" in request.json

# Test with missing token
with pytest.raises(ValueError, match="You must provide an api_key to work with Sambanova API."):
helper.prepare_request(
inputs="this is a dummy input",
parameters={},
headers={},
model="meta-llama/Llama-3.1-8B-Instruct",
api_key=None,
)

# Test routing
request = helper.prepare_request(
@with_production_testing
def test_prepare_request_with_routing(self):
request = SambanovaConversationalTask().prepare_request(
inputs="this is a dummy input",
parameters={},
headers={},
Expand All @@ -377,5 +369,16 @@ def test_prepare_request(self):
)
assert request.url.startswith("https://router.huggingface.co/sambanova")

@with_production_testing
def test_prepare_request_no_api_key(self):
with pytest.raises(ValueError, match="You must provide an api_key to work with Sambanova API."):
SambanovaConversationalTask().prepare_request(
inputs="this is a dummy input",
parameters={},
headers={},
model="meta-llama/Llama-3.1-8B-Instruct",
api_key=None,
)

def test_get_response(self):
pytest.skip("Not implemented yet")

0 comments on commit 0474833

Please sign in to comment.