Skip to content

Commit e5fb741

Browse files
authored
Switch to typing-inspection (#556)
Also fix some logic to be a bit more robust.
1 parent 998eb5a commit e5fb741

File tree

4 files changed

+81
-53
lines changed

4 files changed

+81
-53
lines changed

pydantic_settings/sources.py

+52-51
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@
4141
from dotenv import dotenv_values
4242
from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, Secret, TypeAdapter
4343
from pydantic._internal._repr import Representation
44-
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base
45-
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
44+
from pydantic._internal._utils import deep_update, is_model_class
4645
from pydantic.dataclasses import is_pydantic_dataclass
4746
from pydantic.fields import FieldInfo
4847
from pydantic_core import PydanticUndefined
49-
from typing_extensions import _AnnotatedAlias, get_args, get_origin
48+
from typing_extensions import get_args, get_origin
49+
from typing_inspection import typing_objects
50+
from typing_inspection.introspection import is_union_origin
5051

51-
from pydantic_settings.utils import path_type_label
52+
from pydantic_settings.utils import _lenient_issubclass, _WithArgsTypes, path_type_label
5253

5354
if TYPE_CHECKING:
5455
if sys.version_info >= (3, 11):
@@ -484,7 +485,7 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
484485
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
485486

486487
if not v_alias or self.config.get('populate_by_name', False):
487-
if origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
488+
if is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
488489
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
489490
else:
490491
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
@@ -530,12 +531,13 @@ class Settings(BaseSettings):
530531
annotation = field.annotation
531532

532533
# If field is Optional, we need to find the actual type
533-
args = get_args(annotation)
534-
if origin_is_union(get_origin(field.annotation)) and len(args) == 2 and type(None) in args:
535-
for arg in args:
536-
if arg is not None:
537-
annotation = arg
538-
break
534+
if is_union_origin(get_origin(field.annotation)):
535+
args = get_args(annotation)
536+
if len(args) == 2 and type(None) in args:
537+
for arg in args:
538+
if arg is not None:
539+
annotation = arg
540+
break
539541

540542
# This is here to make mypy happy
541543
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
@@ -553,7 +555,7 @@ class Settings(BaseSettings):
553555
values[name] = value
554556
continue
555557

556-
if lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict):
558+
if _lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict):
557559
values[sub_model_field_name] = self._replace_field_names_case_insensitively(sub_model_field, value)
558560
else:
559561
values[sub_model_field_name] = value
@@ -623,7 +625,7 @@ def __call__(self) -> dict[str, Any]:
623625
field_value = None
624626
if (
625627
not self.case_sensitive
626-
# and lenient_issubclass(field.annotation, BaseModel)
628+
# and _lenient_issubclass(field.annotation, BaseModel)
627629
and isinstance(field_value, dict)
628630
):
629631
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
@@ -842,7 +844,7 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
842844
"""
843845
if self.field_is_complex(field):
844846
allow_parse_failure = False
845-
elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
847+
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
846848
allow_parse_failure = True
847849
else:
848850
return False, False
@@ -888,12 +890,11 @@ class Cfg(BaseSettings):
888890
return None
889891

890892
annotation = field.annotation if isinstance(field, FieldInfo) else field
891-
if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes):
892-
for type_ in get_args(annotation):
893-
type_has_key = self.next_field(type_, key, case_sensitive)
894-
if type_has_key:
895-
return type_has_key
896-
elif is_model_class(annotation) or is_pydantic_dataclass(annotation):
893+
for type_ in get_args(annotation):
894+
type_has_key = self.next_field(type_, key, case_sensitive)
895+
if type_has_key:
896+
return type_has_key
897+
if is_model_class(annotation) or is_pydantic_dataclass(annotation):
897898
fields = _get_model_fields(annotation)
898899
# `case_sensitive is None` is here to be compatible with the old behavior.
899900
# Has to be removed in V3.
@@ -923,7 +924,8 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
923924
if not self.env_nested_delimiter:
924925
return {}
925926

926-
is_dict = lenient_issubclass(get_origin(field.annotation), dict)
927+
ann = field.annotation
928+
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
927929

928930
prefixes = [
929931
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
@@ -1065,7 +1067,7 @@ def __call__(self) -> dict[str, Any]:
10651067
(
10661068
_annotation_is_complex(field.annotation, field.metadata)
10671069
or (
1068-
origin_is_union(get_origin(field.annotation))
1070+
is_union_origin(get_origin(field.annotation))
10691071
and _union_is_complex(field.annotation, field.metadata)
10701072
)
10711073
)
@@ -1382,7 +1384,7 @@ def _get_merge_parsed_list_types(
13821384
merge_type = self._cli_dict_args.get(field_name, list)
13831385
if (
13841386
merge_type is list
1385-
or not origin_is_union(get_origin(merge_type))
1387+
or not is_union_origin(get_origin(merge_type))
13861388
or not any(
13871389
type_
13881390
for type_ in get_args(merge_type)
@@ -1526,7 +1528,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
15261528
alias_names, *_ = _get_alias_names(field_name, field_info)
15271529
if len(alias_names) > 1:
15281530
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
1529-
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
1531+
field_types = (type_ for type_ in get_args(field_info.annotation) if type_ is not type(None))
15301532
for field_type in field_types:
15311533
if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)):
15321534
raise SettingsError(
@@ -1984,19 +1986,20 @@ def _metavar_format_recurse(self, obj: Any) -> str:
19841986
return '...'
19851987
elif isinstance(obj, Representation):
19861988
return repr(obj)
1987-
elif isinstance(obj, typing_extensions.TypeAliasType):
1989+
elif typing_objects.is_typealiastype(obj):
19881990
return str(obj)
19891991

1990-
if not isinstance(obj, (typing_base, WithArgsTypes, type)):
1992+
origin = get_origin(obj)
1993+
if origin is None and not isinstance(obj, (type, typing.ForwardRef, typing_extensions.ForwardRef)):
19911994
obj = obj.__class__
19921995

1993-
if origin_is_union(get_origin(obj)):
1996+
if is_union_origin(origin):
19941997
return self._metavar_format_choices(list(map(self._metavar_format_recurse, self._get_modified_args(obj))))
1995-
elif get_origin(obj) in (typing_extensions.Literal, typing.Literal):
1998+
elif typing_objects.is_literal(origin):
19961999
return self._metavar_format_choices(list(map(str, self._get_modified_args(obj))))
1997-
elif lenient_issubclass(obj, Enum):
2000+
elif _lenient_issubclass(obj, Enum):
19982001
return self._metavar_format_choices([val.name for val in obj])
1999-
elif isinstance(obj, WithArgsTypes):
2002+
elif isinstance(obj, _WithArgsTypes):
20002003
return self._metavar_format_choices(
20012004
list(map(self._metavar_format_recurse, self._get_modified_args(obj))),
20022005
obj_qualname=obj.__qualname__ if hasattr(obj, '__qualname__') else str(obj),
@@ -2292,25 +2295,22 @@ def read_env_file(
22922295
def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
22932296
# If the model is a root model, the root annotation should be used to
22942297
# evaluate the complexity.
2295-
try:
2296-
if annotation is not None and issubclass(annotation, RootModel):
2297-
# In some rare cases (see test_root_model_as_field),
2298-
# the root attribute is not available. For these cases, python 3.8 and 3.9
2299-
# return 'RootModelRootType'.
2300-
root_annotation = annotation.__annotations__.get('root', None)
2301-
if root_annotation is not None and root_annotation != 'RootModelRootType':
2302-
annotation = root_annotation
2303-
except TypeError:
2304-
pass
2298+
if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel:
2299+
annotation = cast('type[RootModel[Any]]', annotation)
2300+
root_annotation = annotation.model_fields['root'].annotation
2301+
if root_annotation is not None:
2302+
annotation = root_annotation
23052303

23062304
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
23072305
return False
2306+
2307+
origin = get_origin(annotation)
2308+
23082309
# Check if annotation is of the form Annotated[type, metadata].
2309-
if isinstance(annotation, _AnnotatedAlias):
2310+
if typing_objects.is_annotated(origin):
23102311
# Return result of recursive call on inner type.
23112312
inner, *meta = get_args(annotation)
23122313
return _annotation_is_complex(inner, meta)
2313-
origin = get_origin(annotation)
23142314

23152315
if origin is Secret:
23162316
return False
@@ -2324,12 +2324,12 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) ->
23242324

23252325

23262326
def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
2327-
if lenient_issubclass(annotation, (str, bytes)):
2327+
if _lenient_issubclass(annotation, (str, bytes)):
23282328
return False
23292329

2330-
return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass(
2331-
annotation
2332-
)
2330+
return _lenient_issubclass(
2331+
annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
2332+
) or is_dataclass(annotation)
23332333

23342334

23352335
def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
@@ -2353,22 +2353,23 @@ def _annotation_contains_types(
23532353

23542354

23552355
def _strip_annotated(annotation: Any) -> Any:
2356-
while get_origin(annotation) == Annotated:
2357-
annotation = get_args(annotation)[0]
2358-
return annotation
2356+
if typing_objects.is_annotated(get_origin(annotation)):
2357+
return annotation.__origin__
2358+
else:
2359+
return annotation
23592360

23602361

23612362
def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Optional[str]:
23622363
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
2363-
if lenient_issubclass(type_, Enum):
2364+
if _lenient_issubclass(type_, Enum):
23642365
if value in tuple(val.value for val in type_):
23652366
return type_(value).name
23662367
return None
23672368

23682369

23692370
def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any:
23702371
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
2371-
if lenient_issubclass(type_, Enum):
2372+
if _lenient_issubclass(type_, Enum):
23722373
if name in tuple(val.name for val in type_):
23732374
return type_[name]
23742375
return None

pydantic_settings/utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
import sys
2+
import types
13
from pathlib import Path
4+
from typing import Any, _GenericAlias # type: ignore [attr-defined]
5+
6+
from typing_extensions import get_origin
27

38
_PATH_TYPE_LABELS = {
49
Path.is_dir: 'directory',
@@ -22,3 +27,22 @@ def path_type_label(p: Path) -> str:
2227
return name
2328

2429
return 'unknown'
30+
31+
32+
# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)`
33+
# once we drop support for Python 3.10.
34+
def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
35+
try:
36+
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
37+
except TypeError:
38+
if get_origin(cls) is not None:
39+
# Up until Python 3.10, isinstance(<generic_alias>, type) is True
40+
# (e.g. list[int])
41+
return False
42+
raise
43+
44+
45+
if sys.version_info < (3, 10):
46+
_WithArgsTypes = tuple()
47+
else:
48+
_WithArgsTypes = (_GenericAlias, types.GenericAlias, types.UnionType)

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ requires-python = '>=3.9'
4242
dependencies = [
4343
'pydantic>=2.7.0',
4444
'python-dotenv>=0.21.0',
45+
'typing-inspection>=0.4.0',
4546
]
4647
dynamic = ['version']
4748

requirements/pyproject.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# This file is autogenerated by pip-compile with Python 3.8
2+
# This file is autogenerated by pip-compile with Python 3.13
33
# by the following command:
44
#
55
# pip-compile --extra=azure-key-vault --extra=toml --extra=yaml --no-emit-index-url --output-file=requirements/pyproject.txt pyproject.toml
@@ -63,11 +63,13 @@ tomli==2.0.1
6363
# via pydantic-settings (pyproject.toml)
6464
typing-extensions==4.12.2
6565
# via
66-
# annotated-types
6766
# azure-core
6867
# azure-identity
6968
# azure-keyvault-secrets
7069
# pydantic
7170
# pydantic-core
71+
# typing-inspection
72+
typing-inspection==0.4.0
73+
# via pydantic-settings (pyproject.toml)
7274
urllib3==2.2.2
7375
# via requests

0 commit comments

Comments
 (0)