17
17
from typing import Callable , List , Optional , Union
18
18
19
19
import torch
20
- from torch .utils .data import BatchSampler , DataLoader , IterableDataset
20
+ from torch .utils .data import BatchSampler , DataLoader , IterableDataset , RandomSampler
21
21
22
22
from .logging import get_logger
23
23
from .state import AcceleratorState , DistributedType , GradientState , is_tpu_available
64
64
_PYTORCH_DATALOADER_KWARGS .update (additional_kwargs )
65
65
66
66
67
+ class SeedableRandomSampler (RandomSampler ):
68
+ """
69
+ Same as a random sampler, except that in `__iter__` a seed can be used.
70
+
71
+ Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
72
+ and be fully reproducable on multiple iterations.
73
+
74
+ If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
75
+ (stored in `self.epoch`).
76
+ """
77
+
78
+ def __init__ (self , * args , ** kwargs ):
79
+ super ().__init__ (* args , ** kwargs )
80
+ self .epoch = 0
81
+
82
+ def __iter__ (self ):
83
+ g = torch .Generator ()
84
+ if self .generator is not None :
85
+ seed = self .epoch + self .generator .initial_seed ()
86
+ else :
87
+ seed = self .epoch
88
+ g .manual_seed (seed )
89
+ n = len (self .data_source )
90
+ # Taken 1:1 from torch.utils.data.sampler.RandomSampler.__iter__
91
+ if self .replacement :
92
+ for _ in range (self .num_samples // 32 ):
93
+ yield from torch .randint (high = n , size = (32 ,), dtype = torch .int64 , generator = g ).tolist ()
94
+ else :
95
+ yield from torch .randperm (n , generator = g ).tolist ()
96
+
97
+ def set_epoch (self , epoch : int ):
98
+ "Sets the current iteration of the sampler."
99
+ self .epoch = epoch
100
+
101
+
67
102
class BatchSamplerShard (BatchSampler ):
68
103
"""
69
104
Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
@@ -271,6 +306,11 @@ def __init__(
271
306
self .process_index = process_index
272
307
self .split_batches = split_batches
273
308
309
+ def set_epoch (self , epoch ):
310
+ self .epoch = epoch
311
+ if hasattr (self .dataset , "set_epoch" ):
312
+ self .dataset .set_epoch (epoch )
313
+
274
314
def __len__ (self ):
275
315
# We will just raise the downstream error if the underlying dataset is not sized
276
316
if self .drop_last :
@@ -279,6 +319,12 @@ def __len__(self):
279
319
return math .ceil (len (self .dataset ) / (self .batch_size * self .num_processes )) * self .batch_size
280
320
281
321
def __iter__ (self ):
322
+ if (
323
+ not hasattr (self .dataset , "set_epoch" )
324
+ and hasattr (self .dataset , "generator" )
325
+ and isinstance (self .dataset .generator , torch .Generator )
326
+ ):
327
+ self .dataset .generator .manual_seed (self .epoch )
282
328
real_batch_size = self .batch_size if self .split_batches else (self .batch_size * self .num_processes )
283
329
process_batch_size = (self .batch_size // self .num_processes ) if self .split_batches else self .batch_size
284
330
process_slice = range (self .process_index * process_batch_size , (self .process_index + 1 ) * process_batch_size )
@@ -391,11 +437,14 @@ def __init__(
391
437
self .skip_batches = skip_batches
392
438
self .gradient_state = GradientState ()
393
439
self ._drop_last = _drop_last
440
+ self .iteration = 0
394
441
395
442
def __iter__ (self ):
396
443
if self .rng_types is not None :
397
444
synchronize_rng_states (self .rng_types , self .synchronized_generator )
398
445
self .begin ()
446
+
447
+ self .set_epoch (self .iteration )
399
448
dataloader_iter = super ().__iter__ ()
400
449
# We iterate one batch ahead to check when we are at the end
401
450
try :
@@ -419,8 +468,21 @@ def __iter__(self):
419
468
if batch_index >= self .skip_batches :
420
469
yield current_batch
421
470
break
471
+
472
+ self .iteration += 1
422
473
self .end ()
423
474
475
+ def set_epoch (self , epoch : int ):
476
+ # In case it is manually passed in, the user can set it to what they like
477
+ if self .iteration != epoch :
478
+ self .iteration = epoch
479
+ if hasattr (self .batch_sampler , "sampler" ) and hasattr (self .batch_sampler .sampler , "set_epoch" ):
480
+ self .batch_sampler .sampler .set_epoch (epoch )
481
+ # We support if a custom `Dataset` implementation has `set_epoch`
482
+ # or in general HF datasets `Datasets`
483
+ elif hasattr (self .dataset , "set_epoch" ):
484
+ self .dataset .set_epoch (epoch )
485
+
424
486
@property
425
487
def total_batch_size (self ):
426
488
batch_sampler = self .sampler if isinstance (self .sampler , BatchSampler ) else self .batch_sampler
@@ -524,6 +586,7 @@ def __init__(
524
586
self .skip_batches = skip_batches
525
587
526
588
self .slice_fn = slice_tensors if slice_fn is None else slice_fn
589
+ self .iteration = 0
527
590
528
591
def _fetch_batches (self , iterator ):
529
592
batches , batch = None , None
@@ -564,6 +627,7 @@ def _fetch_batches(self, iterator):
564
627
565
628
def __iter__ (self ):
566
629
self .begin ()
630
+ self .set_epoch (self .iteration )
567
631
main_iterator = None
568
632
if is_torch_version (">=" , "2.0.1" ):
569
633
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
@@ -633,8 +697,18 @@ def __iter__(self):
633
697
if batch_index >= self .skip_batches :
634
698
yield batch
635
699
batch_index += 1
700
+ self .iteration += 1
636
701
self .end ()
637
702
703
+ def set_epoch (self , epoch : int ):
704
+ # In case it is manually passed in, the user can set it to what they like
705
+ if self .iteration != epoch :
706
+ self .iteration = epoch
707
+ if hasattr (self .batch_sampler .sampler , "set_epoch" ):
708
+ self .batch_sampler .sampler .set_epoch (epoch )
709
+ elif hasattr (self .dataset , "set_epoch" ):
710
+ self .dataset .set_epoch (epoch )
711
+
638
712
def __len__ (self ):
639
713
whole_length = super ().__len__ ()
640
714
if self .split_batches :
@@ -757,6 +831,23 @@ def prepare_data_loader(
757
831
new_batch_sampler = dataloader .batch_sampler if not isinstance (new_dataset , IterableDataset ) else None
758
832
sampler_is_batch_sampler = False
759
833
synchronized_generator = None
834
+ sampler_is_batch_sampler = isinstance (dataloader .sampler , BatchSampler )
835
+ if sampler_is_batch_sampler :
836
+ sampler = dataloader .sampler .sampler
837
+ else :
838
+ sampler = dataloader .batch_sampler .sampler
839
+ if isinstance (sampler , RandomSampler ) and num_processes > 1 :
840
+ # When iterating through the dataloader during distributed processes
841
+ # we want to ensure that on each process we are iterating through the same
842
+ # samples in the same order if a seed is set. This requires a tweak
843
+ # to the `torch.utils.data.RandomSampler` class (if used).
844
+ sampler = SeedableRandomSampler (
845
+ data_source = sampler .data_source ,
846
+ replacement = sampler .replacement ,
847
+ num_samples = sampler ._num_samples ,
848
+ generator = getattr (sampler , "generator" , torch .Generator ()),
849
+ )
850
+
760
851
# No change if no multiprocess
761
852
if (num_processes != 1 or state .distributed_type == DistributedType .MEGATRON_LM ) and not dispatch_batches :
762
853
if isinstance (new_dataset , IterableDataset ):
@@ -771,17 +862,6 @@ def prepare_data_loader(
771
862
split_batches = split_batches ,
772
863
)
773
864
else :
774
- # New batch sampler for the current process.
775
- sampler_is_batch_sampler = isinstance (dataloader .sampler , BatchSampler )
776
- if sampler_is_batch_sampler :
777
- sampler = dataloader .sampler .sampler
778
- else :
779
- sampler = dataloader .batch_sampler .sampler
780
- if hasattr (sampler , "generator" ):
781
- if sampler .generator is None :
782
- sampler .generator = torch .Generator ()
783
- synchronized_generator = sampler .generator
784
-
785
865
batch_sampler = dataloader .sampler if sampler_is_batch_sampler else dataloader .batch_sampler
786
866
new_batch_sampler = BatchSamplerShard (
787
867
batch_sampler ,
@@ -815,7 +895,11 @@ def prepare_data_loader(
815
895
kwargs ["batch_size" ] = (
816
896
dataloader .batch_size // num_processes if split_batches and not dispatch_batches else dataloader .batch_size
817
897
)
818
-
898
+ if isinstance (sampler , SeedableRandomSampler ):
899
+ if sampler_is_batch_sampler :
900
+ dataloader .sampler .sampler = sampler
901
+ else :
902
+ dataloader .batch_sampler .sampler = sampler
819
903
if dispatch_batches :
820
904
kwargs .pop ("generator" )
821
905
dataloader = DataLoaderDispatcher (
0 commit comments