@@ -1004,6 +1004,7 @@ class T5TransformerParams:
1004
1004
num_heads : int
1005
1005
d_ff : int
1006
1006
vocab_size : int
1007
+ target_vocab_size : Optional [int ] = None
1007
1008
dropout_rate : float = 0.0
1008
1009
layer_norm_epsilon : float = 1e-6
1009
1010
shared_embedding : bool = False
@@ -1159,11 +1160,15 @@ def __init__(self,
1159
1160
self .compute_dtype = compute_dtype
1160
1161
if self .config .num_decoder_layers is None :
1161
1162
self .config .num_decoder_layers = self .config .num_layers
1163
+ if not hasattr (
1164
+ self .config ,
1165
+ "target_vocab_size" ) or self .config .target_vocab_size is None :
1166
+ self .config .target_vocab_size = self .config .vocab_size
1162
1167
with self .name_scope :
1163
1168
# Target Embedding.
1164
1169
if shared_embedding is None :
1165
1170
self .target_embed = Embed (
1166
- vocab_size = self .config .vocab_size ,
1171
+ vocab_size = self .config .target_vocab_size ,
1167
1172
features = self .config .d_model ,
1168
1173
embeddings_initializer = self .config .vocab_embeddings_initializer ,
1169
1174
dtype = self .dtype ,
@@ -1211,7 +1216,7 @@ def __init__(self,
1211
1216
if not self .config .logits_via_embedding :
1212
1217
self .logits_dense = Linear (
1213
1218
in_features = self .config .d_model ,
1214
- out_features = self .config .vocab_size ,
1219
+ out_features = self .config .target_vocab_size ,
1215
1220
use_bias = False ,
1216
1221
dtype = self .dtype ,
1217
1222
name = "logits" )
0 commit comments