Skip to content

Commit 60d40db

Browse files
committed
Inject
1 parent 1807d2a commit 60d40db

File tree

7 files changed

+215
-80
lines changed

7 files changed

+215
-80
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,29 @@ You can read values or passwords from files, by using the template
122122
or, more securely, read the file in the code
123123
`alphaconf.get('secret_file', Path).read_text().strip()`.
124124

125+
### Inject parameters
126+
127+
We can inject default values to functions from the configuration.
128+
Either one by one, where we can map a factory function or a configuration key.
129+
Or inject all automatically base on the parameter name.
130+
131+
```python
132+
from alphaconf.inject import inject, inject_auto
133+
134+
@inject('name', 'application.name')
135+
@inject_auto(ignore={'name'})
136+
def main(name: str, example=None):
137+
pass
138+
139+
# similar to
140+
def main(name: str=None, example=None):
141+
if name is None:
142+
name = alphaconf.get('application.name', str)
143+
if example is None:
144+
example = alphaconf.get('example', default=example)
145+
...
146+
```
147+
125148
### Invoke integration
126149

127150
Just add the lines below to parameterize invoke.

alphaconf/inject.py

Lines changed: 88 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,92 @@
1+
import functools
12
import inspect
2-
from dataclasses import dataclass
3-
from typing import Callable, Dict, TypeVar
3+
from typing import Any, Callable, Dict, Optional, Union
44

55
import alphaconf
66

7-
R = TypeVar('R')
8-
9-
10-
@dataclass
11-
class InjectArgument:
12-
name: str
13-
verify: bool = False
14-
# rtype: type = None # TODO add type transformer
15-
16-
def get_value(self, type_spec, required):
17-
get_args: dict = {'key': self.name}
18-
if type_spec:
19-
get_args['type'] = type_spec
20-
if not required:
21-
get_args['default'] = None
22-
value = alphaconf.get(**get_args)
23-
return value
24-
25-
26-
class Injector:
27-
args: Dict[str, InjectArgument]
28-
prefix: str
29-
30-
def __init__(self, prefix: str = ""):
31-
if prefix and not prefix.endswith("."):
32-
prefix += "."
33-
self.prefix = prefix
34-
self.args = {}
35-
36-
def inject(self, name: str, optional, type, resolver):
37-
pass
38-
39-
def decorate(self, func: Callable[..., R]) -> Callable[[], R]:
40-
signature = inspect.signature(func)
41-
42-
def call():
43-
args = {} # TODO {**self.values}
44-
for name, iarg in self.args.items():
45-
param = signature.parameters.get(name, None)
46-
if not param:
47-
if iarg.verify:
48-
raise TypeError("Missing argument", name)
49-
continue
50-
arg_type = None
51-
if param.annotation is not param.empty and isinstance(param.annotation, type):
52-
arg_type = param.annotation
53-
required = param.default is param.empty
54-
value = iarg.get_value(arg_type, required)
55-
if value is None and not required:
56-
continue
57-
args[name] = value
58-
return func(**args)
59-
60-
return call
7+
from .internal.type_resolvers import type_from_annotation
8+
9+
__all__ = ["inject", "inject_auto"]
10+
11+
12+
class ParamDefaultsFunction:
13+
"""Function wrapper that injects default parameters"""
14+
15+
_arg_factory: Dict[str, Callable[[], Any]]
16+
17+
def __init__(self, func: Callable):
18+
self.func = func
19+
self.signature = inspect.signature(func)
20+
self._arg_factory = {}
21+
22+
def bind(self, name: str, factory: Callable[[], Any]):
23+
self._arg_factory[name] = factory
24+
25+
def __call__(self, *a, **kw):
26+
args = self.signature.bind_partial(*a, **kw).arguments
27+
kw.update(
28+
{name: factory() for name, factory in self._arg_factory.items() if name not in args}
29+
)
30+
return self.func(*a, **kw)
31+
32+
@staticmethod
33+
def wrap(func) -> "ParamDefaultsFunction":
34+
if isinstance(func, ParamDefaultsFunction):
35+
return func
36+
return functools.wraps(func)(ParamDefaultsFunction(func))
37+
38+
39+
def getter(
40+
key: str, ktype: Optional[type] = None, *, param: Optional[inspect.Parameter] = None
41+
) -> Callable[[], Any]:
42+
"""Factory function that calls alphaconf.get
43+
44+
The parameter from the signature can be given to extract the type to cast to
45+
and whether the configuration value is optional.
46+
47+
:param key: The key using in alphaconf.get
48+
:param ktype: Type to cast to
49+
:param param: The parameter object from the signature
50+
"""
51+
if ktype is None and param and (ptype := param.annotation) is not param.empty:
52+
ktype = next(type_from_annotation(ptype), None)
53+
if param is not None and param.default is not param.empty:
54+
xparam = param
55+
return (
56+
lambda: xparam.default
57+
if (value := alphaconf.get(key, ktype, default=None)) is None
58+
and xparam.default is not xparam.empty
59+
else value
60+
)
61+
return lambda: alphaconf.get(key, ktype)
62+
63+
64+
def inject(name: str, factory: Union[None, str, Callable[[], Any]]):
65+
"""Inject an argument to a function from a factory or alphaconf"""
66+
67+
def do_inject(func):
68+
f = ParamDefaultsFunction.wrap(func)
69+
if isinstance(factory, str) or factory is None:
70+
b = getter(factory or name, param=f.signature.parameters[name])
71+
else:
72+
b = factory
73+
f.bind(name, b)
74+
return f
75+
76+
return do_inject
77+
78+
79+
def inject_auto(*, prefix: str = "", ignore: set = set()):
80+
"""Inject automatically all paramters"""
81+
if prefix and not prefix.endswith("."):
82+
prefix += "."
83+
84+
def do_inject(func):
85+
f = ParamDefaultsFunction.wrap(func)
86+
for name, param in f.signature.parameters.items():
87+
if name in ignore:
88+
continue
89+
f.bind(name, getter(prefix + name, param=param))
90+
return f
91+
92+
return do_inject

alphaconf/internal/configuration.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import os
3-
import typing
43
import warnings
54
from enum import Enum
65
from typing import (
@@ -19,7 +18,7 @@
1918

2019
from omegaconf import Container, DictConfig, OmegaConf
2120

22-
from .type_resolvers import convert_to_type, pydantic
21+
from .type_resolvers import convert_to_type, pydantic, type_from_annotation
2322

2423
T = TypeVar('T')
2524

@@ -92,14 +91,14 @@ def get(self, key: Union[str, Type], type=None, *, default=raise_on_missing):
9291
)
9392
if value is raise_on_missing:
9493
if default is raise_on_missing:
95-
raise ValueError(f"No value for: {key}")
94+
raise KeyError(f"No value for: {key}")
9695
return default
9796
# check the returned type and convert when necessary
9897
if type is not None and isinstance(value, type):
9998
return value
10099
if isinstance(value, Container):
101100
value = OmegaConf.to_object(value)
102-
if type is not None and default is not None:
101+
if type is not None and value is not default:
103102
value = convert_to_type(value, type)
104103
return value
105104

@@ -110,12 +109,12 @@ def __get_type(self, key: Type, *, default=raise_on_missing):
110109
key_str = self.__type_path.get(key)
111110
if key_str is None:
112111
if default is raise_on_missing:
113-
raise ValueError(f"Key not found for type {key}")
112+
raise KeyError(f"Key not found for type {key}")
114113
return default
115114
try:
116115
value = self.get(key_str, key)
117116
self.__type_value = value
118-
except ValueError:
117+
except KeyError:
119118
if default is raise_on_missing:
120119
raise
121120
value = default
@@ -161,7 +160,7 @@ def setup_configuration(
161160
else:
162161
created_config = self.__prepare_config(conf, path=prefix)
163162
if not isinstance(created_config, DictConfig):
164-
raise ValueError("Failed to convert to a DictConfig")
163+
raise TypeError("Failed to convert to a DictConfig")
165164
config = created_config
166165
# add prefix and merge
167166
if prefix:
@@ -221,7 +220,7 @@ def __prepare_dictconfig(
221220
sub_configs = []
222221
for k, v in obj.items_ex(resolve=False):
223222
if not isinstance(k, str):
224-
raise ValueError("Expecting only str instances in dict")
223+
raise TypeError("Expecting only str instances in dict")
225224
if recursive:
226225
v = self.__prepare_config(v, path + k + ".")
227226
if '.' in k:
@@ -252,9 +251,6 @@ def __prepare_pydantic(self, obj, path):
252251
# pydantic instance, prepare helpers
253252
self.__prepare_pydantic(type(obj), path)
254253
return obj.model_dump(mode="json")
255-
# parse typing recursively for documentation
256-
for t in typing.get_args(obj):
257-
self.__prepare_pydantic(t, path)
258254
# check if not a type
259255
if not isinstance(obj, type):
260256
return obj
@@ -279,8 +275,9 @@ def __prepare_pydantic(self, obj, path):
279275
from alphaconf import SECRET_MASKS
280276

281277
SECRET_MASKS.append(lambda s: s == path)
282-
elif check_type and field.annotation:
283-
self.__prepare_pydantic(field.annotation, path + k + ".")
278+
elif check_type:
279+
for ftype in type_from_annotation(field.annotation):
280+
self.__prepare_pydantic(ftype, path + k + ".")
284281
return defaults
285282
return None
286283

alphaconf/internal/type_resolvers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import typing
23
from pathlib import Path
34

45
from omegaconf import OmegaConf
@@ -61,3 +62,12 @@ def convert_to_type(value, type):
6162
if pydantic:
6263
return pydantic.TypeAdapter(type).validate_python(value)
6364
return type(value)
65+
66+
67+
def type_from_annotation(annotation) -> typing.Generator[type, None, None]:
68+
"""Given an annotation (optional), figure out the types"""
69+
if isinstance(annotation, type) and annotation is not type(None):
70+
yield annotation
71+
else:
72+
for t in typing.get_args(annotation):
73+
yield from type_from_annotation(t)

tests/test_alphaconf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_setup_configuration():
9090

9191

9292
def test_setup_configuration_invalid():
93-
with pytest.raises(ValueError):
93+
with pytest.raises(TypeError):
9494
# invalid configuration (must be non-empty)
9595
alphaconf.setup_configuration(None)
9696

@@ -132,8 +132,8 @@ def test_app_environ(application):
132132
)
133133
application.setup_configuration(load_dotenv=False, env_prefixes=True)
134134
config = application.configuration
135-
with pytest.raises(ValueError):
136-
# XXX should not be loaded
135+
with pytest.raises(KeyError):
136+
# prefix with underscore only should be loaded
137137
config.get('xxx')
138138
assert config.get('testmyenv.x') == 'overwrite'
139139
assert config.get('testmyenv.y') == 'new'

tests/test_configuration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_cast(config):
4848
# cast bool into int
4949
assert config.get('b') is True
5050
assert config.get('b', int) == 1
51+
assert config.get('b', int, default=None) == 1
5152
# cast Path
5253
assert isinstance(config.get('home', Path), Path)
5354

@@ -71,7 +72,7 @@ def test_select_empty(config):
7172

7273
def test_select_required(config):
7374
assert config.get('z', default=None) is None
74-
with pytest.raises(ValueError):
75+
with pytest.raises(KeyError):
7576
print(config.get('z'))
7677
assert config.get('z', default='a') == 'a'
7778

@@ -80,7 +81,7 @@ def test_select_required_incomplete(config_req):
8081
# when we have a default, return it
8182
assert config_req.get('req', default='def') == 'def'
8283
# when required, raise missing
83-
with pytest.raises(ValueError):
84+
with pytest.raises(KeyError):
8485
print(config_req.get('req'))
8586

8687

0 commit comments

Comments
 (0)