Skip to content

Commit

Permalink
small clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Feb 5, 2025
1 parent cc73ead commit a178f03
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 35 deletions.
3 changes: 1 addition & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,7 @@ def post(
"`InferenceClient.post` is deprecated and should not be used directly anymore."
)
provider_helper = HFInferenceTask(task or "unknown")
mapped_model = provider_helper.map_model(model, provider="hf-inference", task=task)
url = provider_helper.build_url(mapped_model)
url = provider_helper.build_url(provider_helper.map_model(model))
headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token)
return self._inner_post(
request_parameters=RequestParameters(
Expand Down
19 changes: 2 additions & 17 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ def prepare_request(
@abstractmethod
def get_response(self, response: Union[bytes, Dict]) -> Any: ...

def map_model(
self, model: Optional[str], provider: str, task: Optional[str] = None, conversational: bool = False
) -> str:
"""Default implementation for mapping model IDs to provider-specific IDs."""
if model is None:
raise ValueError(f"Please provide a model available on {provider}.")
return _get_provider_model_id(model, provider, task, conversational)


#### Fetching Inference Providers model mapping
_PROVIDER_MAPPINGS: Optional[Dict[str, Dict]] = None
Expand All @@ -136,17 +128,10 @@ def _fetch_provider_mappings(model: str) -> Dict:
return provider_mapping


def _get_provider_model_id(
model: str, provider: str, task: Optional[str] = None, chat_completion: bool = False
) -> str:
def _get_provider_model_id(model: str, provider: str, task: str, conversational: bool = False) -> str:
"""
Map a model ID to a provider-specific ID.
"""
if provider == "hf-inference":
return model

if task is None:
raise ValueError("task must be specified when using a third-party provider")

global _PROVIDER_MAPPINGS
if _PROVIDER_MAPPINGS is None:
Expand All @@ -159,7 +144,7 @@ def _get_provider_model_id(
raise ValueError(f"Model {model} is not supported by provider {provider}")

provider_task = provider_mapping.get("task")
requested_task = "conversational" if task == "text-generation" and chat_completion else task
requested_task = "conversational" if task == "text-generation" and conversational else task

if provider_task != requested_task:
raise ValueError(
Expand Down
3 changes: 1 addition & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ async def post(
"`InferenceClient.post` is deprecated and should not be used directly anymore."
)
provider_helper = HFInferenceTask(task or "unknown")
mapped_model = provider_helper.map_model(model, provider="hf-inference", task=task)
url = provider_helper.build_url(mapped_model)
url = provider_helper.build_url(provider_helper.map_model(model))
headers = provider_helper.prepare_headers(headers=self.headers, api_key=self.token)
return await self._inner_post(
request_parameters=RequestParameters(
Expand Down
15 changes: 13 additions & 2 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict, _get_provider_model_id
from huggingface_hub.utils import build_hf_headers, get_session, get_token, logging


Expand Down Expand Up @@ -36,7 +36,7 @@ def prepare_request(
raise ValueError(
"You must provide an api_key to work with fal.ai API or log in with `huggingface-cli login`."
)
mapped_model = self.map_model(model, provider="fal-ai", task=self.task, conversational=conversational)
mapped_model = self.map_model(model=model, task=self.task, conversational=conversational)
headers = {
**build_hf_headers(token=api_key),
**headers,
Expand All @@ -62,6 +62,17 @@ def prepare_request(
headers=headers,
)

def map_model(
self,
model: Optional[str],
task: str,
conversational: bool = False,
) -> str:
"""Default implementation for mapping model HF model IDs to provider model IDs."""
if model is None:
raise ValueError("Please provide a HF model ID supported by fal.ai.")
return _get_provider_model_id(model, "fal-ai", task, conversational)

@abstractmethod
def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: ...

Expand Down
14 changes: 8 additions & 6 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ def prepare_request(
) -> RequestParameters:
if extra_payload is None:
extra_payload = {}
if model is None:
model = get_recommended_model(self.task)
else:
model = self.map_model(model, provider="hf-inference", task=self.task, conversational=conversational)

model = self.map_model(model=model)
url = self.build_url(model)
data, json = self._prepare_payload(inputs, parameters=parameters, model=model, extra_payload=extra_payload)
headers = self.prepare_headers(headers=headers, api_key=api_key)
Expand All @@ -98,6 +94,12 @@ def prepare_request(
headers=headers,
)

def map_model(
self,
model: Optional[str],
) -> str:
return model if model is not None else get_recommended_model(self.task)

def build_url(self, model: str) -> str:
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)
if model.startswith(("http://", "https://")):
Expand Down Expand Up @@ -167,7 +169,7 @@ def prepare_request(
extra_payload: Optional[Dict[str, Any]] = None,
conversational: bool = False,
) -> RequestParameters:
mapped_model = self.map_model(model, provider="hf-inference", task=self.task, conversational=conversational)
mapped_model = self.map_model(model)
payload_model = parameters.get("model") or mapped_model

if payload_model is None or payload_model.startswith(("http://", "https://")):
Expand Down
15 changes: 13 additions & 2 deletions src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict, _get_provider_model_id
from huggingface_hub.utils import build_hf_headers, get_session, get_token, logging


Expand Down Expand Up @@ -47,7 +47,7 @@ def prepare_request(
base_url = BASE_URL
logger.info("Calling Replicate provider directly.")

mapped_model = self.map_model(model, provider="replicate", task=self.task, conversational=conversational)
mapped_model = self.map_model(model=model, task=self.task, conversational=conversational)
url = _build_url(base_url, mapped_model)

headers = {
Expand All @@ -67,6 +67,17 @@ def prepare_request(
headers=headers,
)

def map_model(
self,
model: Optional[str],
task: str,
conversational: bool = False,
) -> str:
"""Default implementation for mapping model HF model IDs to provider model IDs."""
if model is None:
raise ValueError("Please provide a HF model ID supported by Replicate.")
return _get_provider_model_id(model, "replicate", task, conversational)

def _prepare_payload(
self,
inputs: Any,
Expand Down
15 changes: 13 additions & 2 deletions src/huggingface_hub/inference/_providers/sambanova.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _get_provider_model_id
from huggingface_hub.utils import build_hf_headers, get_token, logging


Expand Down Expand Up @@ -43,7 +43,7 @@ def prepare_request(
logger.info("Calling Sambanova provider directly.")
headers = {**build_hf_headers(token=api_key), **headers}

mapped_model = self.map_model(model, provider="sambanova", task=self.task, conversational=conversational)
mapped_model = self.map_model(model=model, task=self.task, conversational=conversational)
payload = {
"messages": inputs,
**{k: v for k, v in parameters.items() if v is not None},
Expand All @@ -59,5 +59,16 @@ def prepare_request(
headers=headers,
)

def map_model(
self,
model: Optional[str],
task: str,
conversational: bool = False,
) -> str:
"""Default implementation for mapping model HF model IDs to provider model IDs."""
if model is None:
raise ValueError("Please provide a HF model ID supported by Sambanova.")
return _get_provider_model_id(model, "sambanova", task, conversational)

def get_response(self, response: Union[bytes, Dict]) -> Any:
return response
15 changes: 13 additions & 2 deletions src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict, _get_provider_model_id
from huggingface_hub.utils import build_hf_headers, get_token, logging


Expand Down Expand Up @@ -52,7 +52,7 @@ def prepare_request(
else:
base_url = BASE_URL
logger.info("Calling Together provider directly.")
mapped_model = self.map_model(model, provider="together", task=self.task, conversational=conversational)
mapped_model = self.map_model(model=model, task=self.task, conversational=conversational)
if "model" in parameters:
parameters["model"] = mapped_model
payload = self._prepare_payload(inputs, parameters=parameters)
Expand All @@ -66,6 +66,17 @@ def prepare_request(
headers=headers,
)

def map_model(
self,
model: Optional[str],
task: str,
conversational: bool = False,
) -> str:
"""Default implementation for mapping model HF model IDs to provider model IDs."""
if model is None:
raise ValueError("Please provide a HF model ID supported by Together.")
return _get_provider_model_id(model, "together", task, conversational)

def get_response(self, response: Union[bytes, Dict]) -> Any:
return response

Expand Down

0 comments on commit a178f03

Please sign in to comment.