Skip to content

Commit 6c1df3f

Browse files
committed
feature: dataset supports slicing
1 parent b3f045a commit 6c1df3f

File tree

1 file changed

+13
-25
lines changed

1 file changed

+13
-25
lines changed

datastream/dataset.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22
from pydantic import BaseModel
33
from typing import (
4-
Tuple, Callable, Union, List, TypeVar, Generic, Dict, Optional
4+
Tuple, Callable, Union, List, TypeVar, Generic, Dict, Optional, Iterable
55
)
66
from pathlib import Path
77
from functools import lru_cache
@@ -89,9 +89,16 @@ def from_dataframe(dataframe: pd.DataFrame) -> Dataset[pd.Series]:
8989
get_item=lambda df, index: df.iloc[index],
9090
)
9191

92-
def __getitem__(self: Dataset[T], index: int) -> T:
93-
'''Get an example ``T`` from the ``Dataset[T]``'''
94-
return self.get_item(self.dataframe, index)
92+
def __getitem__(
93+
self: Dataset[T],
94+
select: Union[int, slice, Iterable, Callable[[pd.DataFrame], Iterable[int]]]
95+
) -> Union[T, Dataset[T]]:
96+
'''Get selection from the ``Dataset[T]``'''
97+
if np.issubdtype(type(select), np.integer):
98+
return self.get_item(self.dataframe, select)
99+
else:
100+
dataframe = self.dataframe.iloc[select]
101+
return self.replace(dataframe=dataframe, length=len(dataframe))
95102

96103
def __len__(self):
97104
return self.length
@@ -198,27 +205,8 @@ def subset(
198205
... )[-1]
199206
2
200207
'''
201-
202-
mask = mask_fn(self.dataframe)
203-
if isinstance(mask, list):
204-
mask = np.array(mask)
205-
elif isinstance(mask, pd.Series):
206-
mask = mask.values
207-
208-
if len(mask.shape) != 1:
209-
raise AssertionError('Expected single dimension in mask')
210-
211-
if len(mask) != len(self):
212-
raise AssertionError(
213-
'Expected mask to have the same length as the dataset'
214-
)
215-
216-
indices = np.argwhere(mask).squeeze(1)
217-
return Dataset(
218-
dataframe=self.dataframe.iloc[indices],
219-
length=len(indices),
220-
get_item=self.get_item,
221-
)
208+
dataframe = self.dataframe[mask_fn(self.dataframe)]
209+
return self.replace(dataframe=dataframe, length=len(dataframe))
222210

223211
def split(
224212
self,

0 commit comments

Comments
 (0)