Skip to content

Commit 2993d12

Browse files
committed
feat: different dispatch instances on same function
Signed-off-by: nstarman <[email protected]>
1 parent 70212b7 commit 2993d12

File tree

2 files changed

+225
-4
lines changed

2 files changed

+225
-4
lines changed

plum/dispatcher.py

+187-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import sys
2-
from dataclasses import dataclass, field
2+
from abc import ABCMeta, abstractmethod
3+
from dataclasses import dataclass, field, replace
34
from functools import partial
4-
from typing import Any, Dict, Optional, Tuple, TypeVar, Union, overload
5+
from itertools import chain
6+
from typing import Any, Dict, Optional, Tuple, TypeVar, Union, final, overload
57

68
from .function import Function
79
from .overload import get_overloads
@@ -19,7 +21,39 @@
1921

2022

2123
@dataclass(frozen=True, **_dataclass_kw_args)
22-
class Dispatcher:
24+
class AbstractDispatcher(metaclass=ABCMeta):
25+
"""An abstract dispatcher."""
26+
27+
@overload
28+
def __call__(self, method: T, precedence: int = ...) -> T: ...
29+
30+
@overload
31+
def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ...
32+
33+
@abstractmethod
34+
def __call__(
35+
self, method: Optional[T] = None, precedence: int = 0
36+
) -> Union[T, Callable[[T], T]]: ...
37+
38+
@abstractmethod
39+
def abstract(self, method: Callable) -> Function:
40+
"""Decorator for an abstract function definition. The abstract function
41+
definition does not implement any methods."""
42+
43+
@abstractmethod
44+
def multi(
45+
self, *signatures: Union[Signature, Tuple[TypeHint, ...]]
46+
) -> Callable[[Callable], Function]:
47+
"""Decorator to register multiple signatures at once."""
48+
49+
@abstractmethod
50+
def clear_cache(self) -> None:
51+
"""Clear cache."""
52+
53+
54+
@final
55+
@dataclass(frozen=True, **_dataclass_kw_args)
56+
class Dispatcher(AbstractDispatcher):
2357
"""A namespace for functions.
2458
2559
Args:
@@ -140,11 +174,160 @@ def _add_method(
140174
f.register(method, signature, precedence)
141175
return f
142176

143-
def clear_cache(self):
177+
def clear_cache(self) -> None:
144178
"""Clear cache."""
145179
for f in self.functions.values():
146180
f.clear_cache()
147181

182+
def __or__(self, other: "AbstractDispatcher") -> "DispatcherBundle":
183+
if not isinstance(other, AbstractDispatcher):
184+
raise ValueError(f"Cannot combine `Dispatcher` with `{type(other)}`.")
185+
return DispatcherBundle.from_dispatchers(self, other)
186+
187+
188+
@final
189+
@dataclass(frozen=True, **_dataclass_kw_args)
190+
class DispatcherBundle(AbstractDispatcher):
191+
"""A bundle of dispatchers.
192+
193+
Examples
194+
--------
195+
>>> from plum import Dispatcher, DispatcherBundle
196+
197+
>>> dispatch1 = Dispatcher()
198+
>>> dispatch2 = Dispatcher()
199+
200+
>>> dispatchbundle = dispatch1 | dispatch2
201+
202+
Some Notes:
203+
204+
At least one dispatcher must be provided to `DispatcherBundle`.
205+
206+
>>> try:
207+
... DispatcherBundle()
208+
... except ValueError as e:
209+
... print(e)
210+
At least one dispatcher must be provided to DispatcherBundle.
211+
212+
213+
A `DispatcherBundle` can be created from a sequence of dispatchers.
214+
215+
>>> dispatchbundle = DispatcherBundle.from_dispatchers(dispatch1, dispatch2)
216+
217+
A nested `DispatcherBundle` can be flattened.
218+
219+
>>> dispatch3 = Dispatcher()
220+
>>> dispatchbundle = DispatcherBundle.from_dispatchers(dispatchbundle, dispatch3)
221+
>>> dispatchbundle
222+
DispatcherBundle(dispatchers=(
223+
<DispatcherBundle(dispatchers=(
224+
<Dispatcher functions={}, classes={} warn_redefinition=False>,
225+
<Dispatcher functions={}, classes={} warn_redefinition=False>))>,
226+
<Dispatcher functions={}, classes={} warn_redefinition=False>)
227+
)
228+
229+
230+
>>> dispatchbundle = dispatchbundle.flatten()
231+
DispatcherBundle(dispatchers=(
232+
<Dispatcher functions={}, classes={} warn_redefinition=False>,
233+
<Dispatcher functions={}, classes={} warn_redefinition=False>,
234+
<Dispatcher functions={}, classes={} warn_redefinition=False>)
235+
)
236+
237+
DispatchBundles can be combined with `|`.
238+
239+
>>> dispatch4 = Dispatcher()
240+
>>> dispatchbundle1 = dispatch1 | dispatch2
241+
>>> dispatchbundle2 = dispatch3 | dispatch4
242+
>>> dispatchbundle = dispatchbundle1 | dispatchbundle2
243+
>>> dispatchbundle
244+
DispatcherBundle(dispatchers=(
245+
<DispatcherBundle(dispatchers=(
246+
<Dispatcher functions={}, classes={} warn_redefinition=False>,
247+
<Dispatcher functions={}, classes={} warn_redefinition=False>))>,
248+
<DispatcherBundle(dispatchers=(
249+
<Dispatcher functions={}, classes={} warn_redefinition=False>,
250+
<Dispatcher functions={}, classes={} warn_redefinition=False>))>)
251+
)
252+
253+
"""
254+
255+
dispatchers: Tuple[AbstractDispatcher, ...]
256+
257+
def __post_init__(self) -> None:
258+
if not self.dispatchers:
259+
msg = "At least one dispatcher must be provided to DispatcherBundle."
260+
raise ValueError(msg)
261+
262+
@classmethod
263+
def from_dispatchers(cls, *dispatchers: AbstractDispatcher) -> "DispatcherBundle":
264+
"""Create a `DispatcherBundle` from a sequence of dispatchers.
265+
266+
This also flattens nested `DispatcherBundle`s.
267+
"""
268+
269+
return cls(dispatchers).flatten()
270+
271+
def flatten(self) -> "DispatcherBundle":
272+
"""Flatten the bundle."""
273+
274+
def as_seq(x: AbstractDispatcher) -> Tuple[AbstractDispatcher, ...]:
275+
return x.dispatchers if isinstance(x, DispatcherBundle) else (x,)
276+
277+
return replace(
278+
self, dispatchers=tuple(chain.from_iterable(map(as_seq, self.dispatchers)))
279+
)
280+
281+
@overload
282+
def __call__(self, method: T, precedence: int = ...) -> T: ...
283+
284+
@overload
285+
def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ...
286+
287+
def __call__(
288+
self, method: Optional[T] = None, precedence: int = 0
289+
) -> Union[T, Callable[[T], T]]:
290+
for dispatcher in self.dispatchers:
291+
f = dispatcher(method, precedence=precedence)
292+
return f
293+
294+
def abstract(self, method: Callable) -> Function:
295+
"""Decorator for an abstract function definition. The abstract function
296+
definition does not implement any methods."""
297+
for dispatcher in self.dispatchers:
298+
f = dispatcher.abstract(method)
299+
return f
300+
301+
def multi(
302+
self, *signatures: Union[Signature, Tuple[TypeHint, ...]]
303+
) -> Callable[[Callable], Function]:
304+
"""Decorator to register multiple signatures at once.
305+
306+
Args:
307+
*signatures (tuple or :class:`.signature.Signature`): Signatures to
308+
register.
309+
310+
Returns:
311+
function: Decorator.
312+
"""
313+
314+
def decorator(method: Callable) -> Function:
315+
for dispatcher in self.dispatchers:
316+
f = dispatcher.multi(*signatures)(method)
317+
return f
318+
319+
return decorator
320+
321+
def clear_cache(self) -> None:
322+
"""Clear cache."""
323+
for dispatcher in self.dispatchers:
324+
dispatcher.clear_cache()
325+
326+
def __or__(self, other: "AbstractDispatcher") -> "DispatcherBundle":
327+
if not isinstance(other, AbstractDispatcher):
328+
return NotImplemented
329+
return self.from_dispatchers(self, other)
330+
148331

149332
def clear_all_cache():
150333
"""Clear all cache, including the cache of subclass checks. This should be called

tests/test_dispatcher.py

+38
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from numbers import Number
2+
from types import SimpleNamespace
3+
14
import pytest
25

36
from plum import Dispatcher
@@ -70,3 +73,38 @@ def f(x):
7073

7174
assert f.__doc__ == "Docs"
7275
assert f.methods == []
76+
77+
78+
def test_multiple_dispatchers_on_same_function():
79+
dispatch1 = Dispatcher()
80+
dispatch2 = Dispatcher()
81+
82+
@dispatch1.abstract
83+
def f(x: Number, y: Number):
84+
return x - 2 * y
85+
86+
@dispatch2.abstract
87+
def f(x: Number, y: Number):
88+
return x - y
89+
90+
@(dispatch2 | dispatch1)
91+
def f(x: int, y: float):
92+
return x + y
93+
94+
@dispatch1
95+
def f(x: str):
96+
return x
97+
98+
ns1 = SimpleNamespace(f=f)
99+
100+
@dispatch2
101+
def f(x: int):
102+
return x
103+
104+
ns2 = SimpleNamespace(f=f)
105+
106+
assert ns1.f("a") == "a"
107+
assert ns1.f(1, 1.0) == 2.0
108+
109+
assert ns2.f(1) == 1
110+
assert ns2.f(1, 1.0) == 2.0

0 commit comments

Comments
 (0)