Skip to content

Commit 7843286

Browse files
authored
Allow for samplers to be seedable and reproducable (#2057)
* bookmark * Works! * Working! * Fully working now * Cover dataset * Needed for dispatch * Check both * Bring back pop, fix hang * Fully working * Change back to epoch * Adjust for new methods * Clean * Fix tests * Avoid circular import * Clean * Fix test * Comment * Add a comment * Comment * Use yield from instead
1 parent 11e2e99 commit 7843286

File tree

6 files changed

+158
-15
lines changed

6 files changed

+158
-15
lines changed

Diff for: src/accelerate/accelerator.py

+7
Original file line numberDiff line numberDiff line change
@@ -2796,6 +2796,9 @@ def _inner(folder):
27962796
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
27972797
schedulers = self._schedulers
27982798

2799+
# Save the samplers of the dataloaders
2800+
dataloaders = self._dataloaders
2801+
27992802
# Call model loading hooks that might have been registered with
28002803
# accelerator.register_model_state_hook
28012804
for hook in self._save_model_state_pre_hook.values():
@@ -2806,6 +2809,7 @@ def _inner(folder):
28062809
weights,
28072810
optimizers,
28082811
schedulers,
2812+
dataloaders,
28092813
self.state.process_index,
28102814
self.scaler,
28112815
save_on_each_node=self.project_configuration.save_on_each_node,
@@ -2935,6 +2939,8 @@ def _inner(folder):
29352939
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
29362940
schedulers = self._schedulers
29372941

2942+
dataloaders = self._dataloaders
2943+
29382944
# Call model loading hooks that might have been registered with
29392945
# accelerator.register_model_state_hook
29402946
for hook in self._load_model_state_pre_hook.values():
@@ -2955,6 +2961,7 @@ def _inner(folder):
29552961
models,
29562962
optimizers,
29572963
schedulers,
2964+
dataloaders,
29582965
self.state.process_index,
29592966
self.scaler,
29602967
map_location,

Diff for: src/accelerate/checkpointing.py

+40
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
import numpy as np
2121
import torch
2222
from torch.cuda.amp import GradScaler
23+
from torch.utils.data import BatchSampler
2324

2425
from .utils import (
2526
MODEL_NAME,
2627
OPTIMIZER_NAME,
2728
RNG_STATE_NAME,
29+
SAMPLER_NAME,
2830
SCALER_NAME,
2931
SCHEDULER_NAME,
3032
get_pretty_name,
@@ -49,6 +51,7 @@ def save_accelerator_state(
4951
model_states: List[dict],
5052
optimizers: list,
5153
schedulers: list,
54+
dataloaders: list,
5255
process_index: int,
5356
scaler: GradScaler = None,
5457
save_on_each_node: bool = False,
@@ -65,6 +68,8 @@ def save_accelerator_state(
6568
A list of optimizer instances
6669
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
6770
A list of learning rate schedulers
71+
dataloaders (`List[torch.utils.data.DataLoader]`):
72+
A list of dataloader instances to save their sampler states
6873
process_index (`int`):
6974
The current process index in the Accelerator state
7075
scaler (`torch.cuda.amp.GradScaler`, *optional*):
@@ -92,6 +97,22 @@ def save_accelerator_state(
9297
output_scheduler_file = os.path.join(output_dir, scheduler_name)
9398
save(state, output_scheduler_file, save_on_each_node=save_on_each_node)
9499
logger.info(f"Scheduler state saved in {output_scheduler_file}")
100+
# DataLoader states
101+
for i, dataloader in enumerate(dataloaders):
102+
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
103+
output_sampler_file = os.path.join(output_dir, sampler_name)
104+
# Only save if we have our custom sampler
105+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
106+
if sampler_is_batch_sampler:
107+
sampler = dataloader.sampler.sampler
108+
else:
109+
sampler = dataloader.batch_sampler.sampler
110+
from .data_loader import SeedableRandomSampler
111+
112+
if isinstance(sampler, SeedableRandomSampler):
113+
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
114+
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
115+
95116
# GradScaler state
96117
if scaler is not None:
97118
state = scaler.state_dict()
@@ -121,6 +142,7 @@ def load_accelerator_state(
121142
models,
122143
optimizers,
123144
schedulers,
145+
dataloaders,
124146
process_index,
125147
scaler=None,
126148
map_location=None,
@@ -177,6 +199,24 @@ def load_accelerator_state(
177199
scheduler.load_state_dict(torch.load(input_scheduler_file))
178200
logger.info("All scheduler states loaded successfully")
179201

202+
for i, dataloader in enumerate(dataloaders):
203+
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
204+
input_sampler_file = os.path.join(input_dir, sampler_name)
205+
# Only load if we have our custom sampler
206+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
207+
if sampler_is_batch_sampler:
208+
sampler = dataloader.sampler.sampler
209+
else:
210+
sampler = dataloader.batch_sampler.sampler
211+
from .data_loader import SeedableRandomSampler
212+
213+
if isinstance(sampler, SeedableRandomSampler):
214+
if sampler_is_batch_sampler:
215+
dataloader.sampler.sampler = torch.load(input_sampler_file)
216+
else:
217+
dataloader.batch_sampler.sampler = torch.load(input_sampler_file)
218+
logger.info("All dataloader sampler states loaded successfully")
219+
180220
# GradScaler state
181221
if scaler is not None:
182222
input_scaler_file = os.path.join(input_dir, SCALER_NAME)

Diff for: src/accelerate/data_loader.py

+97-13
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Callable, List, Optional, Union
1818

1919
import torch
20-
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
20+
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
2121

2222
from .logging import get_logger
2323
from .state import AcceleratorState, DistributedType, GradientState, is_tpu_available
@@ -64,6 +64,41 @@
6464
_PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
6565

6666

67+
class SeedableRandomSampler(RandomSampler):
68+
"""
69+
Same as a random sampler, except that in `__iter__` a seed can be used.
70+
71+
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
72+
and be fully reproducable on multiple iterations.
73+
74+
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
75+
(stored in `self.epoch`).
76+
"""
77+
78+
def __init__(self, *args, **kwargs):
79+
super().__init__(*args, **kwargs)
80+
self.epoch = 0
81+
82+
def __iter__(self):
83+
g = torch.Generator()
84+
if self.generator is not None:
85+
seed = self.epoch + self.generator.initial_seed()
86+
else:
87+
seed = self.epoch
88+
g.manual_seed(seed)
89+
n = len(self.data_source)
90+
# Taken 1:1 from torch.utils.data.sampler.RandomSampler.__iter__
91+
if self.replacement:
92+
for _ in range(self.num_samples // 32):
93+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist()
94+
else:
95+
yield from torch.randperm(n, generator=g).tolist()
96+
97+
def set_epoch(self, epoch: int):
98+
"Sets the current iteration of the sampler."
99+
self.epoch = epoch
100+
101+
67102
class BatchSamplerShard(BatchSampler):
68103
"""
69104
Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
@@ -271,6 +306,11 @@ def __init__(
271306
self.process_index = process_index
272307
self.split_batches = split_batches
273308

309+
def set_epoch(self, epoch):
310+
self.epoch = epoch
311+
if hasattr(self.dataset, "set_epoch"):
312+
self.dataset.set_epoch(epoch)
313+
274314
def __len__(self):
275315
# We will just raise the downstream error if the underlying dataset is not sized
276316
if self.drop_last:
@@ -279,6 +319,12 @@ def __len__(self):
279319
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
280320

281321
def __iter__(self):
322+
if (
323+
not hasattr(self.dataset, "set_epoch")
324+
and hasattr(self.dataset, "generator")
325+
and isinstance(self.dataset.generator, torch.Generator)
326+
):
327+
self.dataset.generator.manual_seed(self.epoch)
282328
real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
283329
process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
284330
process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
@@ -391,11 +437,14 @@ def __init__(
391437
self.skip_batches = skip_batches
392438
self.gradient_state = GradientState()
393439
self._drop_last = _drop_last
440+
self.iteration = 0
394441

395442
def __iter__(self):
396443
if self.rng_types is not None:
397444
synchronize_rng_states(self.rng_types, self.synchronized_generator)
398445
self.begin()
446+
447+
self.set_epoch(self.iteration)
399448
dataloader_iter = super().__iter__()
400449
# We iterate one batch ahead to check when we are at the end
401450
try:
@@ -419,8 +468,21 @@ def __iter__(self):
419468
if batch_index >= self.skip_batches:
420469
yield current_batch
421470
break
471+
472+
self.iteration += 1
422473
self.end()
423474

475+
def set_epoch(self, epoch: int):
476+
# In case it is manually passed in, the user can set it to what they like
477+
if self.iteration != epoch:
478+
self.iteration = epoch
479+
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
480+
self.batch_sampler.sampler.set_epoch(epoch)
481+
# We support if a custom `Dataset` implementation has `set_epoch`
482+
# or in general HF datasets `Datasets`
483+
elif hasattr(self.dataset, "set_epoch"):
484+
self.dataset.set_epoch(epoch)
485+
424486
@property
425487
def total_batch_size(self):
426488
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
@@ -524,6 +586,7 @@ def __init__(
524586
self.skip_batches = skip_batches
525587

526588
self.slice_fn = slice_tensors if slice_fn is None else slice_fn
589+
self.iteration = 0
527590

528591
def _fetch_batches(self, iterator):
529592
batches, batch = None, None
@@ -564,6 +627,7 @@ def _fetch_batches(self, iterator):
564627

565628
def __iter__(self):
566629
self.begin()
630+
self.set_epoch(self.iteration)
567631
main_iterator = None
568632
if is_torch_version(">=", "2.0.1"):
569633
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
@@ -633,8 +697,18 @@ def __iter__(self):
633697
if batch_index >= self.skip_batches:
634698
yield batch
635699
batch_index += 1
700+
self.iteration += 1
636701
self.end()
637702

703+
def set_epoch(self, epoch: int):
704+
# In case it is manually passed in, the user can set it to what they like
705+
if self.iteration != epoch:
706+
self.iteration = epoch
707+
if hasattr(self.batch_sampler.sampler, "set_epoch"):
708+
self.batch_sampler.sampler.set_epoch(epoch)
709+
elif hasattr(self.dataset, "set_epoch"):
710+
self.dataset.set_epoch(epoch)
711+
638712
def __len__(self):
639713
whole_length = super().__len__()
640714
if self.split_batches:
@@ -757,6 +831,23 @@ def prepare_data_loader(
757831
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
758832
sampler_is_batch_sampler = False
759833
synchronized_generator = None
834+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
835+
if sampler_is_batch_sampler:
836+
sampler = dataloader.sampler.sampler
837+
else:
838+
sampler = dataloader.batch_sampler.sampler
839+
if isinstance(sampler, RandomSampler) and num_processes > 1:
840+
# When iterating through the dataloader during distributed processes
841+
# we want to ensure that on each process we are iterating through the same
842+
# samples in the same order if a seed is set. This requires a tweak
843+
# to the `torch.utils.data.RandomSampler` class (if used).
844+
sampler = SeedableRandomSampler(
845+
data_source=sampler.data_source,
846+
replacement=sampler.replacement,
847+
num_samples=sampler._num_samples,
848+
generator=getattr(sampler, "generator", torch.Generator()),
849+
)
850+
760851
# No change if no multiprocess
761852
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
762853
if isinstance(new_dataset, IterableDataset):
@@ -771,17 +862,6 @@ def prepare_data_loader(
771862
split_batches=split_batches,
772863
)
773864
else:
774-
# New batch sampler for the current process.
775-
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
776-
if sampler_is_batch_sampler:
777-
sampler = dataloader.sampler.sampler
778-
else:
779-
sampler = dataloader.batch_sampler.sampler
780-
if hasattr(sampler, "generator"):
781-
if sampler.generator is None:
782-
sampler.generator = torch.Generator()
783-
synchronized_generator = sampler.generator
784-
785865
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
786866
new_batch_sampler = BatchSamplerShard(
787867
batch_sampler,
@@ -815,7 +895,11 @@ def prepare_data_loader(
815895
kwargs["batch_size"] = (
816896
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
817897
)
818-
898+
if isinstance(sampler, SeedableRandomSampler):
899+
if sampler_is_batch_sampler:
900+
dataloader.sampler.sampler = sampler
901+
else:
902+
dataloader.batch_sampler.sampler = sampler
819903
if dispatch_batches:
820904
kwargs.pop("generator")
821905
dataloader = DataLoaderDispatcher(

Diff for: src/accelerate/test_utils/scripts/test_script.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch.utils.data import DataLoader
2626

2727
from accelerate import Accelerator
28-
from accelerate.data_loader import prepare_data_loader
28+
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
2929
from accelerate.state import AcceleratorState
3030
from accelerate.test_utils import RegressionDataset, are_the_same_tensors
3131
from accelerate.utils import (
@@ -292,7 +292,17 @@ def mock_training(length, batch_size, generator):
292292
set_seed(42)
293293
generator.manual_seed(42)
294294
train_set = RegressionDataset(length=length, seed=42)
295-
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
295+
if AcceleratorState().num_processes > 1:
296+
# The SeedableRandomSampler is needed during distributed setups
297+
# for full reproducability across processes with the `DataLoader`
298+
sampler = SeedableRandomSampler(
299+
generator=generator,
300+
data_source=train_set,
301+
num_samples=len(train_set),
302+
)
303+
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
304+
else:
305+
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
296306
model = RegressionModel()
297307
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
298308
for epoch in range(3):

Diff for: src/accelerate/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
RNG_STATE_NAME,
55
SAFE_WEIGHTS_INDEX_NAME,
66
SAFE_WEIGHTS_NAME,
7+
SAMPLER_NAME,
78
SCALER_NAME,
89
SCHEDULER_NAME,
910
TORCH_DISTRIBUTED_OPERATION_TYPES,

Diff for: src/accelerate/utils/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
RNG_STATE_NAME = "random_states"
2121
OPTIMIZER_NAME = "optimizer"
2222
SCHEDULER_NAME = "scheduler"
23+
SAMPLER_NAME = "sampler"
2324
WEIGHTS_NAME = "pytorch_model.bin"
2425
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
2526
SAFE_WEIGHTS_NAME = "model.safetensors"

0 commit comments

Comments
 (0)