Skip to content

Commit 67d4940

Browse files
committed
Add decorator to parse function type hints
1 parent ed44d8b commit 67d4940

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

docs/guide.md

+48
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,54 @@ flag (as in `--obj=True`), or by making sure there's another flag after any
692692
boolean flag argument.
693693

694694

695+
#### Type hints
696+
697+
Fire can be configured to use type hints information by decorating functions with `UseTypeHints()` decorator.
698+
Only `int`, `float` and `str` type hints are respected by default, everything else is ignored (parsed as usual).
699+
Quite common usecase is to instruct fire not to convert strings to integer/floats by supplying `str`
700+
type annotation.
701+
702+
See minimal example below:
703+
704+
```python
705+
import fire
706+
707+
from fire.decorators import UseTypeHints
708+
709+
710+
@UseTypeHints() # () are mandatory here
711+
def main(a: str, b: float):
712+
print(type(a), type(b), type(c), type(d))
713+
714+
715+
if __name__ == "__main__":
716+
fire.Fire(main)
717+
```
718+
719+
When invoked with `python command.py 1 2` this code will produce `str float`.
720+
721+
You can pass custom type hints parsers via decorator argument, following example shows how to parse custom lists:
722+
723+
```python
724+
import fire
725+
726+
from fire.decorators import UseTypeHints
727+
728+
729+
@UseTypeHints({list: lambda arg: [float(x) for x in arg.split(";")]})
730+
def main(a: list, b: str):
731+
print(a)
732+
733+
734+
if __name__ == "__main__":
735+
fire.Fire(main)
736+
```
737+
738+
This code will convert argument `1;2;3;4` argument into `[1.0, 2.0, 3.0, 4.0]` list with floats.
739+
To override default behavior for `int`, `str`, and `float` type hints you need to add them into dictionary supplied to
740+
`UseTypeHints` decorator.
741+
742+
695743
### Using Fire Flags
696744

697745
Fire CLIs all come with a number of flags. These flags should be separated from

fire/decorators.py

+40
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,46 @@
2929
ACCEPTS_POSITIONAL_ARGS = 'ACCEPTS_POSITIONAL_ARGS'
3030

3131

32+
def UseTypeHints(type_hints_mapping=None):
33+
"""Instruct fire to use type hints information when parsing args for this
34+
function.
35+
36+
Args:
37+
type_hints_mapping: mapping of type hints into parsing functions, by
38+
default floats, ints and strings are treated, and all other type
39+
hints are ignored (parsed as usual)
40+
Returns:
41+
The decorated function, which now has metadata telling Fire how to perform
42+
according to type hints.
43+
44+
Examples:
45+
@UseTypeHints()
46+
def main(a, b:int, c:float=2.0)
47+
assert isinstance(b, int)
48+
assert isinstance(c, float)
49+
50+
@UseTypeHints({list: lambda s: s.split(";")})
51+
def main(a, c: list):
52+
assert isinstance(c, list)
53+
"""
54+
default_type_hints_mapping = {float: float, int: int, str: str}
55+
if type_hints_mapping is None:
56+
type_hints_mapping = {}
57+
type_hints_mapping.update(default_type_hints_mapping)
58+
59+
def _Decorator(fn):
60+
signature = inspect.signature(fn)
61+
named = {}
62+
for name, param in signature.parameters.items():
63+
has_type_hint = param.annotation is not param.empty
64+
if has_type_hint and param.annotation in type_hints_mapping:
65+
named[name] = type_hints_mapping[param.annotation]
66+
decorator = SetParseFns(**named)
67+
decorated_func = decorator(fn)
68+
return decorated_func
69+
return _Decorator
70+
71+
3272
def SetParseFn(fn, *arguments):
3373
"""Sets the fn for Fire to use to parse args when calling the decorated fn.
3474

fire/decorators_test.py

+32
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import absolute_import
1818
from __future__ import division
1919
from __future__ import print_function
20+
import sys
21+
import unittest
2022

2123
from fire import core
2224
from fire import decorators
@@ -90,6 +92,18 @@ def example7(self, arg1, arg2=None, *varargs, **kwargs): # pylint: disable=keyw
9092
return arg1, arg2, varargs, kwargs
9193

9294

95+
if sys.version_info >= (3, 5):
96+
class WithTypeHints(object):
97+
98+
@decorators.UseTypeHints()
99+
def example8(self, a: int, b: str, c, d : float = None):
100+
return a, b, c, d
101+
102+
@decorators.UseTypeHints({list: lambda arg: list(map(int, arg.split(";")))})
103+
def example9(self, a: str, b, c: list, d : list = None):
104+
return a, b, c, d
105+
106+
93107
class FireDecoratorsTest(testutils.BaseTestCase):
94108

95109
def testSetParseFnsNamedArgs(self):
@@ -169,6 +183,24 @@ def testSetParseFn(self):
169183
command=['example7', '1', '--arg2=2', '3', '4', '--kwarg=5']),
170184
('1', '2', ('3', '4'), {'kwarg': '5'}))
171185

186+
@unittest.skipIf(sys.version_info < (3, 5),
187+
'Type hints were introduced in python 3.5')
188+
def testDefaultTypeHints(self):
189+
self.assertEqual(
190+
core.Fire(WithTypeHints,
191+
command=['example8', '1', '2', '3', '--d=4']),
192+
(1, '2', 3, 4)
193+
)
194+
195+
@unittest.skipIf(sys.version_info < (3, 5),
196+
'Type hints were introduced in python 3.5')
197+
def testCustomTypeHints(self):
198+
self.assertEqual(
199+
core.Fire(WithTypeHints,
200+
command=['example9', '1', '2', '3', '--d=4;5;6']),
201+
('1', 2, [3], [4, 5, 6])
202+
)
203+
172204

173205
if __name__ == '__main__':
174206
testutils.main()

0 commit comments

Comments
 (0)