From 2c566671aba3a8f40f2593fe5f98f979a62708c9 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 5 Feb 2025 10:59:52 +0100 Subject: [PATCH] fix imports and typing --- src/huggingface_hub/hf_api.py | 5 +++++ src/huggingface_hub/inference/_client.py | 3 ++- src/huggingface_hub/inference/_common.py | 20 ++++++++++--------- .../inference/_generated/_async_client.py | 3 ++- .../inference/_providers/fal_ai.py | 2 +- .../inference/_providers/hf_inference.py | 4 ++-- .../inference/_providers/replicate.py | 2 +- .../inference/_providers/sambanova.py | 2 +- .../inference/_providers/together.py | 2 +- 9 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 1d4f92d35b..937da81199 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -728,6 +728,8 @@ class ModelInfo: Status of the model on the inference API. Warm models are available for immediate use. Cold models will be loaded on first inference call. Frozen models are not available in Inference API. + inference_provider_mapping (`Dict`, *optional*): + Model's inference provider mapping. likes (`int`): Number of likes of the model. library_name (`str`, *optional*): @@ -759,6 +761,7 @@ class ModelInfo: Model's safetensors information. security_repo_status (`Dict`, *optional*): Model's security scan status. + """ id: str @@ -773,6 +776,7 @@ class ModelInfo: gated: Optional[Literal["auto", "manual", False]] gguf: Optional[Dict] inference: Optional[Literal["warm", "cold", "frozen"]] + inference_provider_mapping: Optional[Dict] likes: Optional[int] library_name: Optional[str] tags: Optional[List[str]] @@ -806,6 +810,7 @@ def __init__(self, **kwargs): self.library_name = kwargs.pop("library_name", None) self.gguf = kwargs.pop("gguf", None) self.inference = kwargs.pop("inference", None) + self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None) self.tags = kwargs.pop("tags", None) self.pipeline_tag = kwargs.pop("pipeline_tag", None) self.mask_token = kwargs.pop("mask_token", None) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index deb5da15a0..72686bde9d 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -263,7 +263,8 @@ def post( "`InferenceClient.post` is deprecated and should not be used directly anymore." ) provider_helper = HFInferenceTask(task or "unknown") - url = provider_helper.build_url(provider_helper.map_model(model)) + mapped_model = provider_helper.map_model(model, provider="hf-inference", task=task) + url = provider_helper.build_url(mapped_model) headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token) return self._inner_post( request_parameters=RequestParameters( diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index fec2d93819..21f3643346 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -49,7 +49,6 @@ UnknownError, ValidationError, ) -from huggingface_hub.hf_api import HfApi from ..utils import ( get_session, @@ -64,6 +63,7 @@ from aiohttp import ClientResponse, ClientSession from PIL.Image import Image + from ..hf_api import model_info # TYPES UrlT = str PathT = Union[str, Path] @@ -104,11 +104,13 @@ def prepare_request( @abstractmethod def get_response(self, response: Union[bytes, Dict]) -> Any: ... - def map_model(self, model: Optional[str], chat_completion: bool = False) -> str: + def map_model( + self, model: Optional[str], provider: str, task: Optional[str] = None, conversational: bool = False + ) -> str: """Default implementation for mapping model IDs to provider-specific IDs.""" if model is None: - raise ValueError(f"Please provide a model available on {self.provider}.") - return _get_provider_model_id(model, self.provider, self.task, chat_completion) + raise ValueError(f"Please provide a model available on {provider}.") + return _get_provider_model_id(model, provider, task, conversational) #### Fetching Inference Providers model mapping @@ -119,11 +121,11 @@ def _fetch_provider_mappings(model: str) -> Dict: """ Fetch provider mappings for a model from the Hub. """ - try: - info = HfApi().model_info(model, expand=["inferenceProviderMapping"]) - return info.get("inferenceProviderMapping", {}) - except Exception as e: - raise ValueError(f"Failed to get provider mapping for model {model}: {e}") + info = model_info(model, expand=["inferenceProviderMapping"]) + provider_mapping = info.inference_provider_mapping + if provider_mapping is None: + raise ValueError(f"No provider mapping found for model {model}") + return provider_mapping def _get_provider_model_id( diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 31de6ccfce..0bca79efd2 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -258,7 +258,8 @@ async def post( "`InferenceClient.post` is deprecated and should not be used directly anymore." ) provider_helper = HFInferenceTask(task or "unknown") - url = provider_helper.build_url(provider_helper.map_model(model)) + mapped_model = provider_helper.map_model(model, provider="hf-inference", task=task) + url = provider_helper.build_url(mapped_model) headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token) return await self._inner_post( request_parameters=RequestParameters( diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index ebb1b3a6da..6b2bd7835d 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -36,7 +36,7 @@ def prepare_request( raise ValueError( "You must provide an api_key to work with fal.ai API or log in with `huggingface-cli login`." ) - mapped_model = self.map_model(model, conversational=conversational) + mapped_model = self.map_model(model, provider="fal-ai", task=self.task, conversational=conversational) headers = { **build_hf_headers(token=api_key), **headers, diff --git a/src/huggingface_hub/inference/_providers/hf_inference.py b/src/huggingface_hub/inference/_providers/hf_inference.py index 8a8ff84194..da5eb539f9 100644 --- a/src/huggingface_hub/inference/_providers/hf_inference.py +++ b/src/huggingface_hub/inference/_providers/hf_inference.py @@ -80,7 +80,7 @@ def prepare_request( ) -> RequestParameters: if extra_payload is None: extra_payload = {} - mapped_model = self.map_model(model, conversational=conversational) + mapped_model = self.map_model(model, provider="hf-inference", task=self.task, conversational=conversational) url = self.build_url(mapped_model) data, json = self._prepare_payload( inputs, parameters=parameters, model=mapped_model, extra_payload=extra_payload @@ -165,7 +165,7 @@ def prepare_request( extra_payload: Optional[Dict[str, Any]] = None, conversational: bool = False, ) -> RequestParameters: - mapped_model = self.map_model(model, conversational=conversational) + mapped_model = self.map_model(model, provider="hf-inference", task=self.task, conversational=conversational) payload_model = parameters.get("model") or mapped_model if payload_model is None or payload_model.startswith(("http://", "https://")): diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index d49fb016bc..7a74003a40 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -47,7 +47,7 @@ def prepare_request( base_url = BASE_URL logger.info("Calling Replicate provider directly.") - mapped_model = self.map_model(model, conversational=conversational) + mapped_model = self.map_model(model, provider="replicate", task=self.task, conversational=conversational) url = _build_url(base_url, mapped_model) headers = { diff --git a/src/huggingface_hub/inference/_providers/sambanova.py b/src/huggingface_hub/inference/_providers/sambanova.py index 8fe942837d..f1b9ec12e1 100644 --- a/src/huggingface_hub/inference/_providers/sambanova.py +++ b/src/huggingface_hub/inference/_providers/sambanova.py @@ -43,7 +43,7 @@ def prepare_request( logger.info("Calling Sambanova provider directly.") headers = {**build_hf_headers(token=api_key), **headers} - mapped_model = self.map_model(model, conversational=conversational) + mapped_model = self.map_model(model, provider="sambanova", task=self.task, conversational=conversational) payload = { "messages": inputs, **{k: v for k, v in parameters.items() if v is not None}, diff --git a/src/huggingface_hub/inference/_providers/together.py b/src/huggingface_hub/inference/_providers/together.py index 64610574da..1f2f2bb023 100644 --- a/src/huggingface_hub/inference/_providers/together.py +++ b/src/huggingface_hub/inference/_providers/together.py @@ -52,7 +52,7 @@ def prepare_request( else: base_url = BASE_URL logger.info("Calling Together provider directly.") - mapped_model = self.map_model(model, conversational=conversational) + mapped_model = self.map_model(model, provider="together", task=self.task, conversational=conversational) if "model" in parameters: parameters["model"] = mapped_model payload = self._prepare_payload(inputs, parameters=parameters)