Skip to content

Commit 5e572fa

Browse files
use mutator method to set mask in mha
1 parent 0aa3f28 commit 5e572fa

File tree

4 files changed

+28
-27
lines changed

4 files changed

+28
-27
lines changed

models/transformer/decoder.hpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,14 @@ class TransformerDecoder
147147
// Masked Self attention layer.
148148
Sequential<>* maskedSelfAttention = new Sequential<>(false);
149149
maskedSelfAttention->Add(decoderInput);
150-
maskedSelfAttention->Add<MultiheadAttention<
151-
arma::mat, arma::mat, RegularizerType>>(
152-
tgtSeqLen,
153-
tgtSeqLen,
154-
dModel,
155-
numHeads,
156-
attentionMask);
150+
151+
MultiheadAttention<>* mha1 = new MultiheadAttention<>(tgtSeqLen,
152+
tgtSeqLen,
153+
dModel,
154+
numHeads);
155+
mha1->AttentionMask() = attentionMask;
156+
157+
maskedSelfAttention->Add(mha1);
157158

158159
// Residual connection.
159160
AddMerge<>* residualAdd1 = new AddMerge<>();
@@ -179,14 +180,13 @@ class TransformerDecoder
179180
// Encoder-decoder attention.
180181
Sequential<>* encoderDecoderAttention = new Sequential<>(false);
181182
encoderDecoderAttention->Add(encoderDecoderAttentionInput);
182-
encoderDecoderAttention->Add<MultiheadAttention<
183-
arma::mat, arma::mat, RegularizerType>>(
184-
tgtSeqLen,
185-
srcSeqLen,
186-
dModel,
187-
numHeads,
188-
arma::mat(), // No attention mask to encoder-decoder attention.
189-
keyPaddingMask);
183+
184+
MultiheadAttention<>* mha2 = new MultiheadAttention<>(tgtSeqLen,
185+
srcSeqLen,
186+
dModel,
187+
numHeads);
188+
mha2->KeyPaddingMask() = keyPaddingMask;
189+
encoderDecoderAttention->Add(mha2);
190190

191191
// Residual connection.
192192
AddMerge<>* residualAdd2 = new AddMerge<>();

models/transformer/encoder.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class TransformerEncoder
9595
/**
9696
* Get the Transformer Encoder Model.
9797
*/
98-
Sequential<arma::mat, arma::mat, false>* Model()
98+
Sequential<>* Model()
9999
{
100100
return encoder;
101101
}
@@ -140,13 +140,14 @@ class TransformerEncoder
140140
/* Self attention layer. */
141141
Sequential<>* selfAttn = new Sequential<>(false);
142142
selfAttn->Add(input);
143-
selfAttn->Add<MultiheadAttention<arma::mat, arma::mat, RegularizerType>>(
144-
srcSeqLen,
145-
srcSeqLen,
146-
dModel,
147-
numHeads,
148-
attentionMask,
149-
keyPaddingMask);
143+
144+
MultiheadAttention<>* mha = new MultiheadAttention<>(srcSeqLen,
145+
srcSeqLen,
146+
dModel,
147+
numHeads);
148+
mha->AttentionMask() = attentionMask;
149+
mha->KeyPaddingMask() = keyPaddingMask;
150+
selfAttn->Add(mha);
150151

151152
/* This layer adds a residual connection. */
152153
AddMerge<>* residualAdd = new AddMerge<>();

models/transformer/encoder_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TransformerEncoder<ActivationFunction, RegularizerType>::TransformerEncoder(
4040
keyPaddingMask(keyPaddingMask),
4141
ownMemory(ownMemory)
4242
{
43-
encoder = new Sequential<arma::mat, arma::mat, false>(false);
43+
encoder = new Sequential<>(false);
4444

4545
for (size_t n = 0; n < numLayers; ++n)
4646
{

tests/ffn_model_tests.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ BOOST_AUTO_TEST_CASE(TransformerEncoderTest)
6868
mlpack::ann::TransformerEncoder<> encoder(numLayers, srcSeqLen,
6969
dModel, numHeads, dimFFN, dropout);
7070

71-
FFN<> model;
71+
FFN<NegativeLogLikelihood<>, XavierInitialization> model;
7272

7373
model.Add(encoder.Model());
7474
model.Add<Linear<>>(dModel * srcSeqLen, vocabSize);
@@ -103,7 +103,7 @@ BOOST_AUTO_TEST_CASE(TransformerDecoderTest)
103103
mlpack::ann::TransformerDecoder<> decoder(numLayers, tgtSeqLen, srcSeqLen,
104104
dModel, numHeads, dimFFN, dropout);
105105

106-
FFN<> model;
106+
FFN<NegativeLogLikelihood<>, XavierInitialization> model;
107107

108108
model.Add(decoder.Model());
109109
model.Add<Linear<>>(dModel * tgtSeqLen, vocabSize);
@@ -148,7 +148,7 @@ BOOST_AUTO_TEST_CASE(TransformerTest)
148148
mlpack::ann::Transformer<> transformer(numLayers, tgtSeqLen, srcSeqLen,
149149
tgtVocabSize, srcVocabSize, dModel, numHeads, dimFFN, dropout);
150150

151-
FFN<> model;
151+
FFN<NegativeLogLikelihood<>, XavierInitialization> model;
152152

153153
model.Add(transformer.Model());
154154
model.Add<Linear<>>(dModel * tgtSeqLen, tgtVocabSize);

0 commit comments

Comments
 (0)