Skip to content

Commit

Permalink
rename function
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Feb 5, 2025
1 parent d43462a commit 4988db0
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 19 deletions.
7 changes: 2 additions & 5 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def prepare_request(
raise ValueError(
"You must provide an api_key to work with fal.ai API or log in with `huggingface-cli login`."
)
mapped_model = self.map_model(model)
mapped_model = self._map_model(model)
headers = {
**build_hf_headers(token=api_key),
**headers,
Expand All @@ -85,10 +85,7 @@ def prepare_request(
headers=headers,
)

def map_model(
self,
model: Optional[str],
) -> str:
def _map_model(self, model: Optional[str]) -> str:
"""Default implementation for mapping model HF model IDs to provider model IDs."""
if model is None:
raise ValueError("Please provide a HF model ID supported by fal.ai.")
Expand Down
15 changes: 7 additions & 8 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,23 @@ def prepare_request(
) -> RequestParameters:
if extra_payload is None:
extra_payload = {}
model = self.map_model(model)
url = self.build_url(model)
data, json = self._prepare_payload(inputs, parameters=parameters, model=model, extra_payload=extra_payload)
mapped_model = self._map_model(model)
url = self.build_url(mapped_model)
data, json = self._prepare_payload(
inputs, parameters=parameters, model=mapped_model, extra_payload=extra_payload
)
headers = self.prepare_headers(headers=headers, api_key=api_key)

return RequestParameters(
url=url,
task=self.task,
model=model,
model=mapped_model,
json=json,
data=data,
headers=headers,
)

def map_model(
self,
model: Optional[str],
) -> str:
def _map_model(self, model: Optional[str]) -> str:
return model if model is not None else get_recommended_model(self.task)

def build_url(self, model: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def prepare_request(
base_url = BASE_URL
logger.info("Calling Replicate provider directly.")

mapped_model = self.map_model(model)
mapped_model = self._map_model(model)
url = _build_url(base_url, mapped_model)

headers = {
Expand All @@ -91,7 +91,7 @@ def prepare_request(
headers=headers,
)

def map_model(self, model: Optional[str]) -> str:
def _map_model(self, model: Optional[str]) -> str:
"""Default implementation for mapping model HF model IDs to provider model IDs."""
if model is None:
raise ValueError("Please provide a HF model ID supported by Replicate.")
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_providers/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def prepare_request(
logger.info("Calling Sambanova provider directly.")
headers = {**build_hf_headers(token=api_key), **headers}

mapped_model = self.map_model(model)
mapped_model = self._map_model(model)
payload = {
"messages": inputs,
**{k: v for k, v in parameters.items() if v is not None},
Expand All @@ -79,7 +79,7 @@ def prepare_request(
headers=headers,
)

def map_model(self, model: Optional[str]) -> str:
def _map_model(self, model: Optional[str]) -> str:
"""Default implementation for mapping model HF model IDs to provider model IDs."""
if model is None:
raise ValueError("Please provide a HF model ID supported by Sambanova.")
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def prepare_request(
else:
base_url = BASE_URL
logger.info("Calling Together provider directly.")
mapped_model = self.map_model(model)
mapped_model = self._map_model(model)
if "model" in parameters:
parameters["model"] = mapped_model
payload = self._prepare_payload(inputs, parameters=parameters)
Expand All @@ -119,7 +119,7 @@ def prepare_request(
headers=headers,
)

def map_model(self, model: Optional[str]) -> str:
def _map_model(self, model: Optional[str]) -> str:
"""Default implementation for mapping model HF model IDs to provider model IDs."""
if model is None:
raise ValueError("Please provide a HF model ID supported by Together.")
Expand Down

0 comments on commit 4988db0

Please sign in to comment.