diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index af35e0bdc5..259c876e1a 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -698,6 +698,19 @@ def __init__(self, **kwargs): self.last_commit = last_commit +@dataclass +class InferenceProviderMapping: + status: Literal["live", "staging"] + provider_id: str + task: str + + def __init__(self, **kwargs): + self.status = kwargs.pop("status") + self.provider_id = kwargs.pop("providerId") + self.task = kwargs.pop("task") + self.__dict__.update(**kwargs) + + @dataclass class ModelInfo: """ @@ -788,7 +801,7 @@ class ModelInfo: gated: Optional[Literal["auto", "manual", False]] gguf: Optional[Dict] inference: Optional[Literal["warm", "cold", "frozen"]] - inference_provider_mapping: Optional[Dict] + inference_provider_mapping: Optional[Dict[str, InferenceProviderMapping]] likes: Optional[int] library_name: Optional[str] tags: Optional[List[str]] @@ -821,8 +834,15 @@ def __init__(self, **kwargs): self.likes = kwargs.pop("likes", None) self.library_name = kwargs.pop("library_name", None) self.gguf = kwargs.pop("gguf", None) + self.inference = kwargs.pop("inference", None) self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None) + if self.inference_provider_mapping: + self.inference_provider_mapping = { + provider: InferenceProviderMapping(**value) + for provider, value in self.inference_provider_mapping.items() + } + self.tags = kwargs.pop("tags", None) self.pipeline_tag = kwargs.pop("pipeline_tag", None) self.mask_token = kwargs.pop("mask_token", None) diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index 0c47b6f1d1..a3d869fee9 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -50,12 +50,7 @@ ValidationError, ) -from ..utils import ( - get_session, - is_aiohttp_available, - is_numpy_available, - is_pillow_available, -) +from ..utils import get_session, is_aiohttp_available, is_numpy_available, is_pillow_available from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput @@ -104,38 +99,6 @@ def prepare_request( def get_response(self, response: Union[bytes, Dict]) -> Any: ... -#### Fetching Inference Providers model mapping -_PROVIDER_MAPPINGS: Optional[Dict[str, Dict]] = None - - -def _fetch_provider_mappings(model: str) -> Dict: - """ - Fetch provider mappings for a model from the Hub. - """ - from ..hf_api import model_info - - info = model_info(model, expand=["inferenceProviderMapping"]) - provider_mapping = info.inference_provider_mapping - if provider_mapping is None: - raise ValueError(f"No provider mapping found for model {model}") - return provider_mapping - - -def _get_provider_mapping(model: str, provider: str) -> Dict: - """ - Map a model ID to a provider-specific ID. - """ - global _PROVIDER_MAPPINGS - if _PROVIDER_MAPPINGS is None: - _PROVIDER_MAPPINGS = _fetch_provider_mappings(model) - if not _PROVIDER_MAPPINGS: - logger.warning(f"No provider mappings found for model {model}") - - provider_mapping = _PROVIDER_MAPPINGS.get(provider, {}) - - return provider_mapping - - # Add dataclass for ModelStatus. We use this dataclass in get_model_status function. @dataclass class ModelStatus: diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py new file mode 100644 index 0000000000..5d5d2afcd1 --- /dev/null +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -0,0 +1,76 @@ +import logging +from functools import cache +from typing import Any, Dict, Optional + +from huggingface_hub import constants + + +logger = logging.getLogger(__name__) + + +# Dev purposes only. +# If you want to try to run inference for a new model locally before it's registered on huggingface.co +# for a given Inference Provider, you can add it to the following dictionary. +HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = { + # "HF model ID" => "Model ID on Inference Provider's side" + # + # Example: + # "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", + "fal-ai": {}, + "fireworks-ai": {}, + "hf-inference": {}, + "replicate": {}, + "sambanova": {}, + "together": {}, +} + + +def get_base_url(provider: str, base_url: str, api_key: str) -> str: + # Route to the proxy if the api_key is a HF TOKEN + if api_key.startswith("hf_"): + logger.info(f"Calling '{provider}' provider through Hugging Face router.") + return constants.INFERENCE_PROXY_TEMPLATE.format(provider=provider) + else: + logger.info("Calling '{provider}' provider directly.") + return base_url + + +def get_mapped_model(provider: str, model: Optional[str], task: str) -> str: + if model is None: + raise ValueError(f"Please provide an HF model ID supported by {provider}.") + + # hardcoded mapping for local testing + if HARDCODED_MODEL_ID_MAPPING.get(provider, {}).get(model): + return HARDCODED_MODEL_ID_MAPPING[provider][model] + + provider_mapping = _fetch_inference_provider_mapping(model).get(provider) + if provider_mapping is None: + raise ValueError(f"Model {model} is not supported by provider {provider}.") + + if provider_mapping.task != task: + raise ValueError( + f"Model {model} is not supported for task {task} and provider {provider}. " + f"Supported task: {provider_mapping.task}." + ) + + if provider_mapping.status == "staging": + logger.warning(f"Model {model} is in staging mode for provider {provider}. Meant for test purposes only.") + return provider_mapping.provider_id + + +def filter_none(d: Dict[str, Any]) -> Dict[str, Any]: + return {k: v for k, v in d.items() if v is not None} + + +@cache +def _fetch_inference_provider_mapping(model: str) -> Dict: + """ + Fetch provider mappings for a model from the Hub. + """ + from huggingface_hub.hf_api import model_info + + info = model_info(model, expand=["inferenceProviderMapping"]) + provider_mapping = info.inference_provider_mapping + if provider_mapping is None: + raise ValueError(f"No provider mapping found for model {model}") + return provider_mapping diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index a8e894f42c..aa88b89318 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -2,39 +2,16 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Union -from huggingface_hub import constants -from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict, _get_provider_mapping +from huggingface_hub.inference._common import RequestParameters, TaskProviderHelper, _as_dict from huggingface_hub.utils import build_hf_headers, get_session, get_token, logging +from ._common import filter_none, get_base_url, get_mapped_model + logger = logging.get_logger(__name__) BASE_URL = "https://fal.run" -SUPPORTED_MODELS = { - "automatic-speech-recognition": { - "openai/whisper-large-v3": "fal-ai/whisper", - }, - "text-to-image": { - "black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev", - "black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell", - "ByteDance/SDXL-Lightning": "fal-ai/lightning-models", - "fal/AuraFlow-v0.2": "fal-ai/aura-flow", - "Kwai-Kolors/Kolors": "fal-ai/kolors", - "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "fal-ai/pixart-sigma", - "playgroundai/playground-v2.5-1024px-aesthetic": "fal-ai/playground-v25", - "stabilityai/stable-diffusion-3-medium": "fal-ai/stable-diffusion-v3-medium", - "stabilityai/stable-diffusion-3.5-large": "fal-ai/stable-diffusion-v35-large", - "Warlord-K/Sana-1024": "fal-ai/sana", - }, - "text-to-speech": { - "m-a-p/YuE-s1-7B-anneal-en-cot": "fal-ai/yue", - }, - "text-to-video": { - "genmo/mochi-1-preview": "fal-ai/mochi-v1", - "tencent/HunyuanVideo": "fal-ai/hunyuan-video", - }, -} class FalAITask(TaskProviderHelper, ABC): @@ -59,20 +36,13 @@ def prepare_request( raise ValueError( "You must provide an api_key to work with fal.ai API or log in with `huggingface-cli login`." ) - mapped_model = self._map_model(model) - headers = { - **build_hf_headers(token=api_key), - **headers, - } + mapped_model = get_mapped_model("fal-ai", model, self.task) + headers = {**build_hf_headers(token=api_key), **headers} # Route to the proxy if the api_key is a HF TOKEN - if api_key.startswith("hf_"): - base_url = constants.INFERENCE_PROXY_TEMPLATE.format(provider="fal-ai") - logger.info("Calling fal.ai provider through Hugging Face proxy.") - else: - base_url = BASE_URL + base_url = get_base_url("fai-ai", BASE_URL, api_key) + if not api_key.startswith("hf_"): headers["authorization"] = f"Key {api_key}" - logger.info("Calling fal.ai provider directly.") payload = self._prepare_payload(inputs, parameters=parameters) @@ -85,28 +55,6 @@ def prepare_request( headers=headers, ) - def _map_model(self, model: Optional[str]) -> str: - if model is None: - raise ValueError("Please provide a HF model ID supported by fal.ai.") - provider_mapping = _get_provider_mapping(model, "fal-ai") - if provider_mapping: - provider_task = provider_mapping.get("task") - status = provider_mapping.get("status") - if provider_task != self.task: - raise ValueError( - f"Model {model} is not supported for task {self.task} and provider fal.ai. " - f"Supported task: {provider_task}." - ) - if status == "staging": - logger.warning(f"Model {model} is in staging mode for provider fal.ai. Meant for test purposes only.") - return provider_mapping["providerId"] - if self.task not in SUPPORTED_MODELS: - raise ValueError(f"Task {self.task} not supported with fal.ai.") - mapped_model = SUPPORTED_MODELS[self.task].get(model) - if mapped_model is None: - raise ValueError(f"Model {model} is not supported with fal.ai for task {self.task}.") - return mapped_model - @abstractmethod def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: ... @@ -129,10 +77,7 @@ def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, content_type = "audio/mpeg" audio_url = f"data:{content_type};base64,{audio_b64}" - return { - "audio_url": audio_url, - **{k: v for k, v in parameters.items() if v is not None}, - } + return {"audio_url": audio_url, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict]) -> Any: text = _as_dict(response)["text"] @@ -146,7 +91,7 @@ def __init__(self): super().__init__("text-to-image") def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: - parameters = {k: v for k, v in parameters.items() if v is not None} + parameters = filter_none(parameters) if "width" in parameters and "height" in parameters: parameters["image_size"] = { "width": parameters.pop("width"), @@ -164,10 +109,7 @@ def __init__(self): super().__init__("text-to-speech") def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: - return { - "lyrics": inputs, - **{k: v for k, v in parameters.items() if v is not None}, - } + return {"lyrics": inputs, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict]) -> Any: url = _as_dict(response)["audio"]["url"] @@ -179,8 +121,7 @@ def __init__(self): super().__init__("text-to-video") def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: - parameters = {k: v for k, v in parameters.items() if v is not None} - return {"prompt": inputs, **parameters} + return {"prompt": inputs, **filter_none(parameters)} def get_response(self, response: Union[bytes, Dict]) -> Any: url = _as_dict(response)["video"]["url"] diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index b8b37fdca6..cc899f906d 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -1,40 +1,20 @@ from typing import Any, Dict, Optional, Union -from huggingface_hub import constants from huggingface_hub.inference._common import ( RequestParameters, TaskProviderHelper, _as_dict, - _get_provider_mapping, ) from huggingface_hub.utils import build_hf_headers, get_session, get_token, logging +from ._common import filter_none, get_base_url, get_mapped_model + logger = logging.get_logger(__name__) BASE_URL = "https://api.replicate.com" -SUPPORTED_MODELS = { - "text-to-image": { - "black-forest-labs/FLUX.1-dev": "black-forest-labs/flux-dev", - "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", - "ByteDance/Hyper-SD": "bytedance/hyper-flux-16step:382cf8959fb0f0d665b26e7e80b8d6dc3faaef1510f14ce017e8c732bb3d1eb7", - "ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637", - "playgroundai/playground-v2.5-1024px-aesthetic": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24", - "stabilityai/stable-diffusion-3.5-large-turbo": "stability-ai/stable-diffusion-3.5-large-turbo", - "stabilityai/stable-diffusion-3.5-large": "stability-ai/stable-diffusion-3.5-large", - "stabilityai/stable-diffusion-3.5-medium": "stability-ai/stable-diffusion-3.5-medium", - "stabilityai/stable-diffusion-xl-base-1.0": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", - }, - "text-to-speech": { - "hexgrad/Kokoro-82M": "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", - }, - "text-to-video": { - "genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460", - }, -} - def _build_url(base_url: str, model: str) -> str: if ":" in model: @@ -64,14 +44,9 @@ def prepare_request( ) # Route to the proxy if the api_key is a HF TOKEN - if api_key.startswith("hf_"): - base_url = constants.INFERENCE_PROXY_TEMPLATE.format(provider="replicate") - logger.info("Calling Replicate provider through Hugging Face proxy.") - else: - base_url = BASE_URL - logger.info("Calling Replicate provider directly.") - - mapped_model = self._map_model(model) + base_url = get_base_url("replicate", BASE_URL, api_key) + + mapped_model = get_mapped_model("replicate", model, self.task) url = _build_url(base_url, mapped_model) headers = { @@ -91,31 +66,6 @@ def prepare_request( headers=headers, ) - def _map_model(self, model: Optional[str]) -> str: - if model is None: - raise ValueError("Please provide a HF model ID supported by Replicate.") - provider_mapping = _get_provider_mapping(model, "replicate") - if provider_mapping: - provider_task = provider_mapping.get("task") - status = provider_mapping.get("status") - - if provider_task != self.task: - raise ValueError( - f"Model {model} is not supported for task {self.task} and provider Replicate. " - f"Supported task: {provider_task}." - ) - if status == "staging": - logger.warning( - f"Model {model} is in staging mode for provider Replicate. Meant for test purposes only." - ) - return provider_mapping["providerId"] - if self.task not in SUPPORTED_MODELS: - raise ValueError(f"Task {self.task} not supported with Replicate.") - mapped_model = SUPPORTED_MODELS[self.task].get(model) - if mapped_model is None: - raise ValueError(f"Model {model} is not supported with Replicate for task {self.task}.") - return mapped_model - def _prepare_payload( self, inputs: Any, @@ -125,7 +75,7 @@ def _prepare_payload( payload: Dict[str, Any] = { "input": { "prompt": inputs, - **{k: v for k, v in parameters.items() if v is not None}, + **filter_none(parameters), } } if ":" in model: @@ -157,12 +107,7 @@ def _prepare_payload( model: str, ) -> Dict[str, Any]: # The following payload might work only for a subset of text-to-speech Replicate models. - payload: Dict[str, Any] = { - "input": { - "text": inputs, - **{k: v for k, v in parameters.items() if v is not None}, - }, - } + payload: Dict[str, Any] = {"input": {"text": inputs, **filter_none(parameters)}} if ":" in model: version = model.split(":", 1)[1] payload["version"] = version diff --git a/src/huggingface_hub/inference/_providers/sambanova.py b/src/huggingface_hub/inference/_providers/sambanova.py index 38c7b65980..eabbb15422 100644 --- a/src/huggingface_hub/inference/_providers/sambanova.py +++ b/src/huggingface_hub/inference/_providers/sambanova.py @@ -1,36 +1,19 @@ from typing import Any, Dict, Optional, Union -from huggingface_hub import constants from huggingface_hub.inference._common import ( RequestParameters, TaskProviderHelper, - _get_provider_mapping, ) from huggingface_hub.utils import build_hf_headers, get_token, logging +from ._common import filter_none, get_base_url, get_mapped_model + logger = logging.get_logger(__name__) BASE_URL = "https://api.sambanova.ai" -SUPPORTED_MODELS = { - "conversational": { - "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", - "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct", - "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview", - "meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct", - "meta-llama/Llama-3.2-1B": "Meta-Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-3B": "Meta-Llama-3.2-3B-Instruct", - "meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct", - "meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct", - "meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct", - "meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct", - "meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct", - "meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B", - }, -} - class SambanovaConversationalTask(TaskProviderHelper): def __init__(self): @@ -55,20 +38,11 @@ def prepare_request( ) # Route to the proxy if the api_key is a HF TOKEN - if api_key.startswith("hf_"): - base_url = constants.INFERENCE_PROXY_TEMPLATE.format(provider="sambanova") - logger.info("Calling Sambanova provider through Hugging Face proxy.") - else: - base_url = BASE_URL - logger.info("Calling Sambanova provider directly.") + base_url = get_base_url("sambanova", BASE_URL, api_key) headers = {**build_hf_headers(token=api_key), **headers} - mapped_model = self._map_model(model) - payload = { - "messages": inputs, - **{k: v for k, v in parameters.items() if v is not None}, - "model": mapped_model, - } + mapped_model = get_mapped_model("sambanova", model, self.task) + payload = {"messages": inputs, **filter_none(parameters), "model": mapped_model} return RequestParameters( url=f"{base_url}/v1/chat/completions", @@ -79,30 +53,5 @@ def prepare_request( headers=headers, ) - def _map_model(self, model: Optional[str]) -> str: - if model is None: - raise ValueError("Please provide a HF model ID supported by Sambanova.") - provider_mapping = _get_provider_mapping(model, "sambanova") - if provider_mapping: - provider_task = provider_mapping.get("task") - status = provider_mapping.get("status") - if provider_task != self.task: - raise ValueError( - f"Model {model} is not supported for task {self.task} and provider Sambanova. " - f"Supported task: {provider_task}." - ) - if status == "staging": - logger.warning( - f"Model {model} is in staging mode for provider Sambanova. Meant for test purposes only." - ) - return provider_mapping["providerId"] - - if self.task not in SUPPORTED_MODELS: - raise ValueError(f"Task {self.task} not supported with Sambanova.") - mapped_model = SUPPORTED_MODELS[self.task].get(model) - if mapped_model is None: - raise ValueError(f"Model {model} is not supported with Sambanova for task {self.task}.") - return mapped_model - def get_response(self, response: Union[bytes, Dict]) -> Any: return response diff --git a/src/huggingface_hub/inference/_providers/together.py b/src/huggingface_hub/inference/_providers/together.py index 53c6b76315..0af4e3257e 100644 --- a/src/huggingface_hub/inference/_providers/together.py +++ b/src/huggingface_hub/inference/_providers/together.py @@ -2,71 +2,21 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Union -from huggingface_hub import constants from huggingface_hub.inference._common import ( RequestParameters, TaskProviderHelper, _as_dict, - _get_provider_mapping, ) from huggingface_hub.utils import build_hf_headers, get_token, logging +from ._common import filter_none, get_base_url, get_mapped_model + logger = logging.get_logger(__name__) BASE_URL = "https://api.together.xyz" -SUPPORTED_MODELS = { - "conversational": { - "databricks/dbrx-instruct": "databricks/dbrx-instruct", - "deepseek-ai/DeepSeek-R1": "deepseek-ai/DeepSeek-R1", - "deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat", - "deepseek-ai/DeepSeek-V3": "deepseek-ai/DeepSeek-V3", - "google/gemma-2-9b-it": "google/gemma-2-9b-it", - "google/gemma-2b-it": "google/gemma-2-27b-it", - "meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf", - "meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free", - "meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo", - "meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", - "meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo", - "meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf", - "meta-llama/Meta-Llama-3-8B-Instruct": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo", - "meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", - "meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - "meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - "microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B", - "mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3", - "mistralai/Mistral-Small-24B-Instruct-2501": "mistralai/Mistral-Small-24B-Instruct-2501", - "mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1", - "mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", - "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", - "Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct", - "Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo", - "Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo", - "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct", - "Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview", - "scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", - "scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316", - }, - "text-generation": { - "meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf", - "meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B", - "mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1", - }, - "text-to-image": { - "black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny", - "black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth", - "black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev", - "black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux", - "black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro", - "stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0", - }, -} - - PER_TASK_ROUTES = { "conversational": "v1/chat/completions", "text-generation": "v1/completions", @@ -99,13 +49,9 @@ def prepare_request( headers = {**build_hf_headers(token=api_key), **headers} # Route to the proxy if the api_key is a HF TOKEN - if api_key.startswith("hf_"): - base_url = constants.INFERENCE_PROXY_TEMPLATE.format(provider="together") - logger.info("Calling Together provider through Hugging Face proxy.") - else: - base_url = BASE_URL - logger.info("Calling Together provider directly.") - mapped_model = self._map_model(model) + base_url = get_base_url("together", BASE_URL, api_key) + mapped_model = mapped_model = get_mapped_model("fal-ai", model, self.task) + if "model" in parameters: parameters["model"] = mapped_model payload = self._prepare_payload(inputs, parameters=parameters) @@ -119,32 +65,6 @@ def prepare_request( headers=headers, ) - def _map_model(self, model: Optional[str]) -> str: - if model is None: - raise ValueError("Please provide a HF model ID supported by Together.") - provider_mapping = _get_provider_mapping(model, "together") - - if provider_mapping: - provider_task = provider_mapping.get("task") - status = provider_mapping.get("status") - if provider_task != self.task: - raise ValueError( - f"Model {model} is not supported for task {self.task} and provider Together. " - f"Supported task: {provider_task}." - ) - if status == "staging": - logger.warning( - f"Model {model} is in staging mode for provider Together. Meant for test purposes only." - ) - return provider_mapping["providerId"] - - if self.task not in SUPPORTED_MODELS: - raise ValueError(f"Task {self.task} not supported with Together.") - mapped_model = SUPPORTED_MODELS[self.task].get(model) - if mapped_model is None: - raise ValueError(f"Model {model} is not supported with Together for task {self.task}.") - return mapped_model - def get_response(self, response: Union[bytes, Dict]) -> Any: return response @@ -155,7 +75,7 @@ def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, class TogetherTextGenerationTask(TogetherTask): # Handle both "text-generation" and "conversational" def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: - return {"messages": inputs, **{k: v for k, v in parameters.items() if v is not None}} + return {"messages": inputs, **filter_none(parameters)} class TogetherTextToImageTask(TogetherTask): @@ -163,18 +83,13 @@ def __init__(self): super().__init__("text-to-image") def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: - parameters = {k: v for k, v in parameters.items() if v is not None} + parameters = filter_none(parameters) if "num_inference_steps" in parameters: parameters["steps"] = parameters.pop("num_inference_steps") if "guidance_scale" in parameters: parameters["guidance"] = parameters.pop("guidance_scale") - payload = { - "prompt": inputs, - "response_format": "base64", - **parameters, - } - return payload + return {"prompt": inputs, "response_format": "base64", **parameters} def get_response(self, response: Union[bytes, Dict]) -> Any: response_dict = _as_dict(response)