-
| Hey there! For context - I've tried asking the same question on typing Gitter and had no luck there since I believe the visibility is quite low and I got no comments to the message up to now. I am taking my luck here just to understand if what I'm trying to achieve is possible at all or if I shouldn't bother trying. The idea seems simple - we need to preserve all types of a function the decorator with arguments is being applied to. However, the issue comes when the decorator is supposed to work with both sync and async functions. The solution I've come close with is the following: However,  I have also tried  | 
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
| Here's a potential solution. It type checks without errors in pyright. Mypy produces a few overload-related errors that appear to be false positives, so you could add  Code sample in pyright playground from typing import Awaitable, Callable, Protocol, overload
from typing_extensions import TypeIs
from inspect import iscoroutinefunction
from functools import wraps
def is_coroutine[**P, R](
    func: Callable[P, R | Awaitable[R]],
) -> TypeIs[Callable[P, Awaitable[R]]]:
    return iscoroutinefunction(func)
class SyncOrAsync(Protocol):
    @overload
    def __call__[**P, R](
        self, _func: Callable[P, Awaitable[R]]
    ) -> Callable[P, Awaitable[R]]:
        ...
    @overload
    def __call__[**P, R](self, _func: Callable[P, R]) -> Callable[P, R]:
        ...
    def __call__[**P, R](
        self, _func: Callable[P, Awaitable[R]] | Callable[P, R]
    ) -> Callable[P, Awaitable[R]] | Callable[P, R]:
        ...
def my_dec(param1: str, param2: int | None = None) -> SyncOrAsync:
    @overload
    def decorator[**P, R](
        _func: Callable[P, Awaitable[R]],
    ) -> Callable[P, Awaitable[R]]:
        ...
    @overload
    def decorator[**P, R](
        _func: Callable[P, R],
    ) -> Callable[P, R]:
        ...
    def decorator[**P, R](
        _func: Callable[P, Awaitable[R]] | Callable[P, R],
    ) -> Callable[P, Awaitable[R]] | Callable[P, R]:
        if is_coroutine(_func):
            _awaitable_func = _func
            @wraps(_awaitable_func)
            async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
                return await _awaitable_func(*args, **kwargs)
            return _async_wrapper
        else:
            @wraps(_func)
            def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
                return _func(*args, **kwargs)
            return _sync_wrapper
    return decorator
@my_dec(param1="test")
def test_sync() -> str:
    return "test return"
v_sync = test_sync()
@my_dec(param1="test")
async def test_async() -> str:
    return "test return"
v_async = test_async() | 
Beta Was this translation helpful? Give feedback.
Here's a potential solution. It type checks without errors in pyright. Mypy produces a few overload-related errors that appear to be false positives, so you could add
# type: ignorecomments to suppress them.Code sample in pyright playground