diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 6f29d289ecc..4a92aa4617a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -39,7 +39,7 @@ logger = get_logger(__name__) -# kwargs of the DataLoader in min version 1.4.0. +# kwargs of the DataLoader in min version 2.0 _PYTORCH_DATALOADER_KWARGS = { "batch_size": 1, "shuffle": False, @@ -55,10 +55,11 @@ "generator": None, "prefetch_factor": 2, "persistent_workers": False, + "pin_memory_device": "", } # kwargs added after by version -_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {} +_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {"2.6.0": {"in_order": True}} for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items(): if is_torch_version(">=", v):