Skip to content

Commit c9a525a

Browse files
committed
Quantity ABC try 2, part 2
1 parent 448eddf commit c9a525a

13 files changed

+293
-608
lines changed

src/frequenz/sdk/microgrid/_data_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
#
3535
# pylint: disable=import-outside-toplevel
3636
if typing.TYPE_CHECKING:
37-
from ..actor import ComponentMetricRequest, ResamplerConfig, _power_managing
37+
from ..actor import ComponentMetricRequest, ResamplingActorConfig, _power_managing
3838
from ..actor.power_distributing import ( # noqa: F401 (imports used by string type hints)
3939
ComponentPoolStatus,
4040
PowerDistributingActor,
@@ -78,7 +78,7 @@ class _DataPipeline: # pylint: disable=too-many-instance-attributes
7878

7979
def __init__(
8080
self,
81-
resampler_config: ResamplerConfig[float],
81+
resampler_config: ResamplingActorConfig,
8282
) -> None:
8383
"""Create a `DataPipeline` instance.
8484
@@ -384,7 +384,7 @@ async def _stop(self) -> None:
384384
_DATA_PIPELINE: _DataPipeline | None = None
385385

386386

387-
async def initialize(resampler_config: ResamplerConfig[float]) -> None:
387+
async def initialize(resampler_config: ResamplingActorConfig) -> None:
388388
"""Initialize a `DataPipeline` instance.
389389
390390
Args:

src/frequenz/sdk/timeseries/_base_types.py

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,22 @@
66
import dataclasses
77
import enum
88
import functools
9-
import typing
109
from collections.abc import Callable, Iterator
1110
from dataclasses import dataclass
1211
from datetime import datetime, timezone
13-
from typing import Generic, Protocol, Self, SupportsFloat, overload
12+
from typing import Generic, Self, SupportsFloat, TypeVar, overload
1413

1514
from ._quantities import Power
1615

17-
UNIX_EPOCH = datetime.fromtimestamp(0.0, tz=timezone.utc)
18-
"""The UNIX epoch (in UTC)."""
19-
20-
21-
class Comparable(Protocol):
22-
def __lt__(self, other: Self) -> bool:
23-
...
24-
25-
def __gt__(self, other: Self) -> bool:
26-
...
27-
28-
def __le__(self, other: Self) -> bool:
29-
...
30-
31-
def __ge__(self, other: Self) -> bool:
32-
...
33-
34-
35-
_T = typing.TypeVar("_T")
36-
SupportsFloatT = typing.TypeVar("SupportsFloatT", bound=SupportsFloat)
16+
SupportsFloatT = TypeVar("SupportsFloatT", bound=SupportsFloat)
3717
"""Type variable for types that support conversion to float."""
3818

39-
ComparableT = typing.TypeVar("ComparableT", bound=Comparable)
19+
UNIX_EPOCH = datetime.fromtimestamp(0.0, tz=timezone.utc)
20+
"""The UNIX epoch (in UTC)."""
4021

4122

4223
@dataclass(frozen=True, order=True)
43-
class Sample(Generic[_T]):
24+
class Sample(Generic[SupportsFloatT]):
4425
"""A measurement taken at a particular point in time.
4526
4627
The `value` could be `None` if a component is malfunctioning or data is
@@ -51,12 +32,12 @@ class Sample(Generic[_T]):
5132
timestamp: datetime
5233
"""The time when this sample was generated."""
5334

54-
value: _T | None = None
35+
value: SupportsFloatT | None = None
5536
"""The value of this sample."""
5637

5738

5839
@dataclass(frozen=True)
59-
class Sample3Phase(Generic[ComparableT]):
40+
class Sample3Phase(Generic[SupportsFloatT]):
6041
"""A 3-phase measurement made at a particular point in time.
6142
6243
Each of the `value` fields could be `None` if a component is malfunctioning
@@ -67,16 +48,16 @@ class Sample3Phase(Generic[ComparableT]):
6748

6849
timestamp: datetime
6950
"""The time when this sample was generated."""
70-
value_p1: ComparableT | None
51+
value_p1: SupportsFloatT | None
7152
"""The value of the 1st phase in this sample."""
7253

73-
value_p2: ComparableT | None
54+
value_p2: SupportsFloatT | None
7455
"""The value of the 2nd phase in this sample."""
7556

76-
value_p3: ComparableT | None
57+
value_p3: SupportsFloatT | None
7758
"""The value of the 3rd phase in this sample."""
7859

79-
def __iter__(self) -> Iterator[ComparableT | None]:
60+
def __iter__(self) -> Iterator[SupportsFloatT | None]:
8061
"""Return an iterator that yields values from each of the phases.
8162
8263
Yields:
@@ -87,14 +68,14 @@ def __iter__(self) -> Iterator[ComparableT | None]:
8768
yield self.value_p3
8869

8970
@overload
90-
def max(self, default: ComparableT) -> ComparableT:
71+
def max(self, default: SupportsFloatT) -> SupportsFloatT:
9172
...
9273

9374
@overload
94-
def max(self, default: None = None) -> ComparableT | None:
75+
def max(self, default: None = None) -> SupportsFloatT | None:
9576
...
9677

97-
def max(self, default: ComparableT | None = None) -> ComparableT | None:
78+
def max(self, default: SupportsFloatT | None = None) -> SupportsFloatT | None:
9879
"""Return the max value among all phases, or default if they are all `None`.
9980
10081
Args:
@@ -105,21 +86,21 @@ def max(self, default: ComparableT | None = None) -> ComparableT | None:
10586
"""
10687
if not any(self):
10788
return default
108-
value: ComparableT = functools.reduce(
109-
lambda x, y: x if x > y else y,
89+
value: SupportsFloatT = functools.reduce(
90+
lambda x, y: x if float(x) > float(y) else y,
11091
filter(None, self),
11192
)
11293
return value
11394

11495
@overload
115-
def min(self, default: ComparableT) -> ComparableT:
96+
def min(self, default: SupportsFloatT) -> SupportsFloatT:
11697
...
11798

11899
@overload
119-
def min(self, default: None = None) -> ComparableT | None:
100+
def min(self, default: None = None) -> SupportsFloatT | None:
120101
...
121102

122-
def min(self, default: ComparableT | None = None) -> ComparableT | None:
103+
def min(self, default: SupportsFloatT | None = None) -> SupportsFloatT | None:
123104
"""Return the min value among all phases, or default if they are all `None`.
124105
125106
Args:
@@ -130,16 +111,16 @@ def min(self, default: ComparableT | None = None) -> ComparableT | None:
130111
"""
131112
if not any(self):
132113
return default
133-
value: ComparableT = functools.reduce(
134-
lambda x, y: x if x < y else y,
114+
value: SupportsFloatT = functools.reduce(
115+
lambda x, y: x if float(x) < float(y) else y,
135116
filter(None, self),
136117
)
137118
return value
138119

139120
def map(
140121
self,
141-
function: Callable[[ComparableT], ComparableT],
142-
default: ComparableT | None = None,
122+
function: Callable[[SupportsFloatT], SupportsFloatT],
123+
default: SupportsFloatT | None = None,
143124
) -> Self:
144125
"""Apply the given function on each of the phase values and return the result.
145126
@@ -161,6 +142,9 @@ def map(
161142
)
162143

163144

145+
_T = TypeVar("_T")
146+
147+
164148
@dataclass(frozen=True)
165149
class Bounds(Generic[_T]):
166150
"""Lower and upper bound values."""

src/frequenz/sdk/timeseries/_moving_window.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616
from numpy.typing import ArrayLike
1717

1818
from ..actor._background_service import BackgroundService
19-
from ._base_types import UNIX_EPOCH, Sample
20-
from ._quantities import QuantityT
19+
from ._base_types import UNIX_EPOCH, Sample, SupportsFloatT
2120
from ._resampling import Resampler, ResamplerConfig
2221
from ._ringbuffer import OrderedRingBuffer
2322

2423
_logger = logging.getLogger(__name__)
2524

2625

27-
class MovingWindow(BackgroundService, Generic[QuantityT]):
26+
class MovingWindow(BackgroundService, Generic[SupportsFloatT]):
2827
"""
2928
A data window that moves with the latest datapoints of a data stream.
3029
@@ -130,9 +129,9 @@ async def run() -> None:
130129
def __init__( # pylint: disable=too-many-arguments
131130
self,
132131
size: timedelta,
133-
resampled_data_recv: Receiver[Sample[QuantityT]],
132+
resampled_data_recv: Receiver[Sample[SupportsFloatT]],
134133
input_sampling_period: timedelta,
135-
resampler_config: ResamplerConfig[QuantityT] | None = None,
134+
resampler_config: ResamplerConfig[SupportsFloatT] | None = None,
136135
align_to: datetime = UNIX_EPOCH,
137136
*,
138137
name: str | None = None,
@@ -166,8 +165,8 @@ def __init__( # pylint: disable=too-many-arguments
166165

167166
self._sampling_period = input_sampling_period
168167

169-
self._resampler: Resampler[QuantityT] | None = None
170-
self._resampler_sender: Sender[Sample[QuantityT]] | None = None
168+
self._resampler: Resampler[SupportsFloatT] | None = None
169+
self._resampler_sender: Sender[Sample[SupportsFloatT]] | None = None
171170

172171
if resampler_config:
173172
assert (
@@ -182,7 +181,9 @@ def __init__( # pylint: disable=too-many-arguments
182181
size.total_seconds() / self._sampling_period.total_seconds()
183182
)
184183

185-
self._resampled_data_recv: Receiver[Sample[QuantityT]] = resampled_data_recv
184+
self._resampled_data_recv: Receiver[
185+
Sample[SupportsFloatT]
186+
] = resampled_data_recv
186187
self._buffer = OrderedRingBuffer(
187188
np.empty(shape=num_samples, dtype=float),
188189
sampling_period=self._sampling_period,
@@ -341,11 +342,11 @@ def _configure_resampler(self) -> None:
341342
"""Configure the components needed to run the resampler."""
342343
assert self._resampler is not None
343344

344-
async def sink_buffer(sample: Sample[QuantityT]) -> None:
345+
async def sink_buffer(sample: Sample[SupportsFloatT]) -> None:
345346
if sample.value is not None:
346347
self._buffer.update(sample)
347348

348-
resampler_channel = Broadcast[Sample[QuantityT]]("average")
349+
resampler_channel = Broadcast[Sample[SupportsFloatT]]("average")
349350
self._resampler_sender = resampler_channel.new_sender()
350351
self._resampler.add_timeseries(
351352
"avg",

src/frequenz/sdk/timeseries/_periodic_feature_extractor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from numpy.typing import NDArray
2323

2424
from .._internal._math import is_close_to_zero
25-
from ..timeseries._quantities import QuantityT
2625
from ._moving_window import MovingWindow
26+
from ._quantities import SupportsFloatT
2727
from ._ringbuffer import OrderedRingBuffer
2828

2929
_logger = logging.getLogger(__name__)
@@ -50,7 +50,7 @@ class RelativePositions:
5050
"""The relative position of the next incoming sample."""
5151

5252

53-
class PeriodicFeatureExtractor(Generic[QuantityT]):
53+
class PeriodicFeatureExtractor(Generic[SupportsFloatT]):
5454
"""
5555
A feature extractor for historical timeseries data.
5656
@@ -108,7 +108,7 @@ class PeriodicFeatureExtractor(Generic[QuantityT]):
108108

109109
def __init__(
110110
self,
111-
moving_window: MovingWindow[QuantityT],
111+
moving_window: MovingWindow[SupportsFloatT],
112112
period: timedelta,
113113
) -> None:
114114
"""
@@ -121,7 +121,7 @@ def __init__(
121121
Raises:
122122
ValueError: If the MovingWindow size is not a integer multiple of the period.
123123
"""
124-
self._moving_window: MovingWindow[QuantityT] = moving_window
124+
self._moving_window: MovingWindow[SupportsFloatT] = moving_window
125125

126126
self._sampling_period = self._moving_window.sampling_period
127127
"""The sampling_period as float to use it for indexing of samples."""

0 commit comments

Comments
 (0)