Skip to content

Commit 9b4a741

Browse files
committed
Fix shared T5
1 parent 86aa795 commit 9b4a741

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

megatron/model/shared_t5_model.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,26 @@ def _to_16bit(inputs):
4949
else:
5050
return inputs
5151

52-
self.specs.append(lambda inputss: tuple(_to_16bit(inputs) for inputs in inputss))
52+
self.specs.append(lambda inputss: tuple(tuple(_to_16bit(inputs)) for inputs in inputss))
5353

5454
# Embedding layer
5555
self.specs.append(TiedLayerSpec('embed',
5656
EmbeddingPipe,
5757
args.hidden_size,
5858
args.padded_vocab_size,
5959
args.hidden_dropout,
60+
forward_fn=lambda module, inputs, targets: (module(*inputs), module(*targets)),
6061
init_method=init_method,
6162
num_tokentypes=num_tokentypes,
6263
tied_weight_attr='word_embeddings_weight'))
6364

6465
assert hasattr(args, 'attn_mask'), "Deepspeed integration should have attention mask s"
66+
# Drop everything beside tokens
67+
self.specs.append(lambda inputs, targets: (inputs[0], targets[0]))
6568
if args.fp32_residual_connection:
66-
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
69+
self.specs.append(lambda input_tokens, target_tokens: (input_tokens.transpose(0, 1).contiguous().float(), target_tokens.transpose(0, 1).contiguous().float()))
6770
else:
68-
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
71+
self.specs.append(lambda input_tokens, target_tokens: (input_tokens.transpose(0, 1).contiguous(), target_tokens.transpose(0, 1).contiguous()))
6972

7073
### ----- Encoder -----
7174
for layer_idx in range(args.num_layers):
@@ -74,22 +77,21 @@ def _to_16bit(inputs):
7477
f"block_{layer_idx}",
7578
ParallelTransformerLayerPipe,
7679
init_method=init_method,
77-
# Inputs: (input_tokens, target_tokens,
78-
forward_fn=lambda module, *inputs: ,
80+
forward_fn=lambda module, input_tokens, target_tokens: (module(input_tokens), target_tokens),
7981
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
8082
args.num_layers),
8183
layer_type=LayerType.encoder,
8284
layer_number=layer_idx,
8385
self_attn_mask_type=AttnMaskType.causal,
84-
tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"]
86+
tied_weight_attrs=["self_attention", "mlp"]
8587
))
8688

8789
# Final layernorm after encoder layers
8890
self.specs.append(
89-
TiedLayerSpec(
90-
"final_layer_norm",
91+
LayerSpec(
9192
LayerNorm,
9293
args.hidden_size,
94+
forward_fn=lambda module, input_tokens, target_tokens: (module(input_tokens), target_tokens),
9395
eps=args.layernorm_epsilon
9496
))
9597

@@ -100,19 +102,22 @@ def _to_16bit(inputs):
100102
f"block_{layer_idx}",
101103
ParallelTransformerLayerPipe,
102104
init_method=init_method,
105+
forward_fn=lambda module, encoded_tokens, target_tokens: (encoded_tokens, module(target_tokens, encoder_output=encoded_tokens)),
103106
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
104107
args.num_layers),
105108
layer_number=layer_idx,
106109
layer_type=LayerType.decoder,
107110
self_attn_mask_type=AttnMaskType.padding,
108-
tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"]
111+
tied_weight_attrs=["self_attention", "mlp"]
109112
)
110113
)
111114

115+
# Drop encoded tokens
116+
self.specs.append(lambda encoded_tokens, target_tokens: target_tokens)
117+
112118
# Final layernorm after decoder layers
113119
self.specs.append(
114-
TiedLayerSpec(
115-
"final_layer_norm",
120+
LayerSpec(
116121
LayerNorm,
117122
args.hidden_size,
118123
eps=args.layernorm_epsilon

0 commit comments

Comments
 (0)