|
1 | 1 | from __future__ import annotations
|
2 | 2 | from pydantic import BaseModel
|
3 | 3 | from typing import (
|
4 |
| - Tuple, Callable, Union, List, TypeVar, Generic, Dict, Optional |
| 4 | + Tuple, Callable, Union, List, TypeVar, Generic, Dict, Optional, Iterable |
5 | 5 | )
|
6 | 6 | from pathlib import Path
|
7 | 7 | from functools import lru_cache
|
@@ -89,9 +89,16 @@ def from_dataframe(dataframe: pd.DataFrame) -> Dataset[pd.Series]:
|
89 | 89 | get_item=lambda df, index: df.iloc[index],
|
90 | 90 | )
|
91 | 91 |
|
92 |
| - def __getitem__(self: Dataset[T], index: int) -> T: |
93 |
| - '''Get an example ``T`` from the ``Dataset[T]``''' |
94 |
| - return self.get_item(self.dataframe, index) |
| 92 | + def __getitem__( |
| 93 | + self: Dataset[T], |
| 94 | + select: Union[int, slice, Iterable, Callable[[pd.DataFrame], Iterable[int]]] |
| 95 | + ) -> Union[T, Dataset[T]]: |
| 96 | + '''Get selection from the ``Dataset[T]``''' |
| 97 | + if np.issubdtype(type(select), np.integer): |
| 98 | + return self.get_item(self.dataframe, select) |
| 99 | + else: |
| 100 | + dataframe = self.dataframe.iloc[select] |
| 101 | + return self.replace(dataframe=dataframe, length=len(dataframe)) |
95 | 102 |
|
96 | 103 | def __len__(self):
|
97 | 104 | return self.length
|
@@ -198,27 +205,8 @@ def subset(
|
198 | 205 | ... )[-1]
|
199 | 206 | 2
|
200 | 207 | '''
|
201 |
| - |
202 |
| - mask = mask_fn(self.dataframe) |
203 |
| - if isinstance(mask, list): |
204 |
| - mask = np.array(mask) |
205 |
| - elif isinstance(mask, pd.Series): |
206 |
| - mask = mask.values |
207 |
| - |
208 |
| - if len(mask.shape) != 1: |
209 |
| - raise AssertionError('Expected single dimension in mask') |
210 |
| - |
211 |
| - if len(mask) != len(self): |
212 |
| - raise AssertionError( |
213 |
| - 'Expected mask to have the same length as the dataset' |
214 |
| - ) |
215 |
| - |
216 |
| - indices = np.argwhere(mask).squeeze(1) |
217 |
| - return Dataset( |
218 |
| - dataframe=self.dataframe.iloc[indices], |
219 |
| - length=len(indices), |
220 |
| - get_item=self.get_item, |
221 |
| - ) |
| 208 | + dataframe = self.dataframe[mask_fn(self.dataframe)] |
| 209 | + return self.replace(dataframe=dataframe, length=len(dataframe)) |
222 | 210 |
|
223 | 211 | def split(
|
224 | 212 | self,
|
|
0 commit comments