41
41
from dotenv import dotenv_values
42
42
from pydantic import AliasChoices , AliasPath , BaseModel , Json , RootModel , Secret , TypeAdapter
43
43
from pydantic ._internal ._repr import Representation
44
- from pydantic ._internal ._typing_extra import WithArgsTypes , origin_is_union , typing_base
45
- from pydantic ._internal ._utils import deep_update , is_model_class , lenient_issubclass
44
+ from pydantic ._internal ._utils import deep_update , is_model_class
46
45
from pydantic .dataclasses import is_pydantic_dataclass
47
46
from pydantic .fields import FieldInfo
48
47
from pydantic_core import PydanticUndefined
49
- from typing_extensions import _AnnotatedAlias , get_args , get_origin
48
+ from typing_extensions import get_args , get_origin
49
+ from typing_inspection import typing_objects
50
+ from typing_inspection .introspection import is_union_origin
50
51
51
- from pydantic_settings .utils import path_type_label
52
+ from pydantic_settings .utils import _lenient_issubclass , _WithArgsTypes , path_type_label
52
53
53
54
if TYPE_CHECKING :
54
55
if sys .version_info >= (3 , 11 ):
@@ -484,7 +485,7 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
484
485
field_info .append ((v_alias , self ._apply_case_sensitive (v_alias ), False ))
485
486
486
487
if not v_alias or self .config .get ('populate_by_name' , False ):
487
- if origin_is_union (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
488
+ if is_union_origin (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
488
489
field_info .append ((field_name , self ._apply_case_sensitive (self .env_prefix + field_name ), True ))
489
490
else :
490
491
field_info .append ((field_name , self ._apply_case_sensitive (self .env_prefix + field_name ), False ))
@@ -530,12 +531,13 @@ class Settings(BaseSettings):
530
531
annotation = field .annotation
531
532
532
533
# If field is Optional, we need to find the actual type
533
- args = get_args (annotation )
534
- if origin_is_union (get_origin (field .annotation )) and len (args ) == 2 and type (None ) in args :
535
- for arg in args :
536
- if arg is not None :
537
- annotation = arg
538
- break
534
+ if is_union_origin (get_origin (field .annotation )):
535
+ args = get_args (annotation )
536
+ if len (args ) == 2 and type (None ) in args :
537
+ for arg in args :
538
+ if arg is not None :
539
+ annotation = arg
540
+ break
539
541
540
542
# This is here to make mypy happy
541
543
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
@@ -553,7 +555,7 @@ class Settings(BaseSettings):
553
555
values [name ] = value
554
556
continue
555
557
556
- if lenient_issubclass (sub_model_field .annotation , BaseModel ) and isinstance (value , dict ):
558
+ if _lenient_issubclass (sub_model_field .annotation , BaseModel ) and isinstance (value , dict ):
557
559
values [sub_model_field_name ] = self ._replace_field_names_case_insensitively (sub_model_field , value )
558
560
else :
559
561
values [sub_model_field_name ] = value
@@ -623,7 +625,7 @@ def __call__(self) -> dict[str, Any]:
623
625
field_value = None
624
626
if (
625
627
not self .case_sensitive
626
- # and lenient_issubclass (field.annotation, BaseModel)
628
+ # and _lenient_issubclass (field.annotation, BaseModel)
627
629
and isinstance (field_value , dict )
628
630
):
629
631
data [field_key ] = self ._replace_field_names_case_insensitively (field , field_value )
@@ -842,7 +844,7 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
842
844
"""
843
845
if self .field_is_complex (field ):
844
846
allow_parse_failure = False
845
- elif origin_is_union (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
847
+ elif is_union_origin (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
846
848
allow_parse_failure = True
847
849
else :
848
850
return False , False
@@ -888,12 +890,11 @@ class Cfg(BaseSettings):
888
890
return None
889
891
890
892
annotation = field .annotation if isinstance (field , FieldInfo ) else field
891
- if origin_is_union (get_origin (annotation )) or isinstance (annotation , WithArgsTypes ):
892
- for type_ in get_args (annotation ):
893
- type_has_key = self .next_field (type_ , key , case_sensitive )
894
- if type_has_key :
895
- return type_has_key
896
- elif is_model_class (annotation ) or is_pydantic_dataclass (annotation ):
893
+ for type_ in get_args (annotation ):
894
+ type_has_key = self .next_field (type_ , key , case_sensitive )
895
+ if type_has_key :
896
+ return type_has_key
897
+ if is_model_class (annotation ) or is_pydantic_dataclass (annotation ):
897
898
fields = _get_model_fields (annotation )
898
899
# `case_sensitive is None` is here to be compatible with the old behavior.
899
900
# Has to be removed in V3.
@@ -923,7 +924,8 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
923
924
if not self .env_nested_delimiter :
924
925
return {}
925
926
926
- is_dict = lenient_issubclass (get_origin (field .annotation ), dict )
927
+ ann = field .annotation
928
+ is_dict = ann is dict or _lenient_issubclass (get_origin (ann ), dict )
927
929
928
930
prefixes = [
929
931
f'{ env_name } { self .env_nested_delimiter } ' for _ , env_name , _ in self ._extract_field_info (field , field_name )
@@ -1065,7 +1067,7 @@ def __call__(self) -> dict[str, Any]:
1065
1067
(
1066
1068
_annotation_is_complex (field .annotation , field .metadata )
1067
1069
or (
1068
- origin_is_union (get_origin (field .annotation ))
1070
+ is_union_origin (get_origin (field .annotation ))
1069
1071
and _union_is_complex (field .annotation , field .metadata )
1070
1072
)
1071
1073
)
@@ -1382,7 +1384,7 @@ def _get_merge_parsed_list_types(
1382
1384
merge_type = self ._cli_dict_args .get (field_name , list )
1383
1385
if (
1384
1386
merge_type is list
1385
- or not origin_is_union (get_origin (merge_type ))
1387
+ or not is_union_origin (get_origin (merge_type ))
1386
1388
or not any (
1387
1389
type_
1388
1390
for type_ in get_args (merge_type )
@@ -1526,7 +1528,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
1526
1528
alias_names , * _ = _get_alias_names (field_name , field_info )
1527
1529
if len (alias_names ) > 1 :
1528
1530
raise SettingsError (f'subcommand argument { model .__name__ } .{ field_name } has multiple aliases' )
1529
- field_types = [ type_ for type_ in get_args (field_info .annotation ) if type_ is not type (None )]
1531
+ field_types = ( type_ for type_ in get_args (field_info .annotation ) if type_ is not type (None ))
1530
1532
for field_type in field_types :
1531
1533
if not (is_model_class (field_type ) or is_pydantic_dataclass (field_type )):
1532
1534
raise SettingsError (
@@ -1984,19 +1986,20 @@ def _metavar_format_recurse(self, obj: Any) -> str:
1984
1986
return '...'
1985
1987
elif isinstance (obj , Representation ):
1986
1988
return repr (obj )
1987
- elif isinstance (obj , typing_extensions . TypeAliasType ):
1989
+ elif typing_objects . is_typealiastype (obj ):
1988
1990
return str (obj )
1989
1991
1990
- if not isinstance (obj , (typing_base , WithArgsTypes , type )):
1992
+ origin = get_origin (obj )
1993
+ if origin is None and not isinstance (obj , (type , typing .ForwardRef , typing_extensions .ForwardRef )):
1991
1994
obj = obj .__class__
1992
1995
1993
- if origin_is_union ( get_origin ( obj ) ):
1996
+ if is_union_origin ( origin ):
1994
1997
return self ._metavar_format_choices (list (map (self ._metavar_format_recurse , self ._get_modified_args (obj ))))
1995
- elif get_origin ( obj ) in ( typing_extensions . Literal , typing . Literal ):
1998
+ elif typing_objects . is_literal ( origin ):
1996
1999
return self ._metavar_format_choices (list (map (str , self ._get_modified_args (obj ))))
1997
- elif lenient_issubclass (obj , Enum ):
2000
+ elif _lenient_issubclass (obj , Enum ):
1998
2001
return self ._metavar_format_choices ([val .name for val in obj ])
1999
- elif isinstance (obj , WithArgsTypes ):
2002
+ elif isinstance (obj , _WithArgsTypes ):
2000
2003
return self ._metavar_format_choices (
2001
2004
list (map (self ._metavar_format_recurse , self ._get_modified_args (obj ))),
2002
2005
obj_qualname = obj .__qualname__ if hasattr (obj , '__qualname__' ) else str (obj ),
@@ -2292,25 +2295,22 @@ def read_env_file(
2292
2295
def _annotation_is_complex (annotation : type [Any ] | None , metadata : list [Any ]) -> bool :
2293
2296
# If the model is a root model, the root annotation should be used to
2294
2297
# evaluate the complexity.
2295
- try :
2296
- if annotation is not None and issubclass (annotation , RootModel ):
2297
- # In some rare cases (see test_root_model_as_field),
2298
- # the root attribute is not available. For these cases, python 3.8 and 3.9
2299
- # return 'RootModelRootType'.
2300
- root_annotation = annotation .__annotations__ .get ('root' , None )
2301
- if root_annotation is not None and root_annotation != 'RootModelRootType' :
2302
- annotation = root_annotation
2303
- except TypeError :
2304
- pass
2298
+ if annotation is not None and _lenient_issubclass (annotation , RootModel ) and annotation is not RootModel :
2299
+ annotation = cast ('type[RootModel[Any]]' , annotation )
2300
+ root_annotation = annotation .model_fields ['root' ].annotation
2301
+ if root_annotation is not None :
2302
+ annotation = root_annotation
2305
2303
2306
2304
if any (isinstance (md , Json ) for md in metadata ): # type: ignore[misc]
2307
2305
return False
2306
+
2307
+ origin = get_origin (annotation )
2308
+
2308
2309
# Check if annotation is of the form Annotated[type, metadata].
2309
- if isinstance ( annotation , _AnnotatedAlias ):
2310
+ if typing_objects . is_annotated ( origin ):
2310
2311
# Return result of recursive call on inner type.
2311
2312
inner , * meta = get_args (annotation )
2312
2313
return _annotation_is_complex (inner , meta )
2313
- origin = get_origin (annotation )
2314
2314
2315
2315
if origin is Secret :
2316
2316
return False
@@ -2324,12 +2324,12 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) ->
2324
2324
2325
2325
2326
2326
def _annotation_is_complex_inner (annotation : type [Any ] | None ) -> bool :
2327
- if lenient_issubclass (annotation , (str , bytes )):
2327
+ if _lenient_issubclass (annotation , (str , bytes )):
2328
2328
return False
2329
2329
2330
- return lenient_issubclass ( annotation , ( BaseModel , Mapping , Sequence , tuple , set , frozenset , deque )) or is_dataclass (
2331
- annotation
2332
- )
2330
+ return _lenient_issubclass (
2331
+ annotation , ( BaseModel , Mapping , Sequence , tuple , set , frozenset , deque )
2332
+ ) or is_dataclass ( annotation )
2333
2333
2334
2334
2335
2335
def _union_is_complex (annotation : type [Any ] | None , metadata : list [Any ]) -> bool :
@@ -2353,22 +2353,23 @@ def _annotation_contains_types(
2353
2353
2354
2354
2355
2355
def _strip_annotated (annotation : Any ) -> Any :
2356
- while get_origin (annotation ) == Annotated :
2357
- annotation = get_args (annotation )[0 ]
2358
- return annotation
2356
+ if typing_objects .is_annotated (get_origin (annotation )):
2357
+ return annotation .__origin__
2358
+ else :
2359
+ return annotation
2359
2360
2360
2361
2361
2362
def _annotation_enum_val_to_name (annotation : type [Any ] | None , value : Any ) -> Optional [str ]:
2362
2363
for type_ in (annotation , get_origin (annotation ), * get_args (annotation )):
2363
- if lenient_issubclass (type_ , Enum ):
2364
+ if _lenient_issubclass (type_ , Enum ):
2364
2365
if value in tuple (val .value for val in type_ ):
2365
2366
return type_ (value ).name
2366
2367
return None
2367
2368
2368
2369
2369
2370
def _annotation_enum_name_to_val (annotation : type [Any ] | None , name : Any ) -> Any :
2370
2371
for type_ in (annotation , get_origin (annotation ), * get_args (annotation )):
2371
- if lenient_issubclass (type_ , Enum ):
2372
+ if _lenient_issubclass (type_ , Enum ):
2372
2373
if name in tuple (val .name for val in type_ ):
2373
2374
return type_ [name ]
2374
2375
return None
0 commit comments