|
11 | 11 |
|
12 | 12 |
|
13 | 13 | from functools import lru_cache
|
14 |
| -from typing import Callable, overload |
| 14 | +from typing import Any, Callable, Protocol, TypeVar |
15 | 15 |
|
16 | 16 | import numpy as np
|
17 | 17 | from scipy import special
|
18 | 18 |
|
19 | 19 |
|
20 | 20 | def binary_search(
|
21 |
| - function: Callable[[int | float], int | float], |
22 |
| - target: int | float, |
23 |
| - lower_bound: int | float, |
24 |
| - upper_bound: int | float, |
25 |
| - tolerance: int | float = 1e-4, |
26 |
| -) -> int | float | None: |
| 21 | + function: Callable[[float], float], |
| 22 | + target: float, |
| 23 | + lower_bound: float, |
| 24 | + upper_bound: float, |
| 25 | + tolerance: float = 1e-4, |
| 26 | +) -> float | None: |
27 | 27 | """Searches for a value in a range by repeatedly dividing the range in half.
|
28 | 28 |
|
29 | 29 | To be more precise, performs numerical binary search to determine the
|
@@ -92,15 +92,16 @@ def choose(n: int, k: int) -> int:
|
92 | 92 | return value
|
93 | 93 |
|
94 | 94 |
|
95 |
| -@overload |
96 |
| -def clip(a: float, min_a: float, max_a: float) -> float: ... |
| 95 | +class Comparable(Protocol): |
| 96 | + def __lt__(self, other: Any) -> bool: ... |
97 | 97 |
|
| 98 | + def __gt__(self, other: Any) -> bool: ... |
98 | 99 |
|
99 |
| -@overload |
100 |
| -def clip(a: str, min_a: str, max_a: str) -> str: ... |
101 | 100 |
|
| 101 | +ComparableT = TypeVar("ComparableT", bound=Comparable) # noqa: Y001 |
102 | 102 |
|
103 |
| -def clip(a, min_a, max_a): # type: ignore[no-untyped-def] |
| 103 | + |
| 104 | +def clip(a: ComparableT, min_a: ComparableT, max_a: ComparableT) -> ComparableT: |
104 | 105 | """Clips ``a`` to the interval [``min_a``, ``max_a``].
|
105 | 106 |
|
106 | 107 | Accepts any comparable objects (i.e. those that support <, >).
|
|
0 commit comments