Skip to content

Commit

Permalink
use two classes for text-generation and conversational
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Feb 14, 2025
1 parent 50c8bc5 commit 3381ced
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 40 deletions.
6 changes: 3 additions & 3 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
from .sambanova import SambanovaConversationalTask
from .together import TogetherTextGenerationTask, TogetherTextToImageTask
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask


PROVIDER_T = Literal[
Expand Down Expand Up @@ -78,8 +78,8 @@
},
"together": {
"text-to-image": TogetherTextToImageTask(),
"conversational": TogetherTextGenerationTask("conversational"),
"text-generation": TogetherTextGenerationTask("text-generation"),
"conversational": TogetherConversationalTask(),
"text-generation": TogetherTextGenerationTask(),
},
}

Expand Down
37 changes: 25 additions & 12 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,36 @@ def _prepare_payload_as_bytes(
return None


class BaseConversationalTask(TaskProviderHelper):
"""
Base class for conversational (chat completion) tasks.
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/chat
"""

def __init__(self, provider: str, base_url: str):
super().__init__(provider=provider, base_url=base_url, task="conversational")

def _prepare_route(self, mapped_model: str) -> str:
return "/v1/chat/completions"

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}


class BaseTextGenerationTask(TaskProviderHelper):
def __init__(self, provider: str, base_url: str, task: str):
super().__init__(provider=provider, base_url=base_url, task=task)
"""
Base class for text-generation (completion) tasks.
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/completions
"""

def __init__(self, provider: str, base_url: str):
super().__init__(provider=provider, base_url=base_url, task="text-generation")

def _prepare_route(self, mapped_model: str) -> str:
if self.task == "conversational":
return "/v1/chat/completions"
elif self.task == "text-generation":
return "/v1/completions"
raise ValueError(f"Unsupported task '{self.task}' for {self.provider}.")
return "/v1/completions"

def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
if self.task == "conversational":
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
elif self.task == "text-generation":
return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}
raise ValueError(f"Unsupported task '{self.task}' for {self.provider}.")
return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}


@lru_cache(maxsize=None)
Expand Down
6 changes: 3 additions & 3 deletions src/huggingface_hub/inference/_providers/fireworks_ai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._common import BaseTextGenerationTask
from ._common import BaseConversationalTask


class FireworksAIConversationalTask(BaseTextGenerationTask):
class FireworksAIConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai/inference", task="conversational")
super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai/inference")
20 changes: 11 additions & 9 deletions src/huggingface_hub/inference/_providers/hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub.inference._common import _as_dict
from huggingface_hub.inference._providers._common import BaseTextGenerationTask, TaskProviderHelper, filter_none
from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none


class HyperbolicTextToImageTask(TaskProviderHelper):
Expand Down Expand Up @@ -30,12 +30,14 @@ def get_response(self, response: Union[bytes, Dict]) -> Any:
return base64.b64decode(response_dict["images"][0]["image"])


class HyperbolicTextGenerationTask(BaseTextGenerationTask):
def __init__(self, task: str):
super().__init__(provider="hyperbolic", base_url="https://api.hyperbolic.xyz", task=task)
class HyperbolicTextGenerationTask(BaseConversationalTask):
"""
Special case for Hyperbolic, where text-generation task is handled as a conversational task.
"""

def _prepare_route(self, mapped_model: str) -> str:
# For Hyperbolic, the route is the same for text-generation and conversational tasks
if self.task in ("text-generation", "conversational"):
return "/v1/chat/completions"
raise ValueError(f"Unsupported task '{self.task}' for Hyperbolic.")
def __init__(self, task: str):
super().__init__(
provider="hyperbolic",
base_url="https://api.hyperbolic.xyz",
)
self.task = task
6 changes: 3 additions & 3 deletions src/huggingface_hub/inference/_providers/sambanova.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from huggingface_hub.inference._providers._common import BaseTextGenerationTask
from huggingface_hub.inference._providers._common import BaseConversationalTask


class SambanovaConversationalTask(BaseTextGenerationTask):
class SambanovaConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="sambanova", base_url="https://api.sambanova.ai", task="conversational")
super().__init__(provider="sambanova", base_url="https://api.sambanova.ai")
16 changes: 13 additions & 3 deletions src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub.inference._common import _as_dict
from huggingface_hub.inference._providers._common import BaseTextGenerationTask, TaskProviderHelper, filter_none
from huggingface_hub.inference._providers._common import (
BaseConversationalTask,
BaseTextGenerationTask,
TaskProviderHelper,
filter_none,
)


class TogetherTask(TaskProviderHelper, ABC):
Expand All @@ -23,8 +28,13 @@ def _prepare_route(self, mapped_model: str) -> str:


class TogetherTextGenerationTask(BaseTextGenerationTask):
def __init__(self, task: str):
super().__init__(provider="together", base_url="https://api.together.xyz", task=task)
def __init__(self):
super().__init__(provider="together", base_url="https://api.together.xyz")


class TogetherConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="together", base_url="https://api.together.xyz")


class TogetherTextToImageTask(TogetherTask):
Expand Down
15 changes: 8 additions & 7 deletions tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
)
from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask
from huggingface_hub.inference._providers.together import TogetherTextGenerationTask, TogetherTextToImageTask
from huggingface_hub.inference._providers.together import (
TogetherConversationalTask,
TogetherTextGenerationTask,
TogetherTextToImageTask,
)


class TestFalAIProvider:
Expand Down Expand Up @@ -233,9 +237,6 @@ def test_prepare_route(self):
helper = HyperbolicTextGenerationTask("conversational")
assert helper._prepare_route("username/repo_name") == "/v1/chat/completions"

with pytest.raises(ValueError, match="Unsupported task 'invalid-task' for Hyperbolic."):
HyperbolicTextGenerationTask("invalid-task")._prepare_route("username/repo_name")

def test_prepare_payload_conversational(self):
"""Test payload preparation for conversational task."""
helper = HyperbolicTextGenerationTask("conversational")
Expand Down Expand Up @@ -357,17 +358,17 @@ def test_prepare_payload_as_dict(self):

class TestTogetherProvider:
def test_prepare_route(self):
helper = TogetherTextGenerationTask("text-generation")
helper = TogetherTextGenerationTask()
assert helper._prepare_route("username/repo_name") == "/v1/completions"

helper = TogetherTextGenerationTask("conversational")
helper = TogetherConversationalTask()
assert helper._prepare_route("username/repo_name") == "/v1/chat/completions"

helper = TogetherTextToImageTask()
assert helper._prepare_route("username/repo_name") == "/v1/images/generations"

def test_prepare_payload_as_dict_conversational(self):
helper = TogetherTextGenerationTask("conversational")
helper = TogetherConversationalTask()
payload = helper._prepare_payload_as_dict(
[{"role": "user", "content": "Hello!"}], {}, "meta-llama/Llama-3.1-8B-Instruct"
)
Expand Down

0 comments on commit 3381ced

Please sign in to comment.