Skip to content

Commit 3cbbaff

Browse files
committed
Add docstrings
1 parent 6b0cb5d commit 3cbbaff

File tree

3 files changed

+46
-8
lines changed

3 files changed

+46
-8
lines changed

scripts/datasets/medical/check_dsad.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010

1111
def check_dsad():
12-
# from util import ROOT
13-
ROOT = "/media/anwai/ANWAI/data"
12+
from util import ROOT
1413

1514
loader = get_dsad_loader(
1615
path=os.path.join(ROOT, "dsad"),

scripts/datasets/medical/check_ircadb.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def check_ircadb():
1616
loader = get_ircadb_loader(
1717
path=os.path.join(ROOT, "3d_ircadb"),
1818
batch_size=2,
19-
patch_shape=(8, 512, 512),
19+
patch_shape=(1, 512, 512),
2020
label_choice="liver",
2121
split="train",
22-
ndim=3,
22+
ndim=2,
2323
download=True,
2424
resize_inputs=True,
2525
sampler=MinInstanceSampler(),

torch_em/data/datasets/medical/ircadb.py

+43-4
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,14 @@ def _preprocess_inputs(path):
7171

7272

7373
def get_ircadb_data(path: Union[os.PathLike, str], download: bool = False) -> str:
74-
"""
74+
"""Download the IRCADb dataset.
75+
76+
Args:
77+
path: Filepath to a folder where the data is downloaded for further processing.
78+
download: Whether to download the data if it is not present.
79+
80+
Returns:
81+
Filepath where the data is downloaded.
7582
"""
7683
data_dir = os.path.join(path, "data")
7784
if os.path.exists(data_dir):
@@ -91,7 +98,14 @@ def get_ircadb_data(path: Union[os.PathLike, str], download: bool = False) -> st
9198
def get_ircadb_paths(
9299
path: Union[os.PathLike, str], split: Optional[Literal["train", "val", "test"]] = None, download: bool = False,
93100
) -> List[str]:
94-
"""
101+
"""Get paths to the IRCADb data.
102+
103+
Args:
104+
path: Filepath to a folder where the data is downloaded for further processing.
105+
download: Whether to download the data if it is not present.
106+
107+
Returns:
108+
List of filepaths for the volumetric data.
95109
"""
96110

97111
data_dir = get_ircadb_data(path, download)
@@ -120,7 +134,19 @@ def get_ircadb_dataset(
120134
download: bool = False,
121135
**kwargs
122136
) -> Dataset:
123-
"""
137+
"""Get the IRCADb dataset for liver (and other organ) segmentation.
138+
139+
Args:
140+
path: Filepath to a folder where the data is downloaded for further processing.
141+
patch_shape: The patch shape to use for training.
142+
label_choice: The choice of labelled organs.
143+
split: The choice of data split.
144+
resize_inputs: Whether to resize the inputs to the expected patch shape.
145+
download: Whether to download the data if it is not present.
146+
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
147+
148+
Returns:
149+
The segmentation dataset.
124150
"""
125151
volume_paths = get_ircadb_paths(path, split, download)
126152

@@ -155,7 +181,20 @@ def get_ircadb_loader(
155181
download: bool = False,
156182
**kwargs
157183
) -> DataLoader:
158-
"""
184+
"""Get the IRCADb dataloader for liver (and other organ) segmentation.
185+
186+
Args:
187+
path: Filepath to a folder where the data is downloaded for further processing.
188+
batch_size: The batch size for training.
189+
patch_shape: The patch shape to use for training.
190+
label_choice: The choice of labelled organs.
191+
split: The choice of data split.
192+
resize_inputs: Whether to resize the inputs to the expected patch shape.
193+
download: Whether to download the data if it is not present.
194+
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
195+
196+
Returns:
197+
The DataLoader.
159198
"""
160199
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
161200
dataset = get_ircadb_dataset(path, patch_shape, label_choice, split, resize_inputs, download, **ds_kwargs)

0 commit comments

Comments
 (0)