Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 14d8ed9

Browse files
authoredAug 25, 2021
FactoryAggregate - non string keys (#496)
* Improve FactoryAggregate typing stub * Add implementation, typing stubs, and tests * Update changelog * Fix deepcopying * Add example * Update docs * Fix errors formatting for pypy3
1 parent 6af8181 commit 14d8ed9

File tree

9 files changed

+4105
-3804
lines changed

9 files changed

+4105
-3804
lines changed
 

‎docs/main/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ follows `Semantic versioning`_
99

1010
Development version
1111
-------------------
12+
- Add support of non-string keys for ``FactoryAggregate`` provider.
13+
- Improve ``FactoryAggregate`` typing stub.
1214
- Improve resource subclasses typing and make shutdown definition optional
1315
`PR #492 <https://github.com/ets-labs/python-dependency-injector/pull/492>`_.
1416
Thanks to `@EdwardBlair <https://github.com/EdwardBlair>`_ for suggesting the improvement.

‎docs/providers/factory.rst

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,11 @@ provider with two peculiarities:
148148
Factory aggregate
149149
-----------------
150150

151-
:py:class:`FactoryAggregate` provider aggregates multiple factories. When you call the
152-
``FactoryAggregate`` it delegates the call to one of the factories.
151+
:py:class:`FactoryAggregate` provider aggregates multiple factories.
153152

154-
The aggregated factories are associated with the string names. When you call the
155-
``FactoryAggregate`` you have to provide one of the these names as a first argument.
156-
``FactoryAggregate`` looks for the factory with a matching name and delegates it the work. The
157-
rest of the arguments are passed to the delegated ``Factory``.
153+
The aggregated factories are associated with the string keys. When you call the
154+
``FactoryAggregate`` you have to provide one of the these keys as a first argument.
155+
``FactoryAggregate`` looks for the factory with a matching key and calls it with the rest of the arguments.
158156

159157
.. image:: images/factory_aggregate.png
160158
:width: 100%
@@ -165,17 +163,35 @@ rest of the arguments are passed to the delegated ``Factory``.
165163
:lines: 3-
166164
:emphasize-lines: 33-37,47
167165

168-
You can get a dictionary of the aggregated factories using the ``.factories`` attribute of the
169-
``FactoryAggregate``. To get a game factories dictionary from the previous example you can use
166+
You can get a dictionary of the aggregated factories using the ``.factories`` attribute.
167+
To get a game factories dictionary from the previous example you can use
170168
``game_factory.factories`` attribute.
171169

172170
You can also access an aggregated factory as an attribute. To create the ``Chess`` object from the
173-
previous example you can do ``chess = game_factory.chess('John', 'Jane')``.
171+
previous example you can do ``chess = game_factory.chess("John", "Jane")``.
174172

175173
.. note::
176174
You can not override the ``FactoryAggregate`` provider.
177175

178176
.. note::
179177
When you inject the ``FactoryAggregate`` provider it is passed "as is".
180178

179+
To use non-string keys or keys with ``.`` and ``-`` you can provide a dictionary as a positional argument:
180+
181+
.. code-block:: python
182+
183+
providers.FactoryAggregate({
184+
SomeClass: providers.Factory(...),
185+
"key.with.periods": providers.Factory(...),
186+
"key-with-dashes": providers.Factory(...),
187+
})
188+
189+
Example:
190+
191+
.. literalinclude:: ../../examples/providers/factory_aggregate_non_string_keys.py
192+
:language: python
193+
:lines: 3-
194+
:emphasize-lines: 30-33,39-40
195+
196+
181197
.. disqus::
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""`FactoryAggregate` provider with non-string keys example."""
2+
3+
from dependency_injector import containers, providers
4+
5+
6+
class Command:
7+
...
8+
9+
10+
class CommandA(Command):
11+
...
12+
13+
14+
class CommandB(Command):
15+
...
16+
17+
18+
class Handler:
19+
...
20+
21+
22+
class HandlerA(Handler):
23+
...
24+
25+
26+
class HandlerB(Handler):
27+
...
28+
29+
30+
class Container(containers.DeclarativeContainer):
31+
32+
handler_factory = providers.FactoryAggregate({
33+
CommandA: providers.Factory(HandlerA),
34+
CommandB: providers.Factory(HandlerB),
35+
})
36+
37+
38+
if __name__ == "__main__":
39+
container = Container()
40+
41+
handler_a = container.handler_factory(CommandA)
42+
handler_b = container.handler_factory(CommandB)
43+
44+
assert isinstance(handler_a, HandlerA)
45+
assert isinstance(handler_b, HandlerB)

‎src/dependency_injector/providers.c

Lines changed: 3903 additions & 3764 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎src/dependency_injector/providers.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ cdef class FactoryDelegate(Delegate):
142142
cdef class FactoryAggregate(Provider):
143143
cdef dict __factories
144144

145-
cdef Factory __get_factory(self, str factory_name)
145+
cdef Factory __get_factory(self, object factory_name)
146146

147147

148148
# Singleton providers

‎src/dependency_injector/providers.pyi

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,19 +282,19 @@ class FactoryDelegate(Delegate):
282282
def __init__(self, factory: Factory): ...
283283

284284

285-
class FactoryAggregate(Provider):
286-
def __init__(self, **factories: Factory): ...
287-
def __getattr__(self, factory_name: str) -> Factory: ...
285+
class FactoryAggregate(Provider[T]):
286+
def __init__(self, dict_: Optional[_Dict[Any, Factory[T]]] = None, **factories: Factory[T]): ...
287+
def __getattr__(self, factory_name: Any) -> Factory[T]: ...
288288

289289
@overload
290-
def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Any: ...
290+
def __call__(self, factory_name: Any, *args: Injection, **kwargs: Injection) -> T: ...
291291
@overload
292-
def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[Any]: ...
293-
def async_(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[Any]: ...
292+
def __call__(self, factory_name: Any, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
293+
def async_(self, factory_name: Any, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
294294

295295
@property
296-
def factories(self) -> _Dict[str, Factory]: ...
297-
def set_factories(self, **factories: Factory) -> FactoryAggregate: ...
296+
def factories(self) -> _Dict[Any, Factory[T]]: ...
297+
def set_factories(self, dict_: Optional[_Dict[Any, Factory[T]]] = None, **factories: Factory[T]) -> FactoryAggregate[T]: ...
298298

299299

300300
class BaseSingleton(Provider[T]):

‎src/dependency_injector/providers.pyx

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,10 +2486,10 @@ cdef class FactoryAggregate(Provider):
24862486

24872487
__IS_DELEGATED__ = True
24882488

2489-
def __init__(self, **factories):
2489+
def __init__(self, factories_dict_=None, **factories_kwargs):
24902490
"""Initialize provider."""
24912491
self.__factories = {}
2492-
self.set_factories(**factories)
2492+
self.set_factories(factories_dict_, **factories_kwargs)
24932493
super(FactoryAggregate, self).__init__()
24942494

24952495
def __deepcopy__(self, memo):
@@ -2499,7 +2499,7 @@ cdef class FactoryAggregate(Provider):
24992499
return copied
25002500

25012501
copied = _memorized_duplicate(self, memo)
2502-
copied.set_factories(**deepcopy(self.factories, memo))
2502+
copied.set_factories(deepcopy(self.factories, memo))
25032503

25042504
self._copy_overridings(copied, memo)
25052505

@@ -2521,13 +2521,23 @@ cdef class FactoryAggregate(Provider):
25212521
"""Return dictionary of factories, read-only."""
25222522
return self.__factories
25232523

2524-
def set_factories(self, **factories):
2524+
def set_factories(self, factories_dict_=None, **factories_kwargs):
25252525
"""Set factories."""
2526+
factories = {}
2527+
factories.update(factories_kwargs)
2528+
if factories_dict_:
2529+
factories.update(factories_dict_)
2530+
25262531
for factory in factories.values():
25272532
if isinstance(factory, Factory) is False:
25282533
raise Error(
2529-
'{0} can aggregate only instances of {1}, given - {2}'
2530-
.format(self.__class__, Factory, factory))
2534+
'{0} can aggregate only instances of {1}, given - {2}'.format(
2535+
self.__class__,
2536+
Factory,
2537+
factory,
2538+
),
2539+
)
2540+
25312541
self.__factories = factories
25322542
return self
25332543

@@ -2539,8 +2549,7 @@ cdef class FactoryAggregate(Provider):
25392549
:return: Overriding context.
25402550
:rtype: :py:class:`OverridingContext`
25412551
"""
2542-
raise Error(
2543-
'{0} providers could not be overridden'.format(self.__class__))
2552+
raise Error('{0} providers could not be overridden'.format(self.__class__))
25442553

25452554
@property
25462555
def related(self):
@@ -2561,12 +2570,10 @@ cdef class FactoryAggregate(Provider):
25612570

25622571
return self.__get_factory(factory_name)(*args, **kwargs)
25632572

2564-
cdef Factory __get_factory(self, str factory_name):
2565-
if factory_name not in self.__factories:
2566-
raise NoSuchProviderError(
2567-
'{0} does not contain factory with name {1}'.format(
2568-
self, factory_name))
2569-
return <Factory> self.__factories[factory_name]
2573+
cdef Factory __get_factory(self, object factory_key):
2574+
if factory_key not in self.__factories:
2575+
raise NoSuchProviderError('{0} does not contain factory with name {1}'.format(self, factory_key))
2576+
return <Factory> self.__factories[factory_key]
25702577

25712578

25722579
cdef class BaseSingleton(Provider):

‎tests/typing/factory.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,26 @@ def create(cls) -> Animal:
5555
provider8 = providers.FactoryDelegate(providers.Factory(object))
5656

5757
# Test 9: to check FactoryAggregate provider
58-
provider9 = providers.FactoryAggregate(
59-
a=providers.Factory(object),
60-
b=providers.Factory(object),
58+
provider9: providers.FactoryAggregate[str] = providers.FactoryAggregate(
59+
a=providers.Factory(str, "str1"),
60+
b=providers.Factory(str, "str2"),
6161
)
62-
factory_a_9: providers.Factory = provider9.a
63-
factory_b_9: providers.Factory = provider9.b
64-
val9: Any = provider9('a')
62+
factory_a_9: providers.Factory[str] = provider9.a
63+
factory_b_9: providers.Factory[str] = provider9.b
64+
val9: str = provider9('a')
65+
66+
provider9_set_non_string_keys: providers.FactoryAggregate[str] = providers.FactoryAggregate()
67+
provider9_set_non_string_keys.set_factories({Cat: providers.Factory(str, "str")})
68+
factory_set_non_string_9: providers.Factory[str] = provider9_set_non_string_keys.factories[Cat]
69+
70+
provider9_new_non_string_keys: providers.FactoryAggregate[str] = providers.FactoryAggregate(
71+
{Cat: providers.Factory(str, "str")},
72+
)
73+
factory_new_non_string_9: providers.Factory[str] = provider9_new_non_string_keys.factories[Cat]
74+
75+
provider9_no_explicit_typing = providers.FactoryAggregate(a=providers.Factory(str, "str"))
76+
provider9_no_explicit_typing_factory: providers.Factory[str] = provider9_no_explicit_typing.factories["a"]
77+
provider9_no_explicit_typing_object: str = provider9_no_explicit_typing("a")
6578

6679
# Test 10: to check the explicit typing
6780
factory10: providers.Provider[Animal] = providers.Factory(Cat)

‎tests/unit/providers/test_factories_py2_py3.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66

77
from dependency_injector import (
8+
containers,
89
providers,
910
errors,
1011
)
@@ -498,14 +499,44 @@ def setUp(self):
498499
self.example_b_factory = providers.Factory(self.ExampleB)
499500
self.factory_aggregate = providers.FactoryAggregate(
500501
example_a=self.example_a_factory,
501-
example_b=self.example_b_factory)
502+
example_b=self.example_b_factory,
503+
)
502504

503505
def test_is_provider(self):
504506
self.assertTrue(providers.is_provider(self.factory_aggregate))
505507

506508
def test_is_delegated_provider(self):
507509
self.assertTrue(providers.is_delegated(self.factory_aggregate))
508510

511+
def test_init_with_non_string_keys(self):
512+
factory = providers.FactoryAggregate({
513+
self.ExampleA: self.example_a_factory,
514+
self.ExampleB: self.example_b_factory,
515+
})
516+
517+
object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4)
518+
object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44)
519+
520+
self.assertIsInstance(object_a, self.ExampleA)
521+
self.assertEqual(object_a.init_arg1, 1)
522+
self.assertEqual(object_a.init_arg2, 2)
523+
self.assertEqual(object_a.init_arg3, 3)
524+
self.assertEqual(object_a.init_arg4, 4)
525+
526+
self.assertIsInstance(object_b, self.ExampleB)
527+
self.assertEqual(object_b.init_arg1, 11)
528+
self.assertEqual(object_b.init_arg2, 22)
529+
self.assertEqual(object_b.init_arg3, 33)
530+
self.assertEqual(object_b.init_arg4, 44)
531+
532+
self.assertEqual(
533+
factory.factories,
534+
{
535+
self.ExampleA: self.example_a_factory,
536+
self.ExampleB: self.example_b_factory,
537+
},
538+
)
539+
509540
def test_init_with_not_a_factory(self):
510541
with self.assertRaises(errors.Error):
511542
providers.FactoryAggregate(
@@ -528,7 +559,37 @@ def test_init_optional_factories(self):
528559
self.assertIsInstance(provider('example_a'), self.ExampleA)
529560
self.assertIsInstance(provider('example_b'), self.ExampleB)
530561

531-
def test_set_provides_returns_self(self):
562+
def test_set_factories_with_non_string_keys(self):
563+
factory = providers.FactoryAggregate()
564+
factory.set_factories({
565+
self.ExampleA: self.example_a_factory,
566+
self.ExampleB: self.example_b_factory,
567+
})
568+
569+
object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4)
570+
object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44)
571+
572+
self.assertIsInstance(object_a, self.ExampleA)
573+
self.assertEqual(object_a.init_arg1, 1)
574+
self.assertEqual(object_a.init_arg2, 2)
575+
self.assertEqual(object_a.init_arg3, 3)
576+
self.assertEqual(object_a.init_arg4, 4)
577+
578+
self.assertIsInstance(object_b, self.ExampleB)
579+
self.assertEqual(object_b.init_arg1, 11)
580+
self.assertEqual(object_b.init_arg2, 22)
581+
self.assertEqual(object_b.init_arg3, 33)
582+
self.assertEqual(object_b.init_arg4, 44)
583+
584+
self.assertEqual(
585+
factory.factories,
586+
{
587+
self.ExampleA: self.example_a_factory,
588+
self.ExampleB: self.example_b_factory,
589+
},
590+
)
591+
592+
def test_set_factories_returns_self(self):
532593
provider = providers.FactoryAggregate()
533594
self.assertIs(provider.set_factories(example_a=self.example_a_factory), provider)
534595

@@ -603,6 +664,24 @@ def test_deepcopy(self):
603664
self.assertIsInstance(self.factory_aggregate.example_b, type(provider_copy.example_b))
604665
self.assertIs(self.factory_aggregate.example_b.cls, provider_copy.example_b.cls)
605666

667+
def test_deepcopy_with_non_string_keys(self):
668+
factory_aggregate = providers.FactoryAggregate({
669+
self.ExampleA: self.example_a_factory,
670+
self.ExampleB: self.example_b_factory,
671+
})
672+
provider_copy = providers.deepcopy(factory_aggregate)
673+
674+
self.assertIsNot(factory_aggregate, provider_copy)
675+
self.assertIsInstance(provider_copy, type(factory_aggregate))
676+
677+
self.assertIsNot(factory_aggregate.factories[self.ExampleA], provider_copy.factories[self.ExampleA])
678+
self.assertIsInstance(factory_aggregate.factories[self.ExampleA], type(provider_copy.factories[self.ExampleA]))
679+
self.assertIs(factory_aggregate.factories[self.ExampleA].cls, provider_copy.factories[self.ExampleA].cls)
680+
681+
self.assertIsNot(factory_aggregate.factories[self.ExampleB], provider_copy.factories[self.ExampleB])
682+
self.assertIsInstance(factory_aggregate.factories[self.ExampleB], type(provider_copy.factories[self.ExampleB]))
683+
self.assertIs(factory_aggregate.factories[self.ExampleB].cls, provider_copy.factories[self.ExampleB].cls)
684+
606685
def test_repr(self):
607686
self.assertEqual(repr(self.factory_aggregate),
608687
'<dependency_injector.providers.'

0 commit comments

Comments
 (0)
Please sign in to comment.