|
| 1 | +import inspect |
1 | 2 | from abc import ABC, abstractmethod |
2 | 3 | from contextlib import contextmanager |
3 | 4 | from contextvars import ContextVar |
4 | 5 | from dataclasses import dataclass |
5 | 6 | from datetime import timedelta, datetime |
6 | | -from typing import Iterator |
| 7 | +from enum import Enum |
| 8 | +from functools import update_wrapper |
| 9 | +from inspect import signature, Parameter |
| 10 | +from typing import Iterator, TypedDict, Unpack, Callable, Type, ParamSpec, TypeVar, Generic, get_type_hints, \ |
| 11 | + Any, overload |
7 | 12 |
|
8 | 13 | from cadence import Client |
9 | 14 |
|
@@ -59,3 +64,100 @@ def is_set() -> bool: |
59 | 64 | @staticmethod |
60 | 65 | def get() -> 'ActivityContext': |
61 | 66 | return ActivityContext._var.get() |
| 67 | + |
| 68 | + |
| 69 | +@dataclass(frozen=True) |
| 70 | +class ActivityParameter: |
| 71 | + name: str |
| 72 | + type_hint: Type | None |
| 73 | + default_value: Any | None |
| 74 | + |
| 75 | +class ExecutionStrategy(Enum): |
| 76 | + ASYNC = "async" |
| 77 | + THREAD_POOL = "thread_pool" |
| 78 | + |
| 79 | +class ActivityDefinitionOptions(TypedDict, total=False): |
| 80 | + name: str |
| 81 | + aliases: list[str] |
| 82 | + |
| 83 | +P = ParamSpec('P') |
| 84 | +T = TypeVar('T') |
| 85 | + |
| 86 | +class ActivityDefinition(Generic[P, T]): |
| 87 | + def __init__(self, wrapped: Callable[P, T], name: str, strategy: ExecutionStrategy, params: list[ActivityParameter]): |
| 88 | + self._wrapped = wrapped |
| 89 | + self._name = name |
| 90 | + self._strategy = strategy |
| 91 | + self._params = params |
| 92 | + update_wrapper(self, wrapped) |
| 93 | + |
| 94 | + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: |
| 95 | + return self._wrapped(*args, **kwargs) |
| 96 | + |
| 97 | + @property |
| 98 | + def name(self) -> str: |
| 99 | + return self._name |
| 100 | + |
| 101 | + @property |
| 102 | + def strategy(self) -> ExecutionStrategy: |
| 103 | + return self._strategy |
| 104 | + |
| 105 | + @property |
| 106 | + def params(self) -> list[ActivityParameter]: |
| 107 | + return self._params |
| 108 | + |
| 109 | + @staticmethod |
| 110 | + def wrap(fn: Callable[P, T], opts: ActivityDefinitionOptions) -> 'ActivityDefinition[P, T]': |
| 111 | + name = fn.__qualname__ |
| 112 | + if "name" in opts and opts["name"]: |
| 113 | + name = opts["name"] |
| 114 | + |
| 115 | + strategy = ExecutionStrategy.THREAD_POOL |
| 116 | + if inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__): # type: ignore |
| 117 | + strategy = ExecutionStrategy.ASYNC |
| 118 | + |
| 119 | + params = _get_params(fn) |
| 120 | + return ActivityDefinition(fn, name, strategy, params) |
| 121 | + |
| 122 | + |
| 123 | +ActivityDecorator = Callable[[Callable[P, T]], ActivityDefinition[P, T]] |
| 124 | + |
| 125 | +@overload |
| 126 | +def defn(fn: Callable[P, T]) -> ActivityDefinition[P, T]: |
| 127 | + ... |
| 128 | + |
| 129 | +@overload |
| 130 | +def defn(**kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator: |
| 131 | + ... |
| 132 | + |
| 133 | +def defn(fn: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator | ActivityDefinition[P, T]: |
| 134 | + options = ActivityDefinitionOptions(**kwargs) |
| 135 | + def decorator(inner_fn: Callable[P, T]) -> ActivityDefinition[P, T]: |
| 136 | + return ActivityDefinition.wrap(inner_fn, options) |
| 137 | + |
| 138 | + if fn is not None: |
| 139 | + return decorator(fn) |
| 140 | + |
| 141 | + return decorator |
| 142 | + |
| 143 | + |
| 144 | +def _get_params(fn: Callable) -> list[ActivityParameter]: |
| 145 | + args = signature(fn).parameters |
| 146 | + hints = get_type_hints(fn) |
| 147 | + result = [] |
| 148 | + for name, param in args.items(): |
| 149 | + # "unbound functions" aren't a thing in the Python spec. Filter out the self parameter and hope they followed |
| 150 | + # the convention. |
| 151 | + if param.name == "self": |
| 152 | + continue |
| 153 | + default = None |
| 154 | + if param.default != Parameter.empty: |
| 155 | + default = param.default |
| 156 | + if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): |
| 157 | + type_hint = hints.get(name, None) |
| 158 | + result.append(ActivityParameter(name, type_hint, default)) |
| 159 | + |
| 160 | + else: |
| 161 | + raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid") |
| 162 | + |
| 163 | + return result |
0 commit comments