Skip to content

Commit 244b66b

Browse files
Jiayu Yetensorflower-gardener
Jiayu Ye
authored andcommitted
Internal change
PiperOrigin-RevId: 436647758
1 parent 5d40c99 commit 244b66b

File tree

1 file changed

+7
-2
lines changed
  • official/nlp/modeling/models

1 file changed

+7
-2
lines changed

official/nlp/modeling/models/t5.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,7 @@ class T5TransformerParams:
10041004
num_heads: int
10051005
d_ff: int
10061006
vocab_size: int
1007+
target_vocab_size: Optional[int] = None
10071008
dropout_rate: float = 0.0
10081009
layer_norm_epsilon: float = 1e-6
10091010
shared_embedding: bool = False
@@ -1159,11 +1160,15 @@ def __init__(self,
11591160
self.compute_dtype = compute_dtype
11601161
if self.config.num_decoder_layers is None:
11611162
self.config.num_decoder_layers = self.config.num_layers
1163+
if not hasattr(
1164+
self.config,
1165+
"target_vocab_size") or self.config.target_vocab_size is None:
1166+
self.config.target_vocab_size = self.config.vocab_size
11621167
with self.name_scope:
11631168
# Target Embedding.
11641169
if shared_embedding is None:
11651170
self.target_embed = Embed(
1166-
vocab_size=self.config.vocab_size,
1171+
vocab_size=self.config.target_vocab_size,
11671172
features=self.config.d_model,
11681173
embeddings_initializer=self.config.vocab_embeddings_initializer,
11691174
dtype=self.dtype,
@@ -1211,7 +1216,7 @@ def __init__(self,
12111216
if not self.config.logits_via_embedding:
12121217
self.logits_dense = Linear(
12131218
in_features=self.config.d_model,
1214-
out_features=self.config.vocab_size,
1219+
out_features=self.config.target_vocab_size,
12151220
use_bias=False,
12161221
dtype=self.dtype,
12171222
name="logits")

0 commit comments

Comments
 (0)