Skip to content

Commit

Permalink
Handle extra fields in inference types (#2839)
Browse files Browse the repository at this point in the history
* Handle extra fields in inference types

* mypy

* fix regex
  • Loading branch information
Wauplin authored Feb 11, 2025
1 parent e5c84bc commit 39514e7
Show file tree
Hide file tree
Showing 35 changed files with 230 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import Literal, Optional

from .base import BaseInferenceType
from .base import BaseInferenceType, dataclass_with_extra


AudioClassificationOutputTransform = Literal["sigmoid", "softmax", "none"]


@dataclass
@dataclass_with_extra
class AudioClassificationParameters(BaseInferenceType):
"""Additional inference parameters for Audio Classification"""

Expand All @@ -22,7 +21,7 @@ class AudioClassificationParameters(BaseInferenceType):
"""When specified, limits the output to the top K most probable classes."""


@dataclass
@dataclass_with_extra
class AudioClassificationInput(BaseInferenceType):
"""Inputs for Audio Classification inference"""

Expand All @@ -34,7 +33,7 @@ class AudioClassificationInput(BaseInferenceType):
"""Additional inference parameters for Audio Classification"""


@dataclass
@dataclass_with_extra
class AudioClassificationOutputElement(BaseInferenceType):
"""Outputs for Audio Classification inference"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import Any

from .base import BaseInferenceType
from .base import BaseInferenceType, dataclass_with_extra


@dataclass
@dataclass_with_extra
class AudioToAudioInput(BaseInferenceType):
"""Inputs for Audio to Audio inference"""

inputs: Any
"""The input audio data"""


@dataclass
@dataclass_with_extra
class AudioToAudioOutputElement(BaseInferenceType):
"""Outputs of inference for the Audio To Audio task
A generated audio file with its label.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import List, Literal, Optional, Union

from .base import BaseInferenceType
from .base import BaseInferenceType, dataclass_with_extra


AutomaticSpeechRecognitionEarlyStoppingEnum = Literal["never"]


@dataclass
@dataclass_with_extra
class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType):
"""Parametrization of the text generation process"""

Expand Down Expand Up @@ -72,7 +71,7 @@ class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType):
"""Whether the model should use the past last key/values attentions to speed up decoding"""


@dataclass
@dataclass_with_extra
class AutomaticSpeechRecognitionParameters(BaseInferenceType):
"""Additional inference parameters for Automatic Speech Recognition"""

Expand All @@ -83,7 +82,7 @@ class AutomaticSpeechRecognitionParameters(BaseInferenceType):
"""Parametrization of the text generation process"""


@dataclass
@dataclass_with_extra
class AutomaticSpeechRecognitionInput(BaseInferenceType):
"""Inputs for Automatic Speech Recognition inference"""

Expand All @@ -95,15 +94,15 @@ class AutomaticSpeechRecognitionInput(BaseInferenceType):
"""Additional inference parameters for Automatic Speech Recognition"""


@dataclass
@dataclass_with_extra
class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType):
text: str
"""A chunk of text identified by the model"""
timestamp: List[float]
"""The start and end timestamps corresponding with the text"""


@dataclass
@dataclass_with_extra
class AutomaticSpeechRecognitionOutput(BaseInferenceType):
"""Outputs of inference for the Automatic Speech Recognition task"""

Expand Down
21 changes: 21 additions & 0 deletions src/huggingface_hub/inference/_generated/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@
T = TypeVar("T", bound="BaseInferenceType")


def _repr_with_extra(self):
fields = list(self.__dataclass_fields__.keys())
other_fields = list(k for k in self.__dict__ if k not in fields)
return f"{self.__class__.__name__}({', '.join(f'{k}={self.__dict__[k]!r}' for k in fields + other_fields)})"


def dataclass_with_extra(cls: Type[T]) -> Type[T]:
"""Decorator to add a custom __repr__ method to a dataclass, showing all fields, including extra ones.
This decorator only works with dataclasses that inherit from `BaseInferenceType`.
"""
cls = dataclass(cls)
cls.__repr__ = _repr_with_extra # type: ignore[method-assign]
return cls


@dataclass
class BaseInferenceType(dict):
"""Base class for all inference types.
Expand Down Expand Up @@ -115,6 +131,11 @@ def parse_obj(cls: Type[T], data: Union[bytes, str, List, Dict]) -> Union[List[T

# Add remaining fields as dict attributes
item.update(other_values)

# Add remaining fields as extra dataclass fields.
# They won't be part of the dataclass fields but will be accessible as attributes.
# Use @dataclass_with_extra to show them in __repr__.
item.__dict__.update(other_values)
return item

def __post_init__(self):
Expand Down
59 changes: 29 additions & 30 deletions src/huggingface_hub/inference/_generated/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,27 @@
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import Any, List, Literal, Optional, Union

from .base import BaseInferenceType
from .base import BaseInferenceType, dataclass_with_extra


@dataclass
@dataclass_with_extra
class ChatCompletionInputURL(BaseInferenceType):
url: str


ChatCompletionInputMessageChunkType = Literal["text", "image_url"]


@dataclass
@dataclass_with_extra
class ChatCompletionInputMessageChunk(BaseInferenceType):
type: "ChatCompletionInputMessageChunkType"
image_url: Optional[ChatCompletionInputURL] = None
text: Optional[str] = None


@dataclass
@dataclass_with_extra
class ChatCompletionInputMessage(BaseInferenceType):
content: Union[List[ChatCompletionInputMessageChunk], str]
role: str
Expand All @@ -34,7 +33,7 @@ class ChatCompletionInputMessage(BaseInferenceType):
ChatCompletionInputGrammarTypeType = Literal["json", "regex"]


@dataclass
@dataclass_with_extra
class ChatCompletionInputGrammarType(BaseInferenceType):
type: "ChatCompletionInputGrammarTypeType"
value: Any
Expand All @@ -44,7 +43,7 @@ class ChatCompletionInputGrammarType(BaseInferenceType):
"""


@dataclass
@dataclass_with_extra
class ChatCompletionInputStreamOptions(BaseInferenceType):
include_usage: bool
"""If set, an additional chunk will be streamed before the data: [DONE] message. The usage
Expand All @@ -54,33 +53,33 @@ class ChatCompletionInputStreamOptions(BaseInferenceType):
"""


@dataclass
@dataclass_with_extra
class ChatCompletionInputFunctionName(BaseInferenceType):
name: str


@dataclass
@dataclass_with_extra
class ChatCompletionInputToolChoiceClass(BaseInferenceType):
function: ChatCompletionInputFunctionName


ChatCompletionInputToolChoiceEnum = Literal["auto", "none", "required"]


@dataclass
@dataclass_with_extra
class ChatCompletionInputFunctionDefinition(BaseInferenceType):
arguments: Any
name: str
description: Optional[str] = None


@dataclass
@dataclass_with_extra
class ChatCompletionInputTool(BaseInferenceType):
function: ChatCompletionInputFunctionDefinition
type: str


@dataclass
@dataclass_with_extra
class ChatCompletionInput(BaseInferenceType):
"""Chat Completion Input.
Auto-generated from TGI specs.
Expand Down Expand Up @@ -162,61 +161,61 @@ class ChatCompletionInput(BaseInferenceType):
"""


@dataclass
@dataclass_with_extra
class ChatCompletionOutputTopLogprob(BaseInferenceType):
logprob: float
token: str


@dataclass
@dataclass_with_extra
class ChatCompletionOutputLogprob(BaseInferenceType):
logprob: float
token: str
top_logprobs: List[ChatCompletionOutputTopLogprob]


@dataclass
@dataclass_with_extra
class ChatCompletionOutputLogprobs(BaseInferenceType):
content: List[ChatCompletionOutputLogprob]


@dataclass
@dataclass_with_extra
class ChatCompletionOutputFunctionDefinition(BaseInferenceType):
arguments: Any
name: str
description: Optional[str] = None


@dataclass
@dataclass_with_extra
class ChatCompletionOutputToolCall(BaseInferenceType):
function: ChatCompletionOutputFunctionDefinition
id: str
type: str


@dataclass
@dataclass_with_extra
class ChatCompletionOutputMessage(BaseInferenceType):
role: str
content: Optional[str] = None
tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None


@dataclass
@dataclass_with_extra
class ChatCompletionOutputComplete(BaseInferenceType):
finish_reason: str
index: int
message: ChatCompletionOutputMessage
logprobs: Optional[ChatCompletionOutputLogprobs] = None


@dataclass
@dataclass_with_extra
class ChatCompletionOutputUsage(BaseInferenceType):
completion_tokens: int
prompt_tokens: int
total_tokens: int


@dataclass
@dataclass_with_extra
class ChatCompletionOutput(BaseInferenceType):
"""Chat Completion Output.
Auto-generated from TGI specs.
Expand All @@ -232,61 +231,61 @@ class ChatCompletionOutput(BaseInferenceType):
usage: ChatCompletionOutputUsage


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutputFunction(BaseInferenceType):
arguments: str
name: Optional[str] = None


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType):
function: ChatCompletionStreamOutputFunction
id: str
index: int
type: str


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutputDelta(BaseInferenceType):
role: str
content: Optional[str] = None
tool_calls: Optional[ChatCompletionStreamOutputDeltaToolCall] = None


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutputTopLogprob(BaseInferenceType):
logprob: float
token: str


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutputLogprob(BaseInferenceType):
logprob: float
token: str
top_logprobs: List[ChatCompletionStreamOutputTopLogprob]


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutputLogprobs(BaseInferenceType):
content: List[ChatCompletionStreamOutputLogprob]


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutputChoice(BaseInferenceType):
delta: ChatCompletionStreamOutputDelta
index: int
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionStreamOutputLogprobs] = None


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutputUsage(BaseInferenceType):
completion_tokens: int
prompt_tokens: int
total_tokens: int


@dataclass
@dataclass_with_extra
class ChatCompletionStreamOutput(BaseInferenceType):
"""Chat Completion Stream Output.
Auto-generated from TGI specs.
Expand Down
Loading

0 comments on commit 39514e7

Please sign in to comment.