Skip to content

Commit 8499c92

Browse files
RyanMullinsmattdangerw
authored andcommitted
Gemma: Add logit soft-capping to score function. (#1712)
1 parent 7e56dbd commit 8499c92

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

keras_nlp/src/models/gemma/gemma_causal_lm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,12 @@ def default_layer_intercept_fn(x, unused_i):
445445
x = self.backbone.layer_norm(x)
446446
logits = self.backbone.token_embedding(x, reverse=True)
447447

448+
if self.backbone.final_logit_soft_cap is not None:
449+
logits = ops.divide(logits, self.backbone.final_logit_soft_cap)
450+
logits = ops.multiply(
451+
ops.tanh(logits), self.backbone.final_logit_soft_cap
452+
)
453+
448454
if scoring_mode == "logits":
449455
return logits
450456

0 commit comments

Comments
 (0)