Skip to content

Commit

Permalink
first draft of dynamic mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Feb 5, 2025
1 parent f2e1889 commit dbda294
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 186 deletions.
31 changes: 11 additions & 20 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,56 +132,47 @@

ExpandModelProperty_T = Literal[
"author",
"baseModels",
"cardData",
"childrenModelCount",
"config",
"citation",
"createdAt",
"description",
"disabled",
"downloads",
"downloadsAllTime",
"gated",
"gguf",
"inference",
"inferenceProviderMapping",
"lastModified",
"library_name",
"likes",
"mask_token",
"model-index",
"pipeline_tag",
"paperswithcode_id",
"private",
"safetensors",
"resourceGroup",
"sha",
"siblings",
"spaces",
"tags",
"transformersInfo",
"trendingScore",
"widgetData",
"usedStorage",
"resourceGroup",
]

ExpandDatasetProperty_T = Literal[
"author",
"cardData",
"citation",
"createdAt",
"disabled",
"description",
"disabled",
"downloads",
"downloadsAllTime",
"gated",
"lastModified",
"likes",
"paperswithcode_id",
"private",
"siblings",
"resourceGroup",
"sha",
"trendingScore",
"siblings",
"tags",
"trendingScore",
"usedStorage",
"resourceGroup",
]

ExpandSpaceProperty_T = Literal[
Expand All @@ -194,15 +185,15 @@
"likes",
"models",
"private",
"resourceGroup",
"runtime",
"sdk",
"siblings",
"sha",
"siblings",
"subdomain",
"tags",
"trendingScore",
"usedStorage",
"resourceGroup",
]

USERNAME_PLACEHOLDER = "hf_user"
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,7 @@ def chat_completion(
headers=self.headers,
model=model_id_or_url,
api_key=self.token,
conversational=True,
)
data = self._inner_post(request_parameters, stream=stream)

Expand Down
57 changes: 57 additions & 0 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
UnknownError,
ValidationError,
)
from huggingface_hub.hf_api import HfApi

from ..utils import (
get_session,
Expand Down Expand Up @@ -98,10 +99,66 @@ def prepare_request(
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
conversational: bool = False,
) -> RequestParameters: ...
@abstractmethod
def get_response(self, response: Union[bytes, Dict]) -> Any: ...

def map_model(self, model: Optional[str], chat_completion: bool = False) -> str:
"""Default implementation for mapping model IDs to provider-specific IDs."""
if model is None:
raise ValueError(f"Please provide a model available on {self.provider}.")
return _get_provider_model_id(model, self.provider, self.task, chat_completion)


#### Fetching Inference Providers model mapping
_PROVIDER_MAPPINGS: Optional[Dict[str, Dict]] = None


def _fetch_provider_mappings(model: str) -> Dict:
"""
Fetch provider mappings for a model from the Hub.
"""
try:
info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
return info.get("inferenceProviderMapping", {})
except Exception as e:
raise ValueError(f"Failed to get provider mapping for model {model}: {e}")


def _get_provider_model_id(
model: str, provider: str, task: Optional[str] = None, chat_completion: bool = False
) -> str:
"""
Map a model ID to a provider-specific ID.
"""
if provider == "hf-inference":
return model

if task is None:
raise ValueError("task must be specified when using a third-party provider")

global _PROVIDER_MAPPINGS
if _PROVIDER_MAPPINGS is None:
_PROVIDER_MAPPINGS = _fetch_provider_mappings(model)
if not _PROVIDER_MAPPINGS:
logger.warning(f"No provider mappings found for model {model}")

provider_mapping = _PROVIDER_MAPPINGS.get(provider, {})
if not provider_mapping:
raise ValueError(f"Model {model} is not supported by provider {provider}")

provider_task = provider_mapping.get("task")
requested_task = "conversational" if task == "text-generation" and chat_completion else task

if provider_task != requested_task:
raise ValueError(
f"Model {model} is not supported for task {requested_task} and provider {provider}. "
f"Supported task: {provider_task}."
)

return provider_mapping.get("providerId", model)


# Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
@dataclass
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,7 @@ async def chat_completion(
headers=self.headers,
model=model_id_or_url,
api_key=self.token,
conversational=True,
)
data = await self._inner_post(request_parameters, stream=stream)

Expand Down
39 changes: 2 additions & 37 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,6 @@

BASE_URL = "https://fal.run"

SUPPORTED_MODELS = {
"automatic-speech-recognition": {
"openai/whisper-large-v3": "fal-ai/whisper",
},
"text-to-image": {
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
"fal/AuraFlow-v0.2": "fal-ai/aura-flow",
"Kwai-Kolors/Kolors": "fal-ai/kolors",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
"Warlord-K/Sana-1024": "fal-ai/sana",
},
"text-to-speech": {
"m-a-p/YuE-s1-7B-anneal-en-cot": "fal-ai/yue",
},
"text-to-video": {
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
},
}


class FalAITask(TaskProviderHelper, ABC):
"""Base class for FalAI API tasks."""
Expand All @@ -53,15 +28,15 @@ def prepare_request(
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
conversational: bool = False,
) -> RequestParameters:
if api_key is None:
api_key = get_token()
if api_key is None:
raise ValueError(
"You must provide an api_key to work with fal.ai API or log in with `huggingface-cli login`."
)

mapped_model = self._map_model(model)
mapped_model = self.map_model(model, conversational=conversational)
headers = {
**build_hf_headers(token=api_key),
**headers,
Expand All @@ -87,16 +62,6 @@ def prepare_request(
headers=headers,
)

def _map_model(self, model: Optional[str]) -> str:
if model is None:
raise ValueError("Please provide a model available on FalAI.")
if self.task not in SUPPORTED_MODELS:
raise ValueError(f"Task {self.task} not supported with FalAI.")
mapped_model = SUPPORTED_MODELS[self.task].get(model)
if mapped_model is None:
raise ValueError(f"Model {model} is not supported with FalAI for task {self.task}.")
return mapped_model

@abstractmethod
def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: ...

Expand Down
26 changes: 16 additions & 10 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from huggingface_hub.constants import ENDPOINT
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _b64_encode, _open_as_binary
from huggingface_hub.inference._common import (
RequestParameters,
TaskProviderHelper,
_b64_encode,
_open_as_binary,
)
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status


Expand Down Expand Up @@ -71,12 +76,15 @@ def prepare_request(
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
conversational: bool = False,
) -> RequestParameters:
if extra_payload is None:
extra_payload = {}
mapped_model = self.map_model(model)
mapped_model = self.map_model(model, conversational=conversational)
url = self.build_url(mapped_model)
data, json = self._prepare_payload(inputs, parameters=parameters, model=model, extra_payload=extra_payload)
data, json = self._prepare_payload(
inputs, parameters=parameters, model=mapped_model, extra_payload=extra_payload
)
headers = self.prepare_headers(headers=headers, api_key=api_key)

return RequestParameters(
Expand All @@ -88,9 +96,6 @@ def prepare_request(
headers=headers,
)

def map_model(self, model: Optional[str]) -> str:
return model if model is not None else get_recommended_model(self.task)

def build_url(self, model: str) -> str:
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)
if model.startswith(("http://", "https://")):
Expand Down Expand Up @@ -158,9 +163,10 @@ def prepare_request(
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
conversational: bool = False,
) -> RequestParameters:
model = self.map_model(model)
payload_model = parameters.get("model") or model
mapped_model = self.map_model(model, conversational=conversational)
payload_model = parameters.get("model") or mapped_model

if payload_model is None or payload_model.startswith(("http://", "https://")):
payload_model = "tgi" # use a random string if not provided
Expand All @@ -174,9 +180,9 @@ def prepare_request(
headers = self.prepare_headers(headers=headers, api_key=api_key)

return RequestParameters(
url=self.build_url(model),
url=self.build_url(mapped_model),
task=self.task,
model=model,
model=mapped_model,
json=json,
data=None,
headers=headers,
Expand Down
34 changes: 3 additions & 31 deletions src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,6 @@

BASE_URL = "https://api.replicate.com"

SUPPORTED_MODELS = {
"text-to-image": {
"black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev",
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/Hyper-SD": "bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7",
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
"playgroundai/playground-v2.5-1024px-aesthetic": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
"stabilityai/stable-diffusion-3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo",
"stabilityai/stable-diffusion-3.5-large": "stability-ai/stable-diffusion-3.5-large",
"stabilityai/stable-diffusion-3.5-medium": "stability-ai/stable-diffusion-3.5-medium",
"stabilityai/stable-diffusion-xl-base-1.0": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
},
"text-to-speech": {
"hexgrad/Kokoro-82M": "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13",
},
"text-to-video": {
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
},
}


def _build_url(base_url: str, model: str) -> str:
if ":" in model:
Expand All @@ -50,6 +30,7 @@ def prepare_request(
model: Optional[str],
api_key: Optional[str],
extra_payload: Optional[Dict[str, Any]] = None,
conversational: bool = False,
) -> RequestParameters:
if api_key is None:
api_key = get_token()
Expand All @@ -65,7 +46,8 @@ def prepare_request(
else:
base_url = BASE_URL
logger.info("Calling Replicate provider directly.")
mapped_model = self._map_model(model)

mapped_model = self.map_model(model, conversational=conversational)
url = _build_url(base_url, mapped_model)

headers = {
Expand All @@ -85,16 +67,6 @@ def prepare_request(
headers=headers,
)

def _map_model(self, model: Optional[str]) -> str:
if model is None:
raise ValueError("Please provide a model available on Replicate.")
if self.task not in SUPPORTED_MODELS:
raise ValueError(f"Task {self.task} not supported with Replicate.")
mapped_model = SUPPORTED_MODELS[self.task].get(model)
if mapped_model is None:
raise ValueError(f"Model {model} is not supported with Replicate for task {self.task}.")
return mapped_model

def _prepare_payload(
self,
inputs: Any,
Expand Down
Loading

0 comments on commit dbda294

Please sign in to comment.