Skip to content

Add support for Pydantic models in stubgen #19095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def __init__(
self.processing_enum = False
self.processing_dataclass = False
self.dataclass_field_specifier: tuple[str, ...] = ()
self.processing_pydantic_model = False

@property
def _current_class(self) -> ClassDef | None:
Expand Down Expand Up @@ -808,6 +809,14 @@ def visit_class_def(self, o: ClassDef) -> None:
if self.analyzed and (spec := find_dataclass_transform_spec(o)):
self.processing_dataclass = True
self.dataclass_field_specifier = spec.field_specifiers
is_pydantic_model = False
for base_type_expr in o.base_type_exprs:
if isinstance(base_type_expr, (NameExpr, MemberExpr)) and self.get_fullname(
base_type_expr
).endswith("BaseModel"):
is_pydantic_model = True
break
self.processing_pydantic_model = is_pydantic_model
super().visit_class_def(o)
self.dedent()
self._vars.pop()
Expand All @@ -825,6 +834,7 @@ def visit_class_def(self, o: ClassDef) -> None:
self.dataclass_field_specifier = ()
self._class_stack.pop(-1)
self.processing_enum = False
self.processing_pydantic_model = False

def get_base_types(self, cdef: ClassDef) -> list[str]:
"""Get list of base classes for a class."""
Expand Down Expand Up @@ -1289,6 +1299,9 @@ def get_assign_initializer(self, rvalue: Expression) -> str:
if not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
# TODO: support other possible cases, where initializer is important
if self.processing_pydantic_model:
if not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."

# By default, no initializer is required:
return ""
Expand Down
118 changes: 118 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -4718,3 +4718,121 @@ class DCMeta(type): ...

class DC(metaclass=DCMeta):
x: str

[case testPydanticBaseModel]
import pydantic

class User(pydantic.BaseModel):
id: int
name: str
active: bool = True
optional_field: str | None = None
[out]
import pydantic

class User(pydantic.BaseModel):
id: int
name: str
active: bool = ...
optional_field: str | None = ...

[case testPydanticBaseModelWithAnnotationsOnly]
import pydantic

class ConfigSettings(pydantic.BaseModel):
# Fields without initialization
db_name: str
port: int
debug: bool
[out]
import pydantic

class ConfigSettings(pydantic.BaseModel):
db_name: str
port: int
debug: bool

[case testPydanticNestedBaseModel]
from pydantic import BaseModel

class Address(BaseModel):
street: str
city: str

class User(BaseModel):
name: str
age: int
address: Address | None = None
[out]
from pydantic import BaseModel

class Address(BaseModel):
street: str
city: str

class User(BaseModel):
name: str
age: int
address: Address | None = ...

[case testPydanticBaseModelComplex]
from pydantic import BaseModel
from typing import Dict, List, Optional, Union

class Item(BaseModel):
name: str
description: Optional[str] = None
tags: List[str] = []
properties: Dict[str, Union[str, int, float, bool]] = {}
[out]
from pydantic import BaseModel

class Item(BaseModel):
name: str
description: str | None = ...
tags: list[str] = ...
properties: dict[str, str | int | float | bool] = ...

[case testPydanticBaseModelInheritance]
from pydantic import BaseModel

class BaseUser(BaseModel):
id: int
active: bool = True

class User(BaseUser):
name: str
email: str
[out]
from pydantic import BaseModel

class BaseUser(BaseModel):
id: int
active: bool = ...

class User(BaseUser):
name: str
email: str

[case testPydanticModelWithMethods]
from pydantic import BaseModel

class User(BaseModel):
id: int
name: str

def get_display_name(self) -> str:
return f"User {self.name}"

@property
def display_id(self) -> str:
return f"ID: {self.id}"
[out]
from pydantic import BaseModel

class User(BaseModel):
id: int
name: str
def get_display_name(self) -> str: ...
@property
def display_id(self) -> str: ...