Skip to content

Commit 60f3047

Browse files
author
Negar Foroutan Eghlidi
committed
Fix a device issue.
1 parent c41730d commit 60f3047

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/nanotron/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
695695
if not lang_losses[
696696
lang
697697
]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation
698-
lang_losses[lang] = torch.tensor(-1, dtype=torch.float32)
698+
lang_losses[lang] = torch.tensor(-1, dtype=torch.float32, device="cuda")
699699
else: # If we have at least 1 loss from a given language --> compute local language loss mean
700700
lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang]))
701701

0 commit comments

Comments
 (0)