Skip to content

Commit cc1ca05

Browse files
author
FelixAbrahamsson
committed
refactor: move split functions from dataset to tools
1 parent acdfd8b commit cc1ca05

File tree

4 files changed

+88
-71
lines changed

4 files changed

+88
-71
lines changed

datastream/dataset.py

Lines changed: 9 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
)
66
from pathlib import Path
77
from functools import lru_cache
8-
import warnings
98
import textwrap
109
import inspect
1110
import numpy as np
@@ -117,6 +116,11 @@ def __eq__(self: Dataset[T], other: Dataset[R]) -> bool:
117116
return False
118117
return True
119118

119+
def replace(self, **kwargs):
120+
new_dict = self.dict()
121+
new_dict.update(**kwargs)
122+
return type(self)(**new_dict)
123+
120124
def map(
121125
self: Dataset[T], function: Callable[[T], R]
122126
) -> Dataset[R]:
@@ -258,7 +262,8 @@ def split(
258262
save_directory.mkdir(parents=True, exist_ok=True)
259263

260264
if stratify_column is not None:
261-
return self._stratified_split(
265+
return tools.stratified_split(
266+
self,
262267
key_column=key_column,
263268
proportions=proportions,
264269
stratify_column=stratify_column,
@@ -267,7 +272,8 @@ def split(
267272
frozen=frozen,
268273
)
269274
else:
270-
return self._unstratified_split(
275+
return tools.unstratified_split(
276+
self,
271277
key_column=key_column,
272278
proportions=proportions,
273279
filepath=(
@@ -278,74 +284,6 @@ def split(
278284
frozen=frozen,
279285
)
280286

281-
def _unstratified_split(
282-
self,
283-
key_column: str,
284-
proportions: Dict[str, float],
285-
filepath: Optional[Path] = None,
286-
seed: Optional[int] = None,
287-
frozen: Optional[bool] = False,
288-
):
289-
split_dataframes = tools.numpy_seed(seed)(tools.split_dataframes)
290-
return {
291-
split_name: Dataset(
292-
dataframe=dataframe,
293-
length=len(dataframe),
294-
get_item=self.get_item,
295-
)
296-
for split_name, dataframe in split_dataframes(
297-
self.dataframe,
298-
key_column,
299-
proportions,
300-
filepath=filepath,
301-
frozen=frozen,
302-
).items()
303-
}
304-
305-
def _stratified_split(
306-
self,
307-
key_column: str,
308-
proportions: Dict[str, float],
309-
stratify_column: Optional[str] = None,
310-
save_directory: Optional[Path] = None,
311-
seed: Optional[int] = None,
312-
frozen: Optional[bool] = False,
313-
):
314-
if (
315-
stratify_column is not None
316-
and any(self.dataframe[key_column].duplicated())
317-
):
318-
# mathematically impossible in the general case
319-
warnings.warn(
320-
'Trying to do stratified split with non-unique key column'
321-
' - cannot guarantee correct splitting of key values.'
322-
)
323-
strata = {
324-
stratum_value: self.subset(
325-
lambda df: df[stratify_column] == stratum_value
326-
)
327-
for stratum_value in self.dataframe[stratify_column].unique()
328-
}
329-
split_strata = [
330-
stratum._unstratified_split(
331-
key_column=key_column,
332-
proportions=proportions,
333-
filepath=(
334-
save_directory / f'{hash(stratum_value)}.json'
335-
if save_directory is not None else None
336-
),
337-
seed=seed,
338-
frozen=frozen,
339-
)
340-
for stratum_value, stratum in strata.items()
341-
]
342-
return {
343-
split_name: Dataset.concat(
344-
[split_stratum[split_name] for split_stratum in split_strata]
345-
)
346-
for split_name in proportions.keys()
347-
}
348-
349287
def with_columns(
350288
self: Dataset[T], **kwargs: Callable[pd.Dataframe, pd.Series]
351289
) -> Dataset[T]:

datastream/tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
from datastream.tools.repeat_map_chain import repeat_map_chain
44
from datastream.tools.numpy_seed import numpy_seed
55
from datastream.tools.split_dataframes import split_dataframes
6+
from datastream.tools.unstratified_split import unstratified_split
7+
from datastream.tools.stratified_split import stratified_split

datastream/tools/stratified_split.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import warnings
2+
from typing import Dict, Optional
3+
from pathlib import Path
4+
from datastream import tools
5+
6+
7+
def stratified_split(
8+
dataset,
9+
key_column: str,
10+
proportions: Dict[str, float],
11+
stratify_column: Optional[str] = None,
12+
save_directory: Optional[Path] = None,
13+
seed: Optional[int] = None,
14+
frozen: Optional[bool] = False,
15+
):
16+
if (
17+
stratify_column is not None
18+
and any(dataset.dataframe[key_column].duplicated())
19+
):
20+
# mathematically impossible in the general case
21+
warnings.warn(
22+
'Trying to do stratified split with non-unique key column'
23+
' - cannot guarantee correct splitting of key values.'
24+
)
25+
strata = {
26+
stratum_value: dataset.subset(
27+
lambda df: df[stratify_column] == stratum_value
28+
)
29+
for stratum_value in dataset.dataframe[stratify_column].unique()
30+
}
31+
split_strata = [
32+
tools.unstratified_split(
33+
stratum,
34+
key_column=key_column,
35+
proportions=proportions,
36+
filepath=(
37+
save_directory / f'{hash(stratum_value)}.json'
38+
if save_directory is not None else None
39+
),
40+
seed=seed,
41+
frozen=frozen,
42+
)
43+
for stratum_value, stratum in strata.items()
44+
]
45+
return {
46+
split_name: type(dataset).concat(
47+
[split_stratum[split_name] for split_stratum in split_strata]
48+
)
49+
for split_name in proportions.keys()
50+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Dict, Optional
2+
from pathlib import Path
3+
from datastream import tools
4+
5+
6+
def unstratified_split(
7+
dataset,
8+
key_column: str,
9+
proportions: Dict[str, float],
10+
filepath: Optional[Path] = None,
11+
seed: Optional[int] = None,
12+
frozen: Optional[bool] = False,
13+
):
14+
split_dataframes = tools.numpy_seed(seed)(tools.split_dataframes)
15+
return {
16+
split_name: dataset.replace(
17+
dataframe=dataframe,
18+
length=len(dataframe),
19+
)
20+
for split_name, dataframe in split_dataframes(
21+
dataset.dataframe,
22+
key_column,
23+
proportions,
24+
filepath=filepath,
25+
frozen=frozen,
26+
).items()
27+
}

0 commit comments

Comments
 (0)