6
6
import warnings
7
7
from argparse import Action , ArgumentParser
8
8
from contextlib import ExitStack
9
- from dataclasses import MISSING
9
+ from dataclasses import MISSING , fields , is_dataclass
10
10
from pathlib import Path
11
11
from types import SimpleNamespace
12
12
from typing import Optional , Sequence , Type , Union
15
15
import yaml
16
16
from tyro import cli
17
17
from tyro ._argparse_formatter import TyroArgumentParser
18
- from tyro ._fields import NonpropagatingMissingType
19
- # NOTE in the future versions of tyro, include that way:
20
- # from tyro._singleton import NonpropagatingMissingType
18
+ from tyro ._singleton import MISSING_NONPROP
21
19
from tyro .extras import get_parser
22
20
23
- from .auxiliary import yield_annotations , yield_defaults
21
+ from .auxiliary import yield_annotations
24
22
from .form_dict import EnvClass , MissingTagValue
25
23
from .tag import Tag
26
24
from .tag_factory import tag_factory
@@ -137,8 +135,9 @@ def run_tyro_parser(env_or_list: Type[EnvClass] | list[Type[EnvClass]],
137
135
with ExitStack () as stack :
138
136
[stack .enter_context (p ) for p in patches ] # apply just the chosen mocks
139
137
res = cli (type_form , args = args , ** kwargs )
140
- if isinstance (res , NonpropagatingMissingType ):
141
- # NOTE tyro does not work if a required positional is missing tyro.cli() returns just NonpropagatingMissingType.
138
+ if res is MISSING_NONPROP :
139
+ # NOTE tyro does not work if a required positional is missing tyro.cli()
140
+ # returns just NonpropagatingMissingType (MISSING_NONPROP).
142
141
# If this is supported, I might set other attributes like required (date, time).
143
142
# Fail if missing:
144
143
# files: Positional[list[Path]]
@@ -217,12 +216,12 @@ def set_default(kwargs, field_name, val):
217
216
setattr (kwargs ["default" ], field_name , val )
218
217
219
218
220
- def _parse_cli (env_or_list : Type [EnvClass ] | list [Type [EnvClass ]],
221
- config_file : Path | None = None ,
222
- add_verbosity = True ,
223
- ask_for_missing = True ,
224
- args = None ,
225
- ** kwargs ) -> tuple [EnvClass | None , dict , WrongFields ]:
219
+ def parse_cli (env_or_list : Type [EnvClass ] | list [Type [EnvClass ]],
220
+ config_file : Path | None = None ,
221
+ add_verbosity = True ,
222
+ ask_for_missing = True ,
223
+ args = None ,
224
+ ** kwargs ) -> tuple [EnvClass | None , dict , WrongFields ]:
226
225
""" Parse CLI arguments, possibly merged from a config file.
227
226
228
227
Args:
@@ -243,41 +242,90 @@ def _parse_cli(env_or_list: Type[EnvClass] | list[Type[EnvClass]],
243
242
# Load config file
244
243
if config_file and subcommands :
245
244
# Reading config files when using subcommands is not implemented.
246
- static = {}
247
245
kwargs ["default" ] = None
248
246
warnings .warn (f"Config file { config_file } is ignored because subcommands are used."
249
247
" It is not easy to set how this should work."
250
248
" Describe the developer your usecase so that they might implement this." )
251
- if "default" not in kwargs and not subcommands :
249
+
250
+ if "default" not in kwargs and not subcommands and config_file :
252
251
# Undocumented feature. User put a namespace into kwargs["default"]
253
252
# that already serves for defaults. We do not fetch defaults yet from a config file.
254
- disk = {}
255
- if config_file :
256
- disk = yaml .safe_load (config_file .read_text ()) or {} # empty file is ok
257
- # Nested dataclasses have to be properly initialized. YAML gave them as dicts only.
258
- for key in (key for key , val in disk .items () if isinstance (val , dict )):
259
- disk [key ] = env .__annotations__ [key ](** disk [key ])
260
-
261
- # Fill default fields
262
- if pydantic and issubclass (env , BaseModel ):
263
- # Unfortunately, pydantic needs to fill the default with the actual values,
264
- # the default value takes the precedence over the hard coded one, even if missing.
265
- static = {key : env .model_fields .get (key ).default
266
- for ann in yield_annotations (env ) for key in ann if not key .startswith ("__" ) and not key in disk }
267
- # static = {key: env_.model_fields.get(key).default
268
- # for key, _ in iterate_attributes(env_) if not key in disk}
269
- elif attr and attr .has (env ):
270
- # Unfortunately, attrs needs to fill the default with the actual values,
271
- # the default value takes the precedence over the hard coded one, even if missing.
272
- # NOTE Might not work for inherited models.
273
- static = {key : field .default
274
- for key , field in attr .fields_dict (env ).items () if not key .startswith ("__" ) and not key in disk }
275
- else :
276
- # To ensure the configuration file does not need to contain all keys, we have to fill in the missing ones.
277
- # Otherwise, tyro will spawn warnings about missing fields.
278
- static = {key : val
279
- for key , val in yield_defaults (env ) if not key .startswith ("__" ) and not key in disk }
280
- kwargs ["default" ] = SimpleNamespace (** (static | disk ))
253
+ disk = yaml .safe_load (config_file .read_text ()) or {} # empty file is ok
254
+ kwargs ["default" ] = _create_with_missing (env , disk )
281
255
282
256
# Load configuration from CLI
283
257
return run_tyro_parser (subcommands or env , kwargs , add_verbosity , ask_for_missing , args )
258
+
259
+
260
+ def _create_with_missing (env , disk : dict ):
261
+ """
262
+ Create a default instance of an Env object. This is due to provent tyro to spawn warnings about missing fields.
263
+ Nested dataclasses have to be properly initialized. YAML gave them as dicts only.
264
+ """
265
+
266
+ # Determine model
267
+ if pydantic and issubclass (env , BaseModel ):
268
+ m = _process_pydantic
269
+ elif attr and attr .has (env ):
270
+ m = _process_attr
271
+ else : # dataclass
272
+ m = _process_dataclass
273
+
274
+ # Fill default fields with the config file values or leave the defaults.
275
+ # Unfortunately, we have to fill the defaults, we cannot leave them empty
276
+ # as the default value takes the precedence over the hard coded one, even if missing.
277
+ out = {}
278
+ for name , v in m (env , disk ):
279
+ out [name ] = v
280
+ disk .pop (name , None )
281
+
282
+ # Check for unknown fields
283
+ if disk :
284
+ warnings .warn (f"Unknown fields in the configuration file: { ', ' .join (disk )} " )
285
+
286
+ # Safely initialize the model
287
+ return env (** out )
288
+
289
+
290
+ def _process_pydantic (env , disk ):
291
+ for name , f in env .model_fields .items ():
292
+ if name in disk :
293
+ if isinstance (f .default , BaseModel ):
294
+ v = _create_with_missing (f .default .__class__ , disk [name ])
295
+ else :
296
+ v = disk [name ]
297
+ elif f .default is not None :
298
+ v = f .default
299
+ yield name , v
300
+
301
+
302
+ def _process_attr (env , disk ):
303
+ for f in attr .fields (env ):
304
+ if f .name in disk :
305
+ if attr .has (f .default ):
306
+ v = _create_with_missing (f .default .__class__ , disk [f .name ])
307
+ else :
308
+ v = disk [f .name ]
309
+ elif f .default is not attr .NOTHING :
310
+ v = f .default
311
+ else :
312
+ v = MISSING_NONPROP
313
+ yield f .name , v
314
+
315
+
316
+ def _process_dataclass (env , disk ):
317
+ for f in fields (env ):
318
+ if f .name .startswith ("__" ):
319
+ continue
320
+ elif f .name in disk :
321
+ if is_dataclass (f .type ):
322
+ v = _create_with_missing (f .type , disk [f .name ])
323
+ else :
324
+ v = disk [f .name ]
325
+ elif f .default_factory is not MISSING :
326
+ v = f .default_factory ()
327
+ elif f .default is not MISSING :
328
+ v = f .default
329
+ else :
330
+ v = MISSING_NONPROP
331
+ yield f .name , v
0 commit comments