|
1 | 1 | from collections.abc import Callable |
| 2 | +from typing import TypedDict, type_check_only |
| 3 | +from typing_extensions import TypeAlias, Unpack |
2 | 4 |
|
3 | 5 | import numpy as np |
4 | 6 |
|
5 | 7 | __all__ = ["get_filter", "clear_cache", "sinc_window"] |
6 | 8 |
|
7 | 9 | # Dictionary to cache loaded filters |
8 | | -FILTER_CACHE: dict[str, tuple[np.ndarray[tuple[int, ...], np.dtype[np.floating]], int, float]] |
| 10 | +FILTER_CACHE: dict[str, tuple[np.ndarray[tuple[int], np.dtype[np.float64]], int, float]] |
9 | 11 |
|
10 | 12 | # List of filter functions available |
11 | 13 | FILTER_FUNCTIONS: list[str] |
12 | 14 |
|
| 15 | +_FilterType: TypeAlias = str | Callable[[int], np.ndarray[tuple[int], np.dtype[np.float64]]] |
| 16 | + |
| 17 | +@type_check_only |
| 18 | +class _FilterKwArgs(TypedDict, total=False): |
| 19 | + num_zeros: int |
| 20 | + precision: int |
| 21 | + rolloff: float |
| 22 | + |
13 | 23 | def sinc_window( |
14 | 24 | num_zeros: int = 64, |
15 | 25 | precision: int = 9, |
16 | | - window: Callable[..., np.ndarray[tuple[int, ...], np.dtype[np.floating]]] | None = None, |
| 26 | + window: Callable[[int], np.ndarray[tuple[int], np.dtype[np.float64]]] | None = None, |
17 | 27 | rolloff: float = 0.945, |
18 | | -) -> tuple[np.ndarray[tuple[int, ...], np.dtype[np.floating]], int, float]: ... |
| 28 | +) -> tuple[np.ndarray[tuple[int], np.dtype[np.float64]], int, float]: ... |
19 | 29 | def get_filter( |
20 | | - name_or_function: str | Callable[..., tuple[np.ndarray[tuple[int, ...], np.dtype[np.floating]], int, float]], **kwargs |
21 | | -) -> tuple[np.ndarray[tuple[int, ...], np.dtype[np.floating]], int, float]: ... |
22 | | -def load_filter(filter_name: str) -> tuple[np.ndarray[tuple[int, ...], np.dtype[np.floating]], int, float]: ... |
| 30 | + name_or_function: _FilterType, **kwargs: Unpack[_FilterKwArgs] |
| 31 | +) -> tuple[np.ndarray[tuple[int], np.dtype[np.float64]], int, float]: ... |
| 32 | +def load_filter(filter_name: str) -> tuple[np.ndarray[tuple[int], np.dtype[np.float64]], int, float]: ... |
23 | 33 | def clear_cache() -> None: ... |
0 commit comments