Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Inference Client VCR tests #2858

Merged
merged 18 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
[
"Repository only",
"Everything else",
"Inference only"

]
include:
Expand Down Expand Up @@ -64,7 +65,7 @@ jobs:

case "${{ matrix.test_name }}" in

"Repository only" | "Everything else")
"Repository only" | "Everything else" | "Inference only")
sudo apt update
sudo apt install -y libsndfile1-dev
;;
Expand Down Expand Up @@ -112,8 +113,15 @@ jobs:
eval $PYTEST
;;

"Inference only")
# Run inference tests concurrently
PYTEST="$PYTEST ../tests -k 'test_inference' -n 4"
echo $PYTEST
eval $PYTEST
;;

"Everything else")
PYTEST="$PYTEST ../tests -k 'not TestRepository' -n 4"
PYTEST="$PYTEST ../tests -k 'not TestRepository and not test_inference' -n 4"
echo $PYTEST
eval $PYTEST
;;
Expand Down
12 changes: 6 additions & 6 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from requests import HTTPError

from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub import constants
from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
from huggingface_hub.inference._common import (
TASKS_EXPECTING_IMAGES,
Expand Down Expand Up @@ -3300,9 +3300,9 @@ def list_deployed_models(

# Resolve which frameworks to check
if frameworks is None:
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
elif frameworks == "all":
frameworks = ALL_INFERENCE_API_FRAMEWORKS
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
elif isinstance(frameworks, str):
frameworks = [frameworks]
frameworks = list(set(frameworks))
Expand All @@ -3322,7 +3322,7 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:

for framework in frameworks:
response = get_session().get(
f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
)
hf_raise_for_status(response)
_unpack_response(framework, response.json())
Expand Down Expand Up @@ -3384,7 +3384,7 @@ def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
if model.startswith(("http://", "https://")):
url = model.rstrip("/") + "/info"
else:
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"

response = get_session().get(url, headers=build_hf_headers(token=self.token))
hf_raise_for_status(response)
Expand Down Expand Up @@ -3472,7 +3472,7 @@ def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
raise ValueError("Model id not provided.")
if model.startswith("https://"):
raise NotImplementedError("Model status is only available for Inference API endpoints.")
url = f"{INFERENCE_ENDPOINT}/status/{model}"
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"

response = get_session().get(url, headers=build_hf_headers(token=self.token))
hf_raise_for_status(response)
Expand Down
12 changes: 6 additions & 6 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import warnings
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload

from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub import constants
from huggingface_hub.errors import InferenceTimeoutError
from huggingface_hub.inference._common import (
TASKS_EXPECTING_IMAGES,
Expand Down Expand Up @@ -3365,9 +3365,9 @@ async def list_deployed_models(

# Resolve which frameworks to check
if frameworks is None:
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
elif frameworks == "all":
frameworks = ALL_INFERENCE_API_FRAMEWORKS
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
elif isinstance(frameworks, str):
frameworks = [frameworks]
frameworks = list(set(frameworks))
Expand All @@ -3387,7 +3387,7 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:

for framework in frameworks:
response = get_session().get(
f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
)
hf_raise_for_status(response)
_unpack_response(framework, response.json())
Expand Down Expand Up @@ -3491,7 +3491,7 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A
if model.startswith(("http://", "https://")):
url = model.rstrip("/") + "/info"
else:
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"

async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
response = await client.get(url, proxy=self.proxies)
Expand Down Expand Up @@ -3583,7 +3583,7 @@ async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
raise ValueError("Model id not provided.")
if model.startswith("https://"):
raise NotImplementedError("Model status is only available for Inference API endpoints.")
url = f"{INFERENCE_ENDPOINT}/status/{model}"
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"

async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
response = await client.get(url, proxy=self.proxies)
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
"image-classification": HFInferenceBinaryInputTask("image-classification"),
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
"document-question-answering": HFInferenceTask("document-question-answering"),
"image-to-text": HFInferenceTask("image-to-text"),
"image-to-text": HFInferenceBinaryInputTask("image-to-text"),
"object-detection": HFInferenceBinaryInputTask("object-detection"),
"audio-to-audio": HFInferenceTask("audio-to-audio"),
"audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"),
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),
Expand Down
17 changes: 9 additions & 8 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Example:
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
"replicate": {},
"sambanova": {},
Expand Down Expand Up @@ -65,12 +66,12 @@ def prepare_request(
url = self._prepare_url(api_key, mapped_model)

# prepare payload (to customize in subclasses)
payload = self._prepare_payload(inputs, parameters, mapped_model=mapped_model)
payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
if payload is not None:
payload = recursive_merge(payload, extra_payload or {})

# body data (to customize in subclasses)
data = self._prepare_body(inputs, parameters, mapped_model, extra_payload)
data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)

# check if both payload and data are set and return
if payload is not None and data is not None:
Expand Down Expand Up @@ -159,21 +160,21 @@ def _prepare_route(self, mapped_model: str) -> str:
"""
return ""

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
"""Return the payload to use for the request, as a dict.

Override this method in subclasses for customized payloads.
Only one of `_prepare_payload` and `_prepare_body` should return a value.
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
"""
return None

def _prepare_body(
def _prepare_payload_as_bytes(
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
) -> Optional[bytes]:
"""Return the body to use for the request, as bytes.

Override this method in subclasses for customized body data.
Only one of `_prepare_payload` and `_prepare_body` should return a value.
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
"""
return None

Expand All @@ -183,9 +184,9 @@ def _fetch_inference_provider_mapping(model: str) -> Dict:
"""
Fetch provider mappings for a model from the Hub.
"""
from huggingface_hub.hf_api import model_info
from huggingface_hub.hf_api import HfApi

info = model_info(model, expand=["inferenceProviderMapping"])
info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
provider_mapping = info.inference_provider_mapping
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
Expand Down
8 changes: 4 additions & 4 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
def __init__(self):
super().__init__("automatic-speech-recognition")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
# If input is a URL, pass it directly
audio_url = inputs
Expand All @@ -52,7 +52,7 @@ class FalAITextToImageTask(FalAITask):
def __init__(self):
super().__init__("text-to-image")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
parameters = filter_none(parameters)
if "width" in parameters and "height" in parameters:
parameters["image_size"] = {
Expand All @@ -70,7 +70,7 @@ class FalAITextToSpeechTask(FalAITask):
def __init__(self):
super().__init__("text-to-speech")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"lyrics": inputs, **filter_none(parameters)}

def get_response(self, response: Union[bytes, Dict]) -> Any:
Expand All @@ -82,7 +82,7 @@ class FalAITextToVideoTask(FalAITask):
def __init__(self):
super().__init__("text-to-video")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"prompt": inputs, **filter_none(parameters)}

def get_response(self, response: Union[bytes, Dict]) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/fireworks_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def __init__(self):
def _prepare_route(self, mapped_model: str) -> str:
return "/v1/chat/completions"

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
9 changes: 6 additions & 3 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
else f"{self.base_url}/models/{mapped_model}"
)

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
if isinstance(inputs, bytes):
raise ValueError(f"Unexpected binary input for task {self.task}.")
if isinstance(inputs, Path):
Expand All @@ -55,7 +55,10 @@ def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) ->


class HFInferenceBinaryInputTask(HFInferenceTask):
def _prepare_body(
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return None

def _prepare_payload_as_bytes(
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
) -> Optional[bytes]:
parameters = filter_none({k: v for k, v in parameters.items() if v is not None})
Expand All @@ -80,7 +83,7 @@ class HFInferenceConversational(HFInferenceTask):
def __init__(self):
super().__init__("text-generation")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
payload_model = "tgi" if mapped_model.startswith(("http://", "https://")) else mapped_model
return {**filter_none(parameters), "model": payload_model, "messages": inputs}

Expand Down
15 changes: 8 additions & 7 deletions src/huggingface_hub/inference/_providers/new_provider.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Before adding a new provider to the `huggingface_hub` library, make sure it has

Create a new file under `src/huggingface_hub/inference/_providers/{provider_name}.py` and copy-paste the following snippet.

Implement the methods that require custom handling. Check out the base implementation to check default behavior. If you don't need to override a method, just remove it. At least one of `_prepare_payload` or `_prepare_body` must be overwritten.
Implement the methods that require custom handling. Check out the base implementation to check default behavior. If you don't need to override a method, just remove it. At least one of `_prepare_payload_as_dict` or `_prepare_payload_as_bytes` must be overwritten.

If the provider supports multiple tasks that require different implementations, create dedicated subclasses for each task, following the pattern shown in `fal_ai.py`.

Expand Down Expand Up @@ -42,23 +42,24 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper):
"""
return super()._prepare_route(mapped_model)

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
"""Return the payload to use for the request, as a dict.

Override this method in subclasses for customized payloads.
Only one of `_prepare_payload` and `_prepare_body` should return a value.
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
"""
return super()._prepare_payload(inputs, parameters, mapped_model)
return super()._prepare_payload_as_dict(inputs, parameters, mapped_model)

def _prepare_body(
def _prepare_payload_as_bytes(
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
) -> Optional[bytes]:
"""Return the body to use for the request, as bytes.

Override this method in subclasses for customized body data.
Only one of `_prepare_payload` and `_prepare_body` should return a value.
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
"""
return super()._prepare_body(inputs, parameters, mapped_model, extra_payload)
return super()._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)

```

### 2. Register the provider helper in `__init__.py`
Expand Down
6 changes: 3 additions & 3 deletions src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _prepare_route(self, mapped_model: str) -> str:
return "/v1/predictions"
return f"/v1/models/{mapped_model}/predictions"

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}}
if ":" in mapped_model:
version = mapped_model.split(":", 1)[1]
Expand All @@ -43,7 +43,7 @@ class ReplicateTextToSpeechTask(ReplicateTask):
def __init__(self):
super().__init__("text-to-speech")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
payload: Dict = super()._prepare_payload(inputs, parameters, mapped_model) # type: ignore[assignment]
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, mapped_model) # type: ignore[assignment]
payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
return payload
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def __init__(self):
def _prepare_route(self, mapped_model: str) -> str:
return "/v1/chat/completions"

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def _prepare_route(self, mapped_model: str) -> str:

class TogetherTextGenerationTask(TogetherTask):
# Handle both "text-generation" and "conversational"
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}


class TogetherTextToImageTask(TogetherTask):
def __init__(self):
super().__init__("text-to-image")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
parameters = filter_none(parameters)
if "num_inference_steps" in parameters:
parameters["steps"] = parameters.pop("num_inference_steps")
Expand Down

Large diffs are not rendered by default.

Loading