Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/authors.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* [Rafael Rêgo](mailto:[email protected])
* [Raphael Schrader](mailto:[email protected])
* [João S. O. Bueno](mailto:[email protected])
* [Rodrigo Nogueira](mailto:[email protected])


## Scaffolding
Expand Down
71 changes: 64 additions & 7 deletions statemachine/signature.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from functools import partial
from inspect import BoundArguments
from inspect import Parameter
Expand All @@ -6,6 +8,12 @@
from itertools import chain
from types import MethodType
from typing import Any
from typing import FrozenSet
from typing import Optional
from typing import Tuple

BindCacheKey = Tuple[int, FrozenSet[str]]
BindTemplate = Tuple[Tuple[str, ...], Optional[str]] # noqa: UP007


def _make_key(method):
Expand Down Expand Up @@ -44,6 +52,11 @@ def cached_function(cls, method):

class SignatureAdapter(Signature):
is_coroutine: bool = False
_bind_cache: dict[BindCacheKey, BindTemplate]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._bind_cache = {}

@classmethod
@signature_cache
Expand All @@ -60,19 +73,57 @@ def from_callable(cls, method):
adapter.is_coroutine = iscoroutinefunction(method)
return adapter

def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C901
def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments:
cache_key: BindCacheKey = (len(args), frozenset(kwargs.keys()))
template = self._bind_cache.get(cache_key)

if template is not None:
return self._fast_bind(args, kwargs, template)

result = self._full_bind(cache_key, *args, **kwargs)
return result

def _fast_bind(
self,
args: tuple[Any, ...],
kwargs: dict[str, Any],
template: BindTemplate,
) -> BoundArguments:
param_names, kwargs_param_name = template
arguments: dict[str, Any] = {}

for i, name in enumerate(param_names):
if i < len(args):
arguments[name] = args[i]
else:
arguments[name] = kwargs.get(name)

if kwargs_param_name is not None:
matched = set(param_names)
arguments[kwargs_param_name] = {k: v for k, v in kwargs.items() if k not in matched}

return BoundArguments(self, arguments) # type: ignore[arg-type]

def _full_bind( # noqa: C901
self,
cache_key: BindCacheKey,
*args: Any,
**kwargs: Any,
) -> BoundArguments:
"""Get a BoundArguments object, that maps the passed `args`
and `kwargs` to the function's signature. It avoids to raise `TypeError`
trying to fill all the required arguments and ignoring the unknown ones.

Adapted from the internal `inspect.Signature._bind`.
"""
arguments = {}
arguments: dict[str, Any] = {}
param_names_used: list[str] = []

parameters = iter(self.parameters.values())
arg_vals = iter(args)
parameters_ex: Any = ()
kwargs_param = None
kwargs_param_name: str | None = None

while True:
# Let's iterate through the positional arguments and corresponding
Expand All @@ -95,8 +146,7 @@ def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C
elif param.name in kwargs:
if param.kind == Parameter.POSITIONAL_ONLY:
msg = (
"{arg!r} parameter is positional only, "
"but was passed as a keyword"
"{arg!r} parameter is positional only, but was passed as a keyword"
)
msg = msg.format(arg=param.name)
raise TypeError(msg) from None
Expand Down Expand Up @@ -141,12 +191,14 @@ def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C
values = [arg_val]
values.extend(arg_vals)
arguments[param.name] = tuple(values)
param_names_used.append(param.name)
break

if param.name in kwargs and param.kind != Parameter.POSITIONAL_ONLY:
arguments[param.name] = kwargs.pop(param.name)
else:
arguments[param.name] = arg_val
param_names_used.append(param.name)

# Now, we iterate through the remaining parameters to process
# keyword arguments
Expand All @@ -172,14 +224,19 @@ def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C
# arguments.
pass
else:
arguments[param_name] = arg_val #
arguments[param_name] = arg_val
param_names_used.append(param_name)

if kwargs:
if kwargs_param is not None:
# Process our '**kwargs'-like parameter
arguments[kwargs_param.name] = kwargs # type: ignore [assignment]
arguments[kwargs_param.name] = kwargs # type: ignore[assignment]
kwargs_param_name = kwargs_param.name
else:
# 'ignoring we got an unexpected keyword argument'
pass

return BoundArguments(self, arguments) # type: ignore [arg-type]
template: BindTemplate = (tuple(param_names_used), kwargs_param_name)
self._bind_cache[cache_key] = template

return BoundArguments(self, arguments) # type: ignore[arg-type]
36 changes: 35 additions & 1 deletion tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from functools import partial

import pytest

from statemachine.dispatcher import callable_method
from statemachine.signature import SignatureAdapter


def single_positional_param(a):
Expand Down Expand Up @@ -162,3 +162,37 @@ def test_support_for_partial(self):

assert wrapped_func("A", "B") == ("A", "B", "activated")
assert wrapped_func.__name__ == positional_and_kw_arguments.__name__


def named_and_kwargs(source, **kwargs):
return source, kwargs


class TestCachedBindExpected:
"""Tests that exercise the cache fast-path by calling the same
wrapped function twice with the same argument shape."""

def setup_method(self):
SignatureAdapter.from_callable.clear_cache()

def test_named_param_not_leaked_into_kwargs(self):
"""Named params should not appear in the **kwargs dict on cache hit."""
wrapped = callable_method(named_and_kwargs)

# 1st call: cache miss -> _full_bind
result1 = wrapped(source="A", target="B", event="go")
assert result1 == ("A", {"target": "B", "event": "go"})

# 2nd call: cache hit -> _fast_bind
result2 = wrapped(source="X", target="Y", event="stop")
assert result2 == ("X", {"target": "Y", "event": "stop"})

def test_kwargs_only_receives_unmatched_keys_with_positional(self):
"""When mixing positional and keyword args with **kwargs."""
wrapped = callable_method(named_and_kwargs)

result1 = wrapped("A", target="B")
assert result1 == ("A", {"target": "B"})

result2 = wrapped("X", target="Y")
assert result2 == ("X", {"target": "Y"})
Loading