|
35 | 35 | )
|
36 | 36 |
|
37 | 37 | import typing_extensions
|
38 |
| -from pydantic import BaseModel, Field |
| 38 | +from pydantic import AliasChoices, AliasPath, BaseModel, Field, create_model |
39 | 39 | from pydantic._internal._repr import Representation
|
40 | 40 | from pydantic._internal._utils import is_model_class
|
41 | 41 | from pydantic.dataclasses import is_pydantic_dataclass
|
|
47 | 47 |
|
48 | 48 | from ...exceptions import SettingsError
|
49 | 49 | from ...utils import _lenient_issubclass, _WithArgsTypes
|
50 |
| -from ..types import NoDecode, _CliExplicitFlag, _CliImplicitFlag, _CliPositionalArg, _CliSubCommand, _CliUnknownArgs |
| 50 | +from ..types import ( |
| 51 | + NoDecode, |
| 52 | + PydanticModel, |
| 53 | + _CliExplicitFlag, |
| 54 | + _CliImplicitFlag, |
| 55 | + _CliPositionalArg, |
| 56 | + _CliSubCommand, |
| 57 | + _CliUnknownArgs, |
| 58 | +) |
51 | 59 | from ..utils import (
|
52 | 60 | _annotation_contains_types,
|
53 | 61 | _annotation_enum_val_to_name,
|
@@ -1084,3 +1092,116 @@ def _help_format(
|
1084 | 1092 | def _is_field_suppressed(self, field_info: FieldInfo) -> bool:
|
1085 | 1093 | _help = field_info.description if field_info.description else ''
|
1086 | 1094 | return _help == CLI_SUPPRESS or CLI_SUPPRESS in field_info.metadata
|
| 1095 | + |
| 1096 | + @classmethod |
| 1097 | + def _update_alias_path_only_default( |
| 1098 | + cls, arg_name: str, value: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any] |
| 1099 | + ) -> tuple[str, list[Any] | dict[str, Any]]: |
| 1100 | + alias_path: AliasPath = [ |
| 1101 | + alias if isinstance(alias, AliasPath) else cast(AliasPath, alias.choices[0]) |
| 1102 | + for alias in (field_info.alias, field_info.validation_alias) |
| 1103 | + if isinstance(alias, (AliasPath, AliasChoices)) |
| 1104 | + ][0] |
| 1105 | + |
| 1106 | + alias_nested_paths: list[str] = alias_path.path[1:-1] # type: ignore |
| 1107 | + if '.' in arg_name: |
| 1108 | + alias_nested_paths = arg_name.split('.') + alias_nested_paths |
| 1109 | + arg_name = alias_nested_paths.pop(0) |
| 1110 | + |
| 1111 | + if not alias_nested_paths: |
| 1112 | + alias_path_only_defaults.setdefault(arg_name, []) |
| 1113 | + alias_default = alias_path_only_defaults[arg_name] |
| 1114 | + else: |
| 1115 | + alias_path_only_defaults.setdefault(arg_name, {}) |
| 1116 | + current_path = alias_path_only_defaults[arg_name] |
| 1117 | + |
| 1118 | + for nested_path in alias_nested_paths[:-1]: |
| 1119 | + current_path.setdefault(nested_path, {}) |
| 1120 | + current_path = current_path[nested_path] |
| 1121 | + current_path.setdefault(alias_nested_paths[-1], []) |
| 1122 | + alias_default = current_path[alias_nested_paths[-1]] |
| 1123 | + |
| 1124 | + alias_path_index = cast(int, alias_path.path[-1]) |
| 1125 | + alias_default.extend([''] * max(alias_path_index + 1 - len(alias_default), 0)) |
| 1126 | + alias_default[alias_path_index] = value |
| 1127 | + return arg_name, alias_path_only_defaults[arg_name] |
| 1128 | + |
| 1129 | + @classmethod |
| 1130 | + def _serialized_args(cls, model: PydanticModel, model_config: Any, prefix: str = '') -> list[str]: |
| 1131 | + model_field_definitions: dict[str, Any] = {} |
| 1132 | + for field_name, field_info in _get_model_fields(type(model)).items(): |
| 1133 | + model_default = getattr(model, field_name) |
| 1134 | + if field_info.default == model_default: |
| 1135 | + continue |
| 1136 | + if _CliSubCommand in field_info.metadata and model_default is None: |
| 1137 | + continue |
| 1138 | + model_field_definitions[field_name] = (field_info.annotation, field_info) |
| 1139 | + cli_serialize_cls = create_model('CliSerialize', __config__=model_config, **model_field_definitions) |
| 1140 | + |
| 1141 | + added_args: set[str] = set() |
| 1142 | + alias_path_args: dict[str, str] = {} |
| 1143 | + alias_path_only_defaults: dict[str, Any] = {} |
| 1144 | + optional_args: list[str | list[Any] | dict[str, Any]] = [] |
| 1145 | + positional_args: list[str | list[Any] | dict[str, Any]] = [] |
| 1146 | + subcommand_args: list[str] = [] |
| 1147 | + cli_settings = CliSettingsSource[Any](cli_serialize_cls) |
| 1148 | + for field_name, field_info in _get_model_fields(cli_serialize_cls).items(): |
| 1149 | + model_default = getattr(model, field_name) |
| 1150 | + alias_names, is_alias_path_only = _get_alias_names( |
| 1151 | + field_name, field_info, alias_path_args=alias_path_args, case_sensitive=cli_settings.case_sensitive |
| 1152 | + ) |
| 1153 | + preferred_alias = alias_names[0] |
| 1154 | + if _CliSubCommand in field_info.metadata: |
| 1155 | + subcommand_args.append(preferred_alias) |
| 1156 | + subcommand_args += cls._serialized_args(model_default, model_config) |
| 1157 | + continue |
| 1158 | + if is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default)): |
| 1159 | + positional_args += cls._serialized_args( |
| 1160 | + model_default, model_config, prefix=f'{prefix}{preferred_alias}.' |
| 1161 | + ) |
| 1162 | + continue |
| 1163 | + |
| 1164 | + arg_name = f'{prefix}{cls._check_kebab_name(cli_settings, preferred_alias)}' |
| 1165 | + value: str | list[Any] | dict[str, Any] = ( |
| 1166 | + json.dumps(model_default) if isinstance(model_default, (dict, list, set)) else str(model_default) |
| 1167 | + ) |
| 1168 | + |
| 1169 | + if is_alias_path_only: |
| 1170 | + # For alias path only, we wont know the complete value until we've finished parsing the entire class. In |
| 1171 | + # this case, insert value as a non-string reference pointing to the relevant alias_path_only_defaults |
| 1172 | + # entry and convert into completed string value later. |
| 1173 | + arg_name, value = cls._update_alias_path_only_default( |
| 1174 | + arg_name, value, field_info, alias_path_only_defaults |
| 1175 | + ) |
| 1176 | + |
| 1177 | + if arg_name in added_args: |
| 1178 | + continue |
| 1179 | + added_args.add(arg_name) |
| 1180 | + |
| 1181 | + if _CliPositionalArg in field_info.metadata: |
| 1182 | + if is_alias_path_only: |
| 1183 | + positional_args.append(value) |
| 1184 | + continue |
| 1185 | + for value in model_default if isinstance(model_default, list) else [model_default]: |
| 1186 | + value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value) |
| 1187 | + positional_args.append(value) |
| 1188 | + continue |
| 1189 | + |
| 1190 | + flag_chars = f'{cli_settings.cli_flag_prefix_char * min(len(arg_name), 2)}' |
| 1191 | + |
| 1192 | + kwargs = {'metavar': cls._metavar_format(cli_settings, field_info.annotation)} |
| 1193 | + cls._convert_bool_flag(cli_settings, kwargs, field_info, model_default) |
| 1194 | + # Note: cls._convert_bool_flag will add action to kwargs if value is implicit bool flag |
| 1195 | + if 'action' in kwargs and model_default is False: |
| 1196 | + flag_chars += 'no-' |
| 1197 | + |
| 1198 | + optional_args.append(f'{flag_chars}{arg_name}') |
| 1199 | + |
| 1200 | + # If implicit bool flag, do not add a value |
| 1201 | + if 'action' not in kwargs: |
| 1202 | + optional_args.append(value) |
| 1203 | + |
| 1204 | + serialized_args: list[str] = [] |
| 1205 | + serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in optional_args] |
| 1206 | + serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in positional_args] |
| 1207 | + return serialized_args + subcommand_args |
0 commit comments