Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into triton-version
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored Jan 17, 2025
2 parents 6d2b507 + fbfa53b commit 25fc103
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import math
from contextlib import suppress
from typing import Callable, List, Optional, Union

import torch
from packaging import version
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler

from .logging import get_logger
Expand All @@ -25,6 +27,7 @@
RNGType,
broadcast,
broadcast_object_list,
compare_versions,
concatenate,
find_batch_size,
get_data_structure,
Expand Down Expand Up @@ -415,6 +418,13 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, *
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
)
if use_stateful_dataloader:
torchdata_version = version.parse(importlib.metadata.version("torchdata"))
if (
"in_order" in kwargs
and compare_versions(torchdata_version, "<", "0.11")
and is_torch_version(">=", "2.6.0")
):
kwargs.pop("in_order")
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class GradScalerKwargs(KwargsHandler):
from accelerate import Accelerator
from accelerate.utils import GradScalerKwargs
kwargs = GradScalerKwargs(backoff_filter=0.25)
kwargs = GradScalerKwargs(backoff_factor=0.25)
accelerator = Accelerator(kwargs_handlers=[kwargs])
```
"""
Expand Down

0 comments on commit 25fc103

Please sign in to comment.