From f975a0151f2264ec9a1bdfb87c37ae876134f35d Mon Sep 17 00:00:00 2001 From: T5X Team Date: Fri, 21 Jun 2024 15:37:55 -0700 Subject: [PATCH] Prevent divide by zero error in loss computation. This occurs when trying to export a model. Contains no functional changes to running code. PiperOrigin-RevId: 645521424 --- t5x/losses.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/t5x/losses.py b/t5x/losses.py index a5e365476..8915c813f 100644 --- a/t5x/losses.py +++ b/t5x/losses.py @@ -160,7 +160,10 @@ def compute_weighted_cross_entropy( ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing - low_confidence = (1.0 - confidence) / (vocab_size - 1) + if vocab_size == 1: + low_confidence = 1.0 - confidence + else: + low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)