|
11 | 11 | import pickle
|
12 | 12 | import sys
|
13 | 13 | from typing import Literal
|
| 14 | +from typing import Optional |
| 15 | +from typing import Union |
14 | 16 |
|
15 | 17 | import numpy as np
|
16 | 18 | from numpy.random import RandomState
|
@@ -335,8 +337,8 @@ def train_test_validate(
|
335 | 337 | 'mixed-set', 'drug-blind', 'cancer-blind'
|
336 | 338 | ]='mixed-set',
|
337 | 339 | ratio: tuple[int, int, int]=(8,1,1),
|
338 |
| - stratify_by: (str | None)=None, |
339 |
| - random_state: (int | RandomState | None)=None, |
| 340 | + stratify_by: Optional[str]=None, |
| 341 | + random_state: Optional[Union[int,RandomState]]=None, |
340 | 342 | **kwargs: dict,
|
341 | 343 | ) -> Split:
|
342 | 344 |
|
@@ -386,7 +388,7 @@ def save(self, path: Path) -> None:
|
386 | 388 |
|
387 | 389 | def load(
|
388 | 390 | name: str,
|
389 |
| - local_path: str|Path=Path.cwd(), |
| 391 | + local_path: Union[str,Path]=Path.cwd(), |
390 | 392 | from_pickle:bool=False
|
391 | 393 | ) -> Dataset:
|
392 | 394 | """
|
@@ -669,8 +671,8 @@ def train_test_validate(
|
669 | 671 | 'mixed-set', 'drug-blind', 'cancer-blind'
|
670 | 672 | ]='mixed-set',
|
671 | 673 | ratio: tuple[int, int, int]=(8,1,1),
|
672 |
| - stratify_by: (str | None)=None, |
673 |
| - random_state: (int | RandomState | None)=None, |
| 674 | + stratify_by: Optional[str]=None, |
| 675 | + random_state: Optional[Union[int,RandomState]]=None, |
674 | 676 | **kwargs: dict,
|
675 | 677 | ) -> Split:
|
676 | 678 | """
|
@@ -1015,7 +1017,7 @@ def _load_file(file_path: Path) -> pd.DataFrame:
|
1015 | 1017 | )
|
1016 | 1018 |
|
1017 | 1019 |
|
1018 |
| -def _determine_delimiter(file_path): |
| 1020 | +def _determine_delimiter(file_path: Path) -> str: |
1019 | 1021 | if '.tsv' in file_path.suffixes:
|
1020 | 1022 | return '\t'
|
1021 | 1023 | else:
|
|
0 commit comments