Skip to content

Commit 1a1f6cd

Browse files
quic-swatiashubhagr-quic
authored andcommitted
Passing device type in torch GradScaler (#345)
Passing device type in torch GradScaler(), in case of CUDA and CPU so that it picks up the correct device in case of CPU. Signed-off-by: Swati Allabadi <[email protected]> Co-authored-by: Swati Allabadi <[email protected]>
1 parent 5f09b72 commit 1a1f6cd

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

QEfficient/finetune/utils/train_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def train(
8383
best_val_loss = float("inf")
8484
total_train_steps = 0
8585
max_steps_reached = False # Flag to indicate max training steps reached
86+
device_type = device.split(":")[0]
8687

8788
tensorboard_updates = None
8889
if train_config.enable_ddp:
@@ -95,7 +96,7 @@ def train(
9596
if device.startswith("qaic"):
9697
scaler = QAicGradScaler()
9798
else:
98-
scaler = GradScaler()
99+
scaler = GradScaler(device_type)
99100

100101
loss_0_counter = torch.tensor([0]).to(device)
101102

0 commit comments

Comments
 (0)