@@ -22,6 +22,7 @@ def __init__(
2222 automatic_batching = None ,
2323 num_workers = None ,
2424 pin_memory = None ,
25+ shuffle = None ,
2526 ** kwargs ,
2627 ):
2728 """
@@ -34,13 +35,13 @@ def __init__(
3435 If ``batch_size=None`` all
3536 samples are loaded and data are not batched, defaults to None.
3637 :type batch_size: int | None
37- :param train_size: percentage of elements in the train dataset
38+ :param train_size: Percentage of elements in the train dataset.
3839 :type train_size: float
39- :param test_size: percentage of elements in the test dataset
40+ :param test_size: Percentage of elements in the test dataset.
4041 :type test_size: float
41- :param val_size: percentage of elements in the val dataset
42+ :param val_size: Percentage of elements in the val dataset.
4243 :type val_size: float
43- :param predict_size: percentage of elements in the predict dataset
44+ :param predict_size: Percentage of elements in the predict dataset.
4445 :type predict_size: float
4546 :param compile: if True model is compiled before training,
4647 default False. For Windows users compilation is always disabled.
@@ -49,9 +50,13 @@ def __init__(
4950 performed. Please avoid using automatic batching when batch_size is
5051 large, default False.
5152 :type automatic_batching: bool
52- :param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
53+ :param num_workers: Number of worker threads for data loading.
54+ Default 0 (serial loading).
5355 :type num_workers: int
54- :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
56+ :param pin_memory: Whether to use pinned memory for faster data
57+ transfer to GPU. Default False.
58+ :type pin_memory: bool
59+ :param shuffle: Whether to shuffle the data for training. Default False.
5560 :type pin_memory: bool
5661
5762 :Keyword Arguments:
@@ -77,6 +82,10 @@ def __init__(
7782 check_consistency (pin_memory , int )
7883 else :
7984 num_workers = 0
85+ if shuffle is not None :
86+ check_consistency (shuffle , bool )
87+ else :
88+ shuffle = False
8089 if train_size + test_size + val_size + predict_size > 1 :
8190 raise ValueError (
8291 "train_size, test_size, val_size and predict_size "
@@ -131,6 +140,7 @@ def __init__(
131140 automatic_batching ,
132141 pin_memory ,
133142 num_workers ,
143+ shuffle ,
134144 )
135145
136146 # logging
@@ -166,6 +176,7 @@ def _create_datamodule(
166176 automatic_batching ,
167177 pin_memory ,
168178 num_workers ,
179+ shuffle ,
169180 ):
170181 """
171182 This method is used here because is resampling is needed
@@ -196,6 +207,7 @@ def _create_datamodule(
196207 automatic_batching = automatic_batching ,
197208 num_workers = num_workers ,
198209 pin_memory = pin_memory ,
210+ shuffle = shuffle ,
199211 )
200212
201213 def train (self , ** kwargs ):
0 commit comments