diff --git a/aaargs/__init__.py b/aaargs/__init__.py index d35ee36..acd0529 100644 --- a/aaargs/__init__.py +++ b/aaargs/__init__.py @@ -1,5 +1,6 @@ """The aaargs library to help with attribute autocompletion and argparse library""" import argparse +import dataclasses import importlib.metadata import typing @@ -27,9 +28,9 @@ class ArgumentParser(zninit.ZnInit): def __init_subclass__(cls, **kwargs): """Allow adding arguments through subclass creation""" super().__init_subclass__() - for key in kwargs: + for key, value in kwargs.items(): if key in dir(cls): - setattr(cls, key, kwargs[key]) + setattr(cls, key, value) else: raise AttributeError(f"Class {cls} has no attribute '{key}'.") return cls @@ -56,7 +57,7 @@ def get_parser(cls) -> argparse.ArgumentParser: arguments: typing.List[Argument] = cls._get_descriptors() for argument in arguments: - parser.add_argument(*argument.name_or_flags, **argument.kwargs) + parser.add_argument(*argument.name_or_flags, **argument.options.get_dict()) return parser @@ -92,6 +93,28 @@ def parse_args(cls, args=None, namespace=None): ) from err +@dataclasses.dataclass +class _ArgumentOptions: + action: typing.Any + choices: typing.Any + const: typing.Any + default: typing.Any + dest: typing.Any + help: typing.Any + metavar: typing.Any + nargs: typing.Any + required: bool + type: typing.Any + + def get_dict(self) -> dict: + """Get a dict of all value pairs that are not None""" + return { + key.name: getattr(self, key.name) + for key in dataclasses.fields(self) + if getattr(self, key.name) is not None + } + + class Argument(zninit.Descriptor): """An argparse argument.""" @@ -120,55 +143,46 @@ def __init__( argument, if no name_or_flags are provided. """ - if not required and default is zninit.Empty: + if required: + if default in (zninit.Empty, None): + default = zninit.Empty + else: + raise TypeError( + "When using 'required=True' the argument 'default' must be None" + ) + elif default is zninit.Empty: default = None super().__init__(default=default) self.name_or_flags = name_or_flags self.positional = positional - self.kwargs = {} - - if action is not None: - self.kwargs["action"] = action - if choices is not None: - self.kwargs["choices"] = choices - if const is not None: - self.kwargs["const"] = const - if default is not None: - self.kwargs["default"] = default - if dest is not None: - self.kwargs["dest"] = dest - if help is not None: - self.kwargs["help"] = help - if metavar is not None: - self.kwargs["metavar"] = metavar - if nargs is not None: - self.kwargs["nargs"] = nargs - if required is not None: - self.kwargs["required"] = required - if type is not None: - self.kwargs["type"] = type + self.options = _ArgumentOptions( + action=action, + choices=choices, + const=const, + default=default, + dest=dest, + help=help, + metavar=metavar, + nargs=nargs, + required=required, + type=type, + ) - def __get__(self, instance, owner=None): - """Get method of the descriptor + self._check_input() - This class is used to set the name and allows for the special case: + def _check_input(self): + if self.options.required and self.positional: + raise TypeError("'required' is an invalid argument for positionals`") - >>> class MyArgs(ArgumentParser): - >>> filename = Argument() - >>> verbose: bool = Argument() - - which will define a positional argument without defining 'name_or_flags'. - When using 'positional=False' it will be converted to a keyword only argument. - Futhermore, it allows for boolean arguments without defining 'positional=False' - or 'action=store_true' explicitly. + @property + def _is_boolean(self) -> bool: + """Check type annotations if Argument is defined as boolean""" + return self.owner.__annotations__.get(self.name) in ["bool", bool] - """ - if ( - self.owner.__annotations__.get(self.name) in ["bool", bool] - and self.kwargs.get("action") is None - ): - self.kwargs["action"] = "store_true" + def _handle_boolean_annotation(self): + if self._is_boolean and self.options.action is None: + self.options.action = "store_true" if len(self.name_or_flags) == 0: if self.positional: raise TypeError( @@ -182,7 +196,27 @@ def __get__(self, instance, owner=None): ) self.name_or_flags = (f"--{self.name}",) + def __get__(self, instance, owner=None): + """Get method of the descriptor + + This class is used to set the name and allows for the special case: + + >>> class MyArgs(ArgumentParser): + >>> filename = Argument() + >>> verbose: bool = Argument() + + which will define a positional argument without defining 'name_or_flags'. + When using 'positional=False' it will be converted to a keyword only argument. + Futhermore, it allows for boolean arguments without defining 'positional=False' + or 'action=store_true' explicitly. + + """ + self._handle_boolean_annotation() + if len(self.name_or_flags) == 0: self.name_or_flags = (self.name if self.positional else f"--{self.name}",) + if self._is_boolean and self.default in (None, zninit.Empty): + self._default = False + return super().__get__(instance, owner) diff --git a/tests/test_aaargs.py b/tests/test_aaargs.py index 737ec7f..efcffdf 100644 --- a/tests/test_aaargs.py +++ b/tests/test_aaargs.py @@ -45,6 +45,7 @@ class Parser(ArgumentParser): args = Parser.parse_args(["myfile"]) assert args.filename == "myfile" + assert Parser(filename="myfile").filename == "myfile" class Parser(ArgumentParser): description = "Lorem Ipsum" @@ -54,6 +55,9 @@ class Parser(ArgumentParser): args = Parser.parse_args(["myfile", "-e", "utf-8"]) assert args.filename == "myfile" assert args.encoding == "utf-8" + args = Parser(filename="myfile", encoding="utf-8") + assert args.filename == "myfile" + assert args.encoding == "utf-8" args = Parser.parse_args(["myfile", "--encoding", "utf-8"]) assert args.filename == "myfile" @@ -67,6 +71,9 @@ class Parser(ArgumentParser): args = Parser.parse_args(["myfile", "-e", "utf-8"]) assert args.filename == "myfile" assert args.e == "utf-8" + args = Parser(filename="myfile", e="utf-8") + assert args.filename == "myfile" + assert args.e == "utf-8" def test_parse_args_with_defaults(): @@ -77,6 +84,8 @@ class Parser(ArgumentParser): args = Parser.parse_args("") assert args.filename == "myfile.txt" + assert Parser().filename == "myfile.txt" + def test_args_positional(): class Parser(ArgumentParser): @@ -162,13 +171,16 @@ class Parser(ArgumentParser): verbose: annotation = Argument("--verbose") parser = Parser.parse_args(["someone", "--verbose"]) - assert parser.verbose + assert parser.verbose is True assert parser.name == "someone" parser = Parser.parse_args(["someone"]) - assert not parser.verbose + assert parser.verbose is False assert parser.name == "someone" + assert Parser().verbose is False + assert Parser(verbose=True).verbose is True + class Parser(ArgumentParser): name: str = Argument(positional=True) verbose: annotation = Argument(default=True) @@ -186,6 +198,11 @@ class Parser(ArgumentParser): def test_required(): + with pytest.raises(TypeError): + # required is invalid for positionals + class Parser(ArgumentParser): + name: str = Argument(positional=True, required=True) + class Parser(ArgumentParser): name: str = Argument(required=False, default=None) @@ -202,4 +219,4 @@ class Parser(ArgumentParser): name: str = Argument(required=True) with pytest.raises(TypeError): - parser = Parser() + _ = Parser()