Skip to content

Commit e985e22

Browse files
committed
Serialization fixes.
1 parent 49ab894 commit e985e22

File tree

4 files changed

+217
-17
lines changed

4 files changed

+217
-17
lines changed

docs/index.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,29 @@ CliApp.run(Git, cli_args=['clone', 'repo', 'dir']).model_dump() == {
11731173

11741174
When executing a subcommand with an asynchronous cli_cmd, Pydantic settings automatically detects whether the current thread already has an active event loop. If so, the async command is run in a fresh thread to avoid conflicts. Otherwise, it uses asyncio.run() in the current thread. This handling ensures your asynchronous subcommands "just work" without additional manual setup.
11751175

1176+
### Serializing CLI Arguments
1177+
1178+
An instantiated Pydantic model can be serialized into its CLI arguments using the `CliApp.serialize` method.
1179+
1180+
```py
1181+
from pydantic import BaseModel
1182+
1183+
from pydantic_settings import CliApp
1184+
1185+
1186+
class Nested(BaseModel):
1187+
that: int
1188+
1189+
1190+
class Settings(BaseModel):
1191+
this: str
1192+
nested: Nested
1193+
1194+
1195+
print(CliApp.serialize(Settings(this='hello', nested=Nested(that=123))))
1196+
#> ['--this', 'hello', '--nested.that', '123']
1197+
```
1198+
11761199
### Mutually Exclusive Groups
11771200

11781201
CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class.

pydantic_settings/main.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,25 @@ class CliApp:
477477
CLI applications.
478478
"""
479479

480+
@staticmethod
481+
def _get_base_settings_cls(model_cls: type[Any]) -> type[BaseSettings]:
482+
if issubclass(model_cls, BaseSettings):
483+
return model_cls
484+
485+
class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
486+
__doc__ = model_cls.__doc__
487+
model_config = SettingsConfigDict(
488+
nested_model_default_partial_update=True,
489+
case_sensitive=True,
490+
cli_hide_none_type=True,
491+
cli_avoid_json=True,
492+
cli_enforce_required=True,
493+
cli_implicit_flags=True,
494+
cli_kebab_case=True,
495+
)
496+
497+
return CliAppBaseSettings
498+
480499
@staticmethod
481500
def _run_cli_cmd(model: Any, cli_cmd_method_name: str, is_required: bool) -> Any:
482501
command = getattr(type(model), cli_cmd_method_name, None)
@@ -575,22 +594,10 @@ def run(
575594
model_init_data['_cli_exit_on_error'] = cli_exit_on_error
576595
model_init_data['_cli_settings_source'] = cli_settings
577596
if not issubclass(model_cls, BaseSettings):
578-
579-
class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
580-
__doc__ = model_cls.__doc__
581-
model_config = SettingsConfigDict(
582-
nested_model_default_partial_update=True,
583-
case_sensitive=True,
584-
cli_hide_none_type=True,
585-
cli_avoid_json=True,
586-
cli_enforce_required=True,
587-
cli_implicit_flags=True,
588-
cli_kebab_case=True,
589-
)
590-
591-
model = CliAppBaseSettings(**model_init_data)
597+
base_settings_cls = CliApp._get_base_settings_cls(model_cls)
598+
model = base_settings_cls(**model_init_data)
592599
model_init_data = {}
593-
for field_name, field_info in type(model).model_fields.items():
600+
for field_name, field_info in base_settings_cls.model_fields.items():
594601
model_init_data[_field_name_for_signature(field_name, field_info)] = getattr(model, field_name)
595602

596603
return CliApp._run_cli_cmd(model_cls(**model_init_data), cli_cmd_method_name, is_required=False)
@@ -619,3 +626,18 @@ def run_subcommand(
619626

620627
subcommand = get_subcommand(model, is_required=True, cli_exit_on_error=cli_exit_on_error)
621628
return CliApp._run_cli_cmd(subcommand, cli_cmd_method_name, is_required=True)
629+
630+
@staticmethod
631+
def serialize(model: PydanticModel) -> list[str]:
632+
"""
633+
Serializes the CLI arguments for a Pydantic data model.
634+
635+
Args:
636+
model: The data model to serialize.
637+
638+
Returns:
639+
The serialized CLI arguments for the data model.
640+
"""
641+
642+
base_settings_cls = CliApp._get_base_settings_cls(type(model))
643+
return CliSettingsSource._serialized_args(model, base_settings_cls.model_config)

pydantic_settings/sources/providers/cli.py

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737
import typing_extensions
38-
from pydantic import BaseModel, Field
38+
from pydantic import AliasChoices, AliasPath, BaseModel, Field, create_model
3939
from pydantic._internal._repr import Representation
4040
from pydantic._internal._utils import is_model_class
4141
from pydantic.dataclasses import is_pydantic_dataclass
@@ -47,7 +47,15 @@
4747

4848
from ...exceptions import SettingsError
4949
from ...utils import _lenient_issubclass, _WithArgsTypes
50-
from ..types import NoDecode, _CliExplicitFlag, _CliImplicitFlag, _CliPositionalArg, _CliSubCommand, _CliUnknownArgs
50+
from ..types import (
51+
NoDecode,
52+
PydanticModel,
53+
_CliExplicitFlag,
54+
_CliImplicitFlag,
55+
_CliPositionalArg,
56+
_CliSubCommand,
57+
_CliUnknownArgs,
58+
)
5159
from ..utils import (
5260
_annotation_contains_types,
5361
_annotation_enum_val_to_name,
@@ -1084,3 +1092,116 @@ def _help_format(
10841092
def _is_field_suppressed(self, field_info: FieldInfo) -> bool:
10851093
_help = field_info.description if field_info.description else ''
10861094
return _help == CLI_SUPPRESS or CLI_SUPPRESS in field_info.metadata
1095+
1096+
@classmethod
1097+
def _update_alias_path_only_default(
1098+
cls, arg_name: str, value: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any]
1099+
) -> tuple[str, list[Any] | dict[str, Any]]:
1100+
alias_path: AliasPath = [
1101+
alias if isinstance(alias, AliasPath) else cast(AliasPath, alias.choices[0])
1102+
for alias in (field_info.alias, field_info.validation_alias)
1103+
if isinstance(alias, (AliasPath, AliasChoices))
1104+
][0]
1105+
1106+
alias_nested_paths: list[str] = alias_path.path[1:-1] # type: ignore
1107+
if '.' in arg_name:
1108+
alias_nested_paths = arg_name.split('.') + alias_nested_paths
1109+
arg_name = alias_nested_paths.pop(0)
1110+
1111+
if not alias_nested_paths:
1112+
alias_path_only_defaults.setdefault(arg_name, [])
1113+
alias_default = alias_path_only_defaults[arg_name]
1114+
else:
1115+
alias_path_only_defaults.setdefault(arg_name, {})
1116+
current_path = alias_path_only_defaults[arg_name]
1117+
1118+
for nested_path in alias_nested_paths[:-1]:
1119+
current_path.setdefault(nested_path, {})
1120+
current_path = current_path[nested_path]
1121+
current_path.setdefault(alias_nested_paths[-1], [])
1122+
alias_default = current_path[alias_nested_paths[-1]]
1123+
1124+
alias_path_index = cast(int, alias_path.path[-1])
1125+
alias_default.extend([''] * max(alias_path_index + 1 - len(alias_default), 0))
1126+
alias_default[alias_path_index] = value
1127+
return arg_name, alias_path_only_defaults[arg_name]
1128+
1129+
@classmethod
1130+
def _serialized_args(cls, model: PydanticModel, model_config: Any, prefix: str = '') -> list[str]:
1131+
model_field_definitions: dict[str, Any] = {}
1132+
for field_name, field_info in _get_model_fields(type(model)).items():
1133+
model_default = getattr(model, field_name)
1134+
if field_info.default == model_default:
1135+
continue
1136+
if _CliSubCommand in field_info.metadata and model_default is None:
1137+
continue
1138+
model_field_definitions[field_name] = (field_info.annotation, field_info)
1139+
cli_serialize_cls = create_model('CliSerialize', __config__=model_config, **model_field_definitions)
1140+
1141+
added_args: set[str] = set()
1142+
alias_path_args: dict[str, str] = {}
1143+
alias_path_only_defaults: dict[str, Any] = {}
1144+
optional_args: list[str | list[Any] | dict[str, Any]] = []
1145+
positional_args: list[str | list[Any] | dict[str, Any]] = []
1146+
subcommand_args: list[str] = []
1147+
cli_settings = CliSettingsSource[Any](cli_serialize_cls)
1148+
for field_name, field_info in _get_model_fields(cli_serialize_cls).items():
1149+
model_default = getattr(model, field_name)
1150+
alias_names, is_alias_path_only = _get_alias_names(
1151+
field_name, field_info, alias_path_args=alias_path_args, case_sensitive=cli_settings.case_sensitive
1152+
)
1153+
preferred_alias = alias_names[0]
1154+
if _CliSubCommand in field_info.metadata:
1155+
subcommand_args.append(preferred_alias)
1156+
subcommand_args += cls._serialized_args(model_default, model_config)
1157+
continue
1158+
if is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default)):
1159+
positional_args += cls._serialized_args(
1160+
model_default, model_config, prefix=f'{prefix}{preferred_alias}.'
1161+
)
1162+
continue
1163+
1164+
arg_name = f'{prefix}{cls._check_kebab_name(cli_settings, preferred_alias)}'
1165+
value: str | list[Any] | dict[str, Any] = (
1166+
json.dumps(model_default) if isinstance(model_default, (dict, list, set)) else str(model_default)
1167+
)
1168+
1169+
if is_alias_path_only:
1170+
# For alias path only, we wont know the complete value until we've finished parsing the entire class. In
1171+
# this case, insert value as a non-string reference pointing to the relevant alias_path_only_defaults
1172+
# entry and convert into completed string value later.
1173+
arg_name, value = cls._update_alias_path_only_default(
1174+
arg_name, value, field_info, alias_path_only_defaults
1175+
)
1176+
1177+
if arg_name in added_args:
1178+
continue
1179+
added_args.add(arg_name)
1180+
1181+
if _CliPositionalArg in field_info.metadata:
1182+
if is_alias_path_only:
1183+
positional_args.append(value)
1184+
continue
1185+
for value in model_default if isinstance(model_default, list) else [model_default]:
1186+
value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value)
1187+
positional_args.append(value)
1188+
continue
1189+
1190+
flag_chars = f'{cli_settings.cli_flag_prefix_char * min(len(arg_name), 2)}'
1191+
1192+
kwargs = {'metavar': cls._metavar_format(cli_settings, field_info.annotation)}
1193+
cls._convert_bool_flag(cli_settings, kwargs, field_info, model_default)
1194+
# Note: cls._convert_bool_flag will add action to kwargs if value is implicit bool flag
1195+
if 'action' in kwargs and model_default is False:
1196+
flag_chars += 'no-'
1197+
1198+
optional_args.append(f'{flag_chars}{arg_name}')
1199+
1200+
# If implicit bool flag, do not add a value
1201+
if 'action' not in kwargs:
1202+
optional_args.append(value)
1203+
1204+
serialized_args: list[str] = []
1205+
serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in optional_args]
1206+
serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in positional_args]
1207+
return serialized_args + subcommand_args

tests/test_source_cli.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,21 @@ class Cfg(BaseSettings, cli_avoid_json=avoid_json):
297297
'alias_str': 'str',
298298
}
299299

300+
serialized_cli_args = CliApp.serialize(cfg)
301+
assert serialized_cli_args == [
302+
'-a',
303+
'a',
304+
'--path1',
305+
'["", "b1"]',
306+
'-b',
307+
'b',
308+
'--path2',
309+
'{"deep": ["", "b2"]}',
310+
'--str',
311+
'str',
312+
]
313+
assert CliApp.run(Cfg, cli_args=serialized_cli_args).model_dump() == cfg.model_dump()
314+
300315

301316
@pytest.mark.parametrize('avoid_json', [True, False])
302317
def test_cli_alias_nested_arg(capsys, monkeypatch, avoid_json):
@@ -333,6 +348,19 @@ class Cfg(BaseSettings, cli_avoid_json=avoid_json):
333348
}
334349
}
335350

351+
serialized_cli_args = CliApp.serialize(cfg)
352+
assert serialized_cli_args == [
353+
'--nest.a',
354+
'a',
355+
'--nest',
356+
'{"path1": ["", "b1"], "path2": {"deep": ["", "b2"]}}',
357+
'--nest.b',
358+
'b',
359+
'--nest.str',
360+
'str',
361+
]
362+
assert CliApp.run(Cfg, cli_args=serialized_cli_args).model_dump() == cfg.model_dump()
363+
336364

337365
def test_cli_alias_exceptions(capsys, monkeypatch):
338366
with pytest.raises(SettingsError, match='subcommand argument BadCliSubCommand.foo has multiple aliases'):
@@ -2574,3 +2602,9 @@ class Settings(BaseSettings):
25742602
'nested': {'option': 'bar'},
25752603
'option2': 'foo2',
25762604
}
2605+
2606+
2607+
# ADD TEST KEBAB CASE
2608+
# ADD TEST SUBCOMMAND BEFORE POSITIONAL
2609+
# ADD TEST IMPLICIT BOOL FLAGS
2610+
# ADD TEST ONLY SERIALIZE NON-DEFAULT VALUES

0 commit comments

Comments
 (0)