Skip to content

Commit

Permalink
fix imports and typing
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Feb 5, 2025
1 parent dbda294 commit 2c56667
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 17 deletions.
5 changes: 5 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ class ModelInfo:
Status of the model on the inference API.
Warm models are available for immediate use. Cold models will be loaded on first inference call.
Frozen models are not available in Inference API.
inference_provider_mapping (`Dict`, *optional*):
Model's inference provider mapping.
likes (`int`):
Number of likes of the model.
library_name (`str`, *optional*):
Expand Down Expand Up @@ -759,6 +761,7 @@ class ModelInfo:
Model's safetensors information.
security_repo_status (`Dict`, *optional*):
Model's security scan status.
"""

id: str
Expand All @@ -773,6 +776,7 @@ class ModelInfo:
gated: Optional[Literal["auto", "manual", False]]
gguf: Optional[Dict]
inference: Optional[Literal["warm", "cold", "frozen"]]
inference_provider_mapping: Optional[Dict]
likes: Optional[int]
library_name: Optional[str]
tags: Optional[List[str]]
Expand Down Expand Up @@ -806,6 +810,7 @@ def __init__(self, **kwargs):
self.library_name = kwargs.pop("library_name", None)
self.gguf = kwargs.pop("gguf", None)
self.inference = kwargs.pop("inference", None)
self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None)
self.tags = kwargs.pop("tags", None)
self.pipeline_tag = kwargs.pop("pipeline_tag", None)
self.mask_token = kwargs.pop("mask_token", None)
Expand Down
3 changes: 2 additions & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def post(
"`InferenceClient.post` is deprecated and should not be used directly anymore."
)
provider_helper = HFInferenceTask(task or "unknown")
url = provider_helper.build_url(provider_helper.map_model(model))
mapped_model = provider_helper.map_model(model, provider="hf-inference", task=task)
url = provider_helper.build_url(mapped_model)
headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token)
return self._inner_post(
request_parameters=RequestParameters(
Expand Down
20 changes: 11 additions & 9 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
UnknownError,
ValidationError,
)
from huggingface_hub.hf_api import HfApi

from ..utils import (
get_session,
Expand All @@ -64,6 +63,7 @@
from aiohttp import ClientResponse, ClientSession
from PIL.Image import Image

from ..hf_api import model_info
# TYPES
UrlT = str
PathT = Union[str, Path]
Expand Down Expand Up @@ -104,11 +104,13 @@ def prepare_request(
@abstractmethod
def get_response(self, response: Union[bytes, Dict]) -> Any: ...

def map_model(self, model: Optional[str], chat_completion: bool = False) -> str:
def map_model(
self, model: Optional[str], provider: str, task: Optional[str] = None, conversational: bool = False
) -> str:
"""Default implementation for mapping model IDs to provider-specific IDs."""
if model is None:
raise ValueError(f"Please provide a model available on {self.provider}.")
return _get_provider_model_id(model, self.provider, self.task, chat_completion)
raise ValueError(f"Please provide a model available on {provider}.")
return _get_provider_model_id(model, provider, task, conversational)


#### Fetching Inference Providers model mapping
Expand All @@ -119,11 +121,11 @@ def _fetch_provider_mappings(model: str) -> Dict:
"""
Fetch provider mappings for a model from the Hub.
"""
try:
info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
return info.get("inferenceProviderMapping", {})
except Exception as e:
raise ValueError(f"Failed to get provider mapping for model {model}: {e}")
info = model_info(model, expand=["inferenceProviderMapping"])
provider_mapping = info.inference_provider_mapping
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
return provider_mapping


def _get_provider_model_id(
Expand Down
3 changes: 2 additions & 1 deletion src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ async def post(
"`InferenceClient.post` is deprecated and should not be used directly anymore."
)
provider_helper = HFInferenceTask(task or "unknown")
url = provider_helper.build_url(provider_helper.map_model(model))
mapped_model = provider_helper.map_model(model, provider="hf-inference", task=task)
url = provider_helper.build_url(mapped_model)
headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token)
return await self._inner_post(
request_parameters=RequestParameters(
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def prepare_request(
raise ValueError(
"You must provide an api_key to work with fal.ai API or log in with `huggingface-cli login`."
)
mapped_model = self.map_model(model, conversational=conversational)
mapped_model = self.map_model(model, provider="fal-ai", task=self.task, conversational=conversational)
headers = {
**build_hf_headers(token=api_key),
**headers,
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def prepare_request(
) -> RequestParameters:
if extra_payload is None:
extra_payload = {}
mapped_model = self.map_model(model, conversational=conversational)
mapped_model = self.map_model(model, provider="hf-inference", task=self.task, conversational=conversational)
url = self.build_url(mapped_model)
data, json = self._prepare_payload(
inputs, parameters=parameters, model=mapped_model, extra_payload=extra_payload
Expand Down Expand Up @@ -165,7 +165,7 @@ def prepare_request(
extra_payload: Optional[Dict[str, Any]] = None,
conversational: bool = False,
) -> RequestParameters:
mapped_model = self.map_model(model, conversational=conversational)
mapped_model = self.map_model(model, provider="hf-inference", task=self.task, conversational=conversational)
payload_model = parameters.get("model") or mapped_model

if payload_model is None or payload_model.startswith(("http://", "https://")):
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def prepare_request(
base_url = BASE_URL
logger.info("Calling Replicate provider directly.")

mapped_model = self.map_model(model, conversational=conversational)
mapped_model = self.map_model(model, provider="replicate", task=self.task, conversational=conversational)
url = _build_url(base_url, mapped_model)

headers = {
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def prepare_request(
logger.info("Calling Sambanova provider directly.")
headers = {**build_hf_headers(token=api_key), **headers}

mapped_model = self.map_model(model, conversational=conversational)
mapped_model = self.map_model(model, provider="sambanova", task=self.task, conversational=conversational)
payload = {
"messages": inputs,
**{k: v for k, v in parameters.items() if v is not None},
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def prepare_request(
else:
base_url = BASE_URL
logger.info("Calling Together provider directly.")
mapped_model = self.map_model(model, conversational=conversational)
mapped_model = self.map_model(model, provider="together", task=self.task, conversational=conversational)
if "model" in parameters:
parameters["model"] = mapped_model
payload = self._prepare_payload(inputs, parameters=parameters)
Expand Down

0 comments on commit 2c56667

Please sign in to comment.