Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pydantic support in response_format #2647

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mypy
lhoestq committed Oct 31, 2024
commit 30144293ef9bec249292d20c19ff4094bcb08fe0
6 changes: 3 additions & 3 deletions src/huggingface_hub/_webhooks_payload.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
# limitations under the License.
"""Contains data structures to parse the webhooks payload."""

from typing import Any, List, Literal, Optional
from typing import Any, List, Literal, Optional, Union

from .utils import is_pydantic_available

@@ -47,14 +47,14 @@ def schema(cls, *args, **kwargs) -> dict[str, Any]:
)

@classmethod
def model_validate_json(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel":
def model_validate_json(cls, json_data: Union[str, bytes, bytearray], *args, **kwargs) -> "BaseModel":
raise ImportError(
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
)

@classmethod
def parse_raw(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel":
def parse_raw(cls, json_data: Union[str, bytes, bytearray], *args, **kwargs) -> "BaseModel":
raise ImportError(
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
@@ -877,7 +877,7 @@ def chat_completion(
ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon'])
```
"""
if issubclass(response_format, BaseModel):
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
response_model = response_format
response_format = ChatCompletionInputGrammarType(
type="json",
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
@@ -932,7 +932,7 @@ async def chat_completion(
ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon'])
```
"""
if issubclass(response_format, BaseModel):
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
response_model = response_format
response_format = ChatCompletionInputGrammarType(
type="json",