Skip to content

Commit

Permalink
refacto
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Feb 7, 2025
1 parent 80acc5c commit 5639de0
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 320 deletions.
22 changes: 21 additions & 1 deletion src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,19 @@ def __init__(self, **kwargs):
self.last_commit = last_commit


@dataclass
class InferenceProviderMapping:
status: Literal["live", "staging"]
provider_id: str
task: str

def __init__(self, **kwargs):
self.status = kwargs.pop("status")
self.provider_id = kwargs.pop("providerId")
self.task = kwargs.pop("task")
self.__dict__.update(**kwargs)


@dataclass
class ModelInfo:
"""
Expand Down Expand Up @@ -788,7 +801,7 @@ class ModelInfo:
gated: Optional[Literal["auto", "manual", False]]
gguf: Optional[Dict]
inference: Optional[Literal["warm", "cold", "frozen"]]
inference_provider_mapping: Optional[Dict]
inference_provider_mapping: Optional[Dict[str, InferenceProviderMapping]]
likes: Optional[int]
library_name: Optional[str]
tags: Optional[List[str]]
Expand Down Expand Up @@ -821,8 +834,15 @@ def __init__(self, **kwargs):
self.likes = kwargs.pop("likes", None)
self.library_name = kwargs.pop("library_name", None)
self.gguf = kwargs.pop("gguf", None)

self.inference = kwargs.pop("inference", None)
self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None)
if self.inference_provider_mapping:
self.inference_provider_mapping = {
provider: InferenceProviderMapping(**value)
for provider, value in self.inference_provider_mapping.items()
}

self.tags = kwargs.pop("tags", None)
self.pipeline_tag = kwargs.pop("pipeline_tag", None)
self.mask_token = kwargs.pop("mask_token", None)
Expand Down
39 changes: 1 addition & 38 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,7 @@
ValidationError,
)

from ..utils import (
get_session,
is_aiohttp_available,
is_numpy_available,
is_pillow_available,
)
from ..utils import get_session, is_aiohttp_available, is_numpy_available, is_pillow_available
from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput


Expand Down Expand Up @@ -104,38 +99,6 @@ def prepare_request(
def get_response(self, response: Union[bytes, Dict]) -> Any: ...


#### Fetching Inference Providers model mapping
_PROVIDER_MAPPINGS: Optional[Dict[str, Dict]] = None


def _fetch_provider_mappings(model: str) -> Dict:
"""
Fetch provider mappings for a model from the Hub.
"""
from ..hf_api import model_info

info = model_info(model, expand=["inferenceProviderMapping"])
provider_mapping = info.inference_provider_mapping
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
return provider_mapping


def _get_provider_mapping(model: str, provider: str) -> Dict:
"""
Map a model ID to a provider-specific ID.
"""
global _PROVIDER_MAPPINGS
if _PROVIDER_MAPPINGS is None:
_PROVIDER_MAPPINGS = _fetch_provider_mappings(model)
if not _PROVIDER_MAPPINGS:
logger.warning(f"No provider mappings found for model {model}")

provider_mapping = _PROVIDER_MAPPINGS.get(provider, {})

return provider_mapping


# Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
@dataclass
class ModelStatus:
Expand Down
76 changes: 76 additions & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
from functools import cache
from typing import Any, Dict, Optional

from huggingface_hub import constants


logger = logging.getLogger(__name__)


# Dev purposes only.
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
# for a given Inference Provider, you can add it to the following dictionary.
HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = {
# "HF model ID" => "Model ID on Inference Provider's side"
#
# Example:
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
"replicate": {},
"sambanova": {},
"together": {},
}


def get_base_url(provider: str, base_url: str, api_key: str) -> str:
# Route to the proxy if the api_key is a HF TOKEN
if api_key.startswith("hf_"):
logger.info(f"Calling '{provider}' provider through Hugging Face router.")
return constants.INFERENCE_PROXY_TEMPLATE.format(provider=provider)
else:
logger.info("Calling '{provider}' provider directly.")
return base_url


def get_mapped_model(provider: str, model: Optional[str], task: str) -> str:
if model is None:
raise ValueError(f"Please provide an HF model ID supported by {provider}.")

# hardcoded mapping for local testing
if HARDCODED_MODEL_ID_MAPPING.get(provider, {}).get(model):
return HARDCODED_MODEL_ID_MAPPING[provider][model]

provider_mapping = _fetch_inference_provider_mapping(model).get(provider)
if provider_mapping is None:
raise ValueError(f"Model {model} is not supported by provider {provider}.")

if provider_mapping.task != task:
raise ValueError(
f"Model {model} is not supported for task {task} and provider {provider}. "
f"Supported task: {provider_mapping.task}."
)

if provider_mapping.status == "staging":
logger.warning(f"Model {model} is in staging mode for provider {provider}. Meant for test purposes only.")
return provider_mapping.provider_id


def filter_none(d: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in d.items() if v is not None}


@cache
def _fetch_inference_provider_mapping(model: str) -> Dict:
"""
Fetch provider mappings for a model from the Hub.
"""
from huggingface_hub.hf_api import model_info

info = model_info(model, expand=["inferenceProviderMapping"])
provider_mapping = info.inference_provider_mapping
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
return provider_mapping
81 changes: 11 additions & 70 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,16 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union

from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict, _get_provider_mapping
from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict
from huggingface_hub.utils import build_hf_headers, get_session, get_token, logging

from ._common import filter_none, get_base_url, get_mapped_model


logger = logging.get_logger(__name__)


BASE_URL = "https://fal.run"
SUPPORTED_MODELS = {
"automatic-speech-recognition": {
"openai/whisper-large-v3": "fal-ai/whisper",
},
"text-to-image": {
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"ByteDance/SDXL-Lightning": "fal-ai/lightning-models",
"fal/AuraFlow-v0.2": "fal-ai/aura-flow",
"Kwai-Kolors/Kolors": "fal-ai/kolors",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma",
"playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25",
"stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium",
"stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large",
"Warlord-K/Sana-1024": "fal-ai/sana",
},
"text-to-speech": {
"m-a-p/YuE-s1-7B-anneal-en-cot": "fal-ai/yue",
},
"text-to-video": {
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
},
}


class FalAITask(TaskProviderHelper, ABC):
Expand All @@ -59,20 +36,13 @@ def prepare_request(
raise ValueError(
"You must provide an api_key to work with fal.ai API or log in with `huggingface-cli login`."
)
mapped_model = self._map_model(model)
headers = {
**build_hf_headers(token=api_key),
**headers,
}
mapped_model = get_mapped_model("fal-ai", model, self.task)
headers = {**build_hf_headers(token=api_key), **headers}

# Route to the proxy if the api_key is a HF TOKEN
if api_key.startswith("hf_"):
base_url = constants.INFERENCE_PROXY_TEMPLATE.format(provider="fal-ai")
logger.info("Calling fal.ai provider through Hugging Face proxy.")
else:
base_url = BASE_URL
base_url = get_base_url("fai-ai", BASE_URL, api_key)
if not api_key.startswith("hf_"):
headers["authorization"] = f"Key {api_key}"
logger.info("Calling fal.ai provider directly.")

payload = self._prepare_payload(inputs, parameters=parameters)

Expand All @@ -85,28 +55,6 @@ def prepare_request(
headers=headers,
)

def _map_model(self, model: Optional[str]) -> str:
if model is None:
raise ValueError("Please provide a HF model ID supported by fal.ai.")
provider_mapping = _get_provider_mapping(model, "fal-ai")
if provider_mapping:
provider_task = provider_mapping.get("task")
status = provider_mapping.get("status")
if provider_task != self.task:
raise ValueError(
f"Model {model} is not supported for task {self.task} and provider fal.ai. "
f"Supported task: {provider_task}."
)
if status == "staging":
logger.warning(f"Model {model} is in staging mode for provider fal.ai. Meant for test purposes only.")
return provider_mapping["providerId"]
if self.task not in SUPPORTED_MODELS:
raise ValueError(f"Task {self.task} not supported with fal.ai.")
mapped_model = SUPPORTED_MODELS[self.task].get(model)
if mapped_model is None:
raise ValueError(f"Model {model} is not supported with fal.ai for task {self.task}.")
return mapped_model

@abstractmethod
def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: ...

Expand All @@ -129,10 +77,7 @@ def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str,
content_type = "audio/mpeg"
audio_url = f"data:{content_type};base64,{audio_b64}"

return {
"audio_url": audio_url,
**{k: v for k, v in parameters.items() if v is not None},
}
return {"audio_url": audio_url, **filter_none(parameters)}

def get_response(self, response: Union[bytes, Dict]) -> Any:
text = _as_dict(response)["text"]
Expand All @@ -146,7 +91,7 @@ def __init__(self):
super().__init__("text-to-image")

def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {k: v for k, v in parameters.items() if v is not None}
parameters = filter_none(parameters)
if "width" in parameters and "height" in parameters:
parameters["image_size"] = {
"width": parameters.pop("width"),
Expand All @@ -164,10 +109,7 @@ def __init__(self):
super().__init__("text-to-speech")

def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
return {
"lyrics": inputs,
**{k: v for k, v in parameters.items() if v is not None},
}
return {"lyrics": inputs, **filter_none(parameters)}

def get_response(self, response: Union[bytes, Dict]) -> Any:
url = _as_dict(response)["audio"]["url"]
Expand All @@ -179,8 +121,7 @@ def __init__(self):
super().__init__("text-to-video")

def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {k: v for k, v in parameters.items() if v is not None}
return {"prompt": inputs, **parameters}
return {"prompt": inputs, **filter_none(parameters)}

def get_response(self, response: Union[bytes, Dict]) -> Any:
url = _as_dict(response)["video"]["url"]
Expand Down
Loading

0 comments on commit 5639de0

Please sign in to comment.