@@ -49,23 +49,26 @@ def _to_16bit(inputs):
49
49
else :
50
50
return inputs
51
51
52
- self .specs .append (lambda inputss : tuple (_to_16bit (inputs ) for inputs in inputss ))
52
+ self .specs .append (lambda inputss : tuple (tuple ( _to_16bit (inputs ) ) for inputs in inputss ))
53
53
54
54
# Embedding layer
55
55
self .specs .append (TiedLayerSpec ('embed' ,
56
56
EmbeddingPipe ,
57
57
args .hidden_size ,
58
58
args .padded_vocab_size ,
59
59
args .hidden_dropout ,
60
+ forward_fn = lambda module , inputs , targets : (module (* inputs ), module (* targets )),
60
61
init_method = init_method ,
61
62
num_tokentypes = num_tokentypes ,
62
63
tied_weight_attr = 'word_embeddings_weight' ))
63
64
64
65
assert hasattr (args , 'attn_mask' ), "Deepspeed integration should have attention mask s"
66
+ # Drop everything beside tokens
67
+ self .specs .append (lambda inputs , targets : (inputs [0 ], targets [0 ]))
65
68
if args .fp32_residual_connection :
66
- self .specs .append (lambda x : x .transpose (0 , 1 ).contiguous ().float ())
69
+ self .specs .append (lambda input_tokens , target_tokens : ( input_tokens .transpose (0 , 1 ).contiguous ().float (), target_tokens . transpose ( 0 , 1 ). contiguous (). float () ))
67
70
else :
68
- self .specs .append (lambda x : x .transpose (0 , 1 ).contiguous ())
71
+ self .specs .append (lambda input_tokens , target_tokens : ( input_tokens .transpose (0 , 1 ).contiguous (), target_tokens . transpose ( 0 , 1 ). contiguous () ))
69
72
70
73
### ----- Encoder -----
71
74
for layer_idx in range (args .num_layers ):
@@ -74,22 +77,21 @@ def _to_16bit(inputs):
74
77
f"block_{ layer_idx } " ,
75
78
ParallelTransformerLayerPipe ,
76
79
init_method = init_method ,
77
- # Inputs: (input_tokens, target_tokens,
78
- forward_fn = lambda module , * inputs : ,
80
+ forward_fn = lambda module , input_tokens , target_tokens : (module (input_tokens ), target_tokens ),
79
81
output_layer_init_method = scaled_init_method_normal (args .init_method_std ,
80
82
args .num_layers ),
81
83
layer_type = LayerType .encoder ,
82
84
layer_number = layer_idx ,
83
85
self_attn_mask_type = AttnMaskType .causal ,
84
- tied_weight_attrs = ["input_layernorm" , " self_attention" , "post_attention_layernorm " , "mlp" ]
86
+ tied_weight_attrs = ["self_attention" , "mlp" ]
85
87
))
86
88
87
89
# Final layernorm after encoder layers
88
90
self .specs .append (
89
- TiedLayerSpec (
90
- "final_layer_norm" ,
91
+ LayerSpec (
91
92
LayerNorm ,
92
93
args .hidden_size ,
94
+ forward_fn = lambda module , input_tokens , target_tokens : (module (input_tokens ), target_tokens ),
93
95
eps = args .layernorm_epsilon
94
96
))
95
97
@@ -100,19 +102,22 @@ def _to_16bit(inputs):
100
102
f"block_{ layer_idx } " ,
101
103
ParallelTransformerLayerPipe ,
102
104
init_method = init_method ,
105
+ forward_fn = lambda module , encoded_tokens , target_tokens : (encoded_tokens , module (target_tokens , encoder_output = encoded_tokens )),
103
106
output_layer_init_method = scaled_init_method_normal (args .init_method_std ,
104
107
args .num_layers ),
105
108
layer_number = layer_idx ,
106
109
layer_type = LayerType .decoder ,
107
110
self_attn_mask_type = AttnMaskType .padding ,
108
- tied_weight_attrs = ["input_layernorm" , " self_attention" , "post_attention_layernorm " , "mlp" ]
111
+ tied_weight_attrs = ["self_attention" , "mlp" ]
109
112
)
110
113
)
111
114
115
+ # Drop encoded tokens
116
+ self .specs .append (lambda encoded_tokens , target_tokens : target_tokens )
117
+
112
118
# Final layernorm after decoder layers
113
119
self .specs .append (
114
- TiedLayerSpec (
115
- "final_layer_norm" ,
120
+ LayerSpec (
116
121
LayerNorm ,
117
122
args .hidden_size ,
118
123
eps = args .layernorm_epsilon
0 commit comments