Skip to content

Commit 048ea82

Browse files
committed
fix deprecation warning
1 parent 521eb5b commit 048ea82

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/accelerate/accelerator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def __init__(
497497
elif is_xpu_available():
498498
self.scaler = torch.amp.GradScaler("xpu", **kwargs)
499499
else:
500-
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
500+
self.scaler = torch.amp.GradScaler("cuda", **kwargs)
501501

502502
elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
503503
DistributedType.DEEPSPEED,

0 commit comments

Comments
 (0)