diff --git a/t5x/losses.py b/t5x/losses.py index a5e365476..8915c813f 100644 --- a/t5x/losses.py +++ b/t5x/losses.py @@ -160,7 +160,10 @@ def compute_weighted_cross_entropy( ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing - low_confidence = (1.0 - confidence) / (vocab_size - 1) + if vocab_size == 1: + low_confidence = 1.0 - confidence + else: + low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)