Skip to content

Commit 9676c61

Browse files
committed
expand pydantic utils to auto reload parent modules when new registration happens
Signed-off-by: Mark Kurtz <[email protected]>
1 parent a7e62c9 commit 9676c61

File tree

1 file changed

+80
-2
lines changed

1 file changed

+80
-2
lines changed

src/speculators/utils/pydantic_utils.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from __future__ import annotations
1212

1313
from abc import ABC, abstractmethod
14-
from typing import Any, ClassVar, Generic, TypeVar
14+
from typing import Any, ClassVar, Generic, TypeVar, get_args, get_origin
1515

1616
from pydantic import BaseModel, GetCoreSchemaHandler
1717
from pydantic_core import CoreSchema, core_schema
@@ -41,15 +41,93 @@ class ReloadableBaseModel(BaseModel):
4141
"""
4242

4343
@classmethod
44-
def reload_schema(cls) -> None:
44+
def reload_schema(cls, parents: bool = True) -> None:
4545
"""
4646
Reload the class schema with updated registry information.
4747
4848
Forces a complete rebuild of the Pydantic model schema to incorporate
4949
any changes made to associated registries or validation rules.
50+
51+
:param parents: Whether to also rebuild schemas for any pydantic parent
52+
types that reference this model.
5053
"""
5154
cls.model_rebuild(force=True)
5255

56+
if parents:
57+
cls.reload_parent_schemas()
58+
59+
@classmethod
60+
def reload_parent_schemas(cls):
61+
"""
62+
Recursively reload schemas for all parent Pydantic models.
63+
64+
Traverses the inheritance hierarchy to find all parent classes that
65+
are Pydantic models and triggers schema rebuilding on each to ensure
66+
that any changes in child models are reflected in parent schemas.
67+
"""
68+
potential_parents: set[type[BaseModel]] = {BaseModel}
69+
stack: list[type[BaseModel]] = [BaseModel]
70+
71+
while stack:
72+
current = stack.pop()
73+
for subclass in current.__subclasses__():
74+
if (
75+
issubclass(subclass, BaseModel)
76+
and subclass is not cls
77+
and subclass not in potential_parents
78+
):
79+
potential_parents.add(subclass)
80+
stack.append(subclass)
81+
82+
for check in cls.__mro__:
83+
if isinstance(check, type) and issubclass(check, BaseModel):
84+
cls._reload_schemas_depending_on(check, potential_parents)
85+
86+
@classmethod
87+
def _reload_schemas_depending_on(cls, target: type[BaseModel], types: set[type]):
88+
changed = True
89+
while changed:
90+
changed = False
91+
for candidate in types:
92+
if (
93+
isinstance(candidate, type)
94+
and issubclass(candidate, BaseModel)
95+
and any(
96+
cls._uses_type(target, field_info.annotation)
97+
for field_info in candidate.model_fields.values()
98+
if field_info.annotation is not None
99+
)
100+
):
101+
try:
102+
before = candidate.model_json_schema()
103+
except Exception: # noqa: BLE001
104+
before = None
105+
candidate.model_rebuild(force=True)
106+
if before is not None:
107+
after = candidate.model_json_schema()
108+
changed |= before != after
109+
110+
@classmethod
111+
def _uses_type(cls, target: type, candidate: type) -> bool:
112+
if target is candidate:
113+
return True
114+
115+
origin = get_origin(candidate)
116+
117+
if origin is None:
118+
return isinstance(candidate, type) and issubclass(candidate, target)
119+
120+
if isinstance(origin, type) and (
121+
target is origin or issubclass(origin, target)
122+
):
123+
return True
124+
125+
for arg in get_args(candidate) or []:
126+
if isinstance(arg, type) and cls._uses_type(target, arg):
127+
return True
128+
129+
return False
130+
53131

54132
class PydanticClassRegistryMixin(
55133
ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT]

0 commit comments

Comments
 (0)