Skip to content

Commit

Permalink
convert to float32 the loss if it is a bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Nov 10, 2024
1 parent df1112f commit 4e31c87
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion mttl/evaluators/loglike_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def evaluate(
loss_per_option = compute_loglike_loss(
logits, batch["labels"], reduction="none"
)
loss_per_option = loss_per_option.cpu().numpy()
loss_per_option = loss_per_option.cpu()

if loss_per_option.dtype in [torch.bfloat16, torch.float16]:
loss_per_option = loss_per_option.float().numpy()

loss_per_example = [
loss_per_option[
int(np.sum(num_options[:i])) : int(np.sum(num_options[: i + 1]))
Expand Down

0 comments on commit 4e31c87

Please sign in to comment.