Skip to content

Commit 0c08f69

Browse files
complete encoder decoder and transformer model
1 parent 0a307db commit 0c08f69

File tree

11 files changed

+1023
-1
lines changed

11 files changed

+1023
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
build*
22
xcode*
3+
.vscode/
34
.DS_Store
45
.idea
56
cmake-build-*

models/CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
22
project(models)
33

4-
add_subdirectory(darknet)
4+
# Recurse into each model mlpack provides.
5+
set(DIRS
6+
darknet
7+
transformer
8+
)
9+
10+
foreach(dir ${DIRS})
11+
add_subdirectory(${dir})
12+
endforeach()
513

614
# Add directory name to sources.
715
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)

models/models.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
/**
2+
* @file models.hpp
3+
* @author Mrityunjay Tripathi
4+
*
5+
* This includes various models.
6+
*/
7+
8+
#include "transformer/encoder.hpp"
9+
#include "transformer/decoder.hpp"

models/transformer/CMakeLists.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
2+
project(transformer)
3+
4+
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)
5+
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../")
6+
7+
set(SOURCES
8+
decoder.hpp
9+
decoder_impl.hpp
10+
encoder.hpp
11+
encoder_impl.hpp
12+
transformer.hpp
13+
transformer_impl.hpp
14+
)
15+
16+
foreach(file ${SOURCES})
17+
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
18+
endforeach()
19+
20+
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE)

models/transformer/decoder.hpp

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/**
2+
* @file models/transformer/decoder.hpp
3+
* @author Mikhail Lozhnikov
4+
* @author Mrityunjay Tripathi
5+
*
6+
* Definition of the Transformer Decoder layer.
7+
*
8+
* mlpack is free software; you may redistribute it and/or modify it under the
9+
* terms of the 3-clause BSD license. You should have received a copy of the
10+
* 3-clause BSD license along with mlpack. If not, see
11+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
12+
*/
13+
14+
#ifndef MODELS_TRANSFORMER_DECODER_HPP
15+
#define MODELS_TRANSFORMER_DECODER_HPP
16+
17+
#include <mlpack/prereqs.hpp>
18+
#include <mlpack/methods/ann/layer/layer_types.hpp>
19+
#include <mlpack/methods/ann/layer/base_layer.hpp>
20+
#include <mlpack/methods/ann/regularizer/no_regularizer.hpp>
21+
22+
namespace mlpack {
23+
namespace ann /** Artificial Neural Network. */ {
24+
25+
/**
26+
* In addition to the two sub-layers in each encoder layer, the decoder inserts
27+
* a third sub-layer, which performs multi-head attention over the output of the
28+
* encoder stack. Similar to the encoder, we employ residual connections around
29+
* each of the sub-layers, followed by layer normalization. We also modify the
30+
* self-attention sub-layer in the decoder stack to prevent positions from
31+
* attending to subsequent positions. This masking, combined with fact that the
32+
* output embeddings are offset by one position, ensures that the predictions
33+
* for position i can depend only on the known outputs at positions less than i.
34+
*
35+
* @tparam ActivationFunction The type of the activation function to be used in
36+
* the position-wise feed forward neural network.
37+
* @tparam RegularizerType The type of regularizer to be applied to layer
38+
* parameters.
39+
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
40+
* arma::sp_mat or arma::cube).
41+
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
42+
* arma::sp_mat or arma::cube).
43+
*/
44+
template <
45+
typename ActivationFunction = ReLULayer<>,
46+
typename RegularizerType = NoRegularizer,
47+
typename InputDataType = arma::mat,
48+
typename OutputDataType = arma::mat
49+
>
50+
class TransformerDecoder
51+
{
52+
public:
53+
TransformerDecoder();
54+
55+
/**
56+
* Create the TransformerDecoder object using the specified parameters.
57+
*
58+
* @param numLayers The number of decoder blocks.
59+
* @param tgtSeqLen Target Sequence Length.
60+
* @param srcSeqLen Source Sequence Length.
61+
* @param memoryModule The last Encoder module.
62+
* @param dModel The number of features in the input. Also, same as the
63+
* 'embedDim' in 'MultiheadAttention' layer.
64+
* @param numHeads The number of attention heads.
65+
* @param dimFFN The dimentionality of feedforward network.
66+
* @param dropout The dropout rate.
67+
* @param attentionMask The attention mask used to black-out future sequences.
68+
* @param keyPaddingMask The padding mask used to black-out particular token.
69+
*/
70+
TransformerDecoder(const size_t numLayers,
71+
const size_t tgtSeqLen,
72+
const size_t srcSeqLen,
73+
const size_t dModel = 512,
74+
const size_t numHeads = 8,
75+
const size_t dimFFN = 1024,
76+
const double dropout = 0.1,
77+
const InputDataType& attentionMask = InputDataType(),
78+
const InputDataType& keyPaddingMask = InputDataType());
79+
80+
/**
81+
* Get the Transformer Decoder model.
82+
*/
83+
Sequential<>* Model() { return decoder; }
84+
/**
85+
* Load the network from a local directory.
86+
*
87+
* @param filepath The location of the stored model.
88+
*/
89+
void LoadModel(const std::string& filepath);
90+
91+
/**
92+
* Save the network locally.
93+
*
94+
* @param filepath The location where the model is to be saved.
95+
*/
96+
void SaveModel(const std::string& filepath);
97+
98+
//! Get the key matrix, the output of the Transformer Encoder.
99+
InputDataType const& Key() const { return key; }
100+
101+
//! Modify the key matrix.
102+
InputDataType& Key() { return key; }
103+
104+
private:
105+
/**
106+
* This method adds the attention block to the decoder.
107+
*/
108+
void AttentionBlock()
109+
{
110+
Sequential<>* decoderBlockBottom = new Sequential<>();
111+
decoderBlockBottom->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);
112+
113+
// Broadcast the incoming input to decoder
114+
// i.e. query into (query, key, value).
115+
Concat<>* decoderInput = new Concat<>();
116+
decoderInput->Add<IdentityLayer<>>();
117+
decoderInput->Add<IdentityLayer<>>();
118+
decoderInput->Add<IdentityLayer<>>();
119+
120+
// Masked Self attention layer.
121+
Sequential<>* maskedSelfAttention = new Sequential<>();
122+
maskedSelfAttention->Add(decoderInput);
123+
maskedSelfAttention->Add<MultiheadAttention<
124+
InputDataType, OutputDataType, RegularizerType>>(
125+
tgtSeqLen,
126+
tgtSeqLen,
127+
dModel,
128+
numHeads,
129+
attentionMask
130+
);
131+
132+
// Residual connection.
133+
AddMerge<>* residualAdd = new AddMerge<>();
134+
residualAdd->Add(maskedSelfAttention);
135+
residualAdd->Add<IdentityLayer<>>();
136+
137+
decoderBlockBottom->Add(residualAddMerge);
138+
139+
// Add the LayerNorm layer with required parameters.
140+
decoderBlockBottom->Add<LayerNorm<>>(dModel * tgtSeqLen);
141+
142+
// This layer broadcasts the output of encoder i.e. key into (key, value).
143+
Concat<>* broadcastEncoderOutput = new Concat<>();
144+
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);
145+
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);
146+
147+
// This layer concatenates the output of the bottom decoder block (query)
148+
// and the output of the encoder (key, value).
149+
Concat<>* encoderDecoderAttentionInput = new Concat<>();
150+
encoderDecoderAttentionInput->Add(decoderBlockBottom);
151+
encoderDecoderAttentionInput->Add(broadcastEncoderOutput);
152+
153+
// Encoder-decoder attention.
154+
Sequential<>* encoderDecoderAttention = new Sequential<>();
155+
encoderDecoderAttention->Add(encoderDecoderAttentionInput);
156+
encoderDecoderAttention->Add<MultiheadAttention<
157+
InputDataType, OutputDataType, RegularizerType>>(
158+
tgtSeqLen,
159+
srcSeqLen,
160+
dModel,
161+
numHeads,
162+
InputDatatype(), // No attention mask to encoder-decoder attention.
163+
keyPaddingMask);
164+
165+
// Residual connection.
166+
AddMerge<>* residualAdd = new AddMerge<>();
167+
residualAdd->Add(encoderDecoderAttention);
168+
residualAdd->Add<IdentityLayer<>>();
169+
170+
decoder->Add(residualAdd);
171+
decoder->Add<LayerNorm<>>(dModel * tgtSeqLen);
172+
}
173+
174+
/**
175+
* This method adds the position-wise feed forward network to the decoder.
176+
*/
177+
void PositionWiseFFNBlock()
178+
{
179+
Sequential<>* positionWiseFFN = new Sequential<>();
180+
positionWiseFFN->Add<Linear3D<>>(dModel, dimFFN);
181+
positionWiseFFN->Add<ActivationFunction>();
182+
positionWiseFFN->Add<Linear3D<>>(dimFFN, dModel);
183+
positionWiseFFN->Add<Dropout<>>(dropout);
184+
185+
/* Residual connection. */
186+
AddMerge<>* residualAdd = new AddMerge<>();
187+
residualAdd->Add(positionWiseFFN);
188+
residualAdd->Add<IdentityLayer<>>();
189+
decoder->Add(residualAdd);
190+
}
191+
192+
//! Locally-stored number of decoder layers.
193+
size_t numLayers;
194+
195+
//! Locally-stored target sequence length.
196+
size_t tgtSeqLen;
197+
198+
//! Locally-stored source sequence length.
199+
size_t srcSeqLen;
200+
201+
//! Locally-stored number of input units.
202+
size_t dModel;
203+
204+
//! Locally-stored number of output units.
205+
size_t numHeads;
206+
207+
//! Locally-stored weight object.
208+
size_t dimFFN;
209+
210+
//! Locally-stored weight parameters.
211+
double dropout;
212+
213+
//! Locally-stored attention mask.
214+
InputDataType attentionMask;
215+
216+
//! Locally-stored key padding mask.
217+
InputDataType keyPaddingMask;
218+
219+
//! Locally-stored complete decoder network.
220+
Sequential<InputDataType, OutputDataType, false>* decoder;
221+
222+
}; // class TransformerDecoder
223+
224+
} // namespace ann
225+
} // namespace mlpack
226+
227+
// Include implementation.
228+
#include "decoder_impl.hpp"
229+
230+
#endif

models/transformer/decoder_impl.hpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/**
2+
* @file models/transformer/decoder_impl.hpp
3+
* @author Mikhail Lozhnikov
4+
* @author Mrityunjay Tripathi
5+
*
6+
* Implementation of the Transformer Decoder class.
7+
*
8+
* mlpack is free software; you may redistribute it and/or modify it under the
9+
* terms of the 3-clause BSD license. You should have received a copy of the
10+
* 3-clause BSD license along with mlpack. If not, see
11+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
12+
*/
13+
14+
#ifndef MODELS_TRANSFORMER_DECODER_IMPL_HPP
15+
#define MODELS_TRANSFORMER_DECODER_IMPL_HPP
16+
17+
#include "decoder.hpp"
18+
19+
namespace mlpack {
20+
namespace ann /** Artificial Neural Network. */ {
21+
22+
template<typename ActivationFunction, typename RegularizerType,
23+
typename InputDataType, typename OutputDataType>
24+
TransformerDecoder<ActivationFunction, RegularizerType, InputDataType,
25+
OutputDataType>::TransformerDecoder() :
26+
tgtSeqLen(0),
27+
srcSeqLen(0),
28+
memoryModule(NULL),
29+
dModel(0),
30+
numHeads(0),
31+
dimFFN(0),
32+
dropout(0)
33+
{
34+
// Nothing to do here.
35+
}
36+
37+
template<typename ActivationFunction, typename RegularizerType,
38+
typename InputDataType, typename OutputDataType>
39+
TransformerDecoder<ActivationFunction, RegularizerType, InputDataType,
40+
OutputDataType>::TransformerDecoder(
41+
const size_t numLayers,
42+
const size_t tgtSeqLen,
43+
const size_t srcSeqLen,
44+
const size_t dModel,
45+
const size_t numHeads,
46+
const size_t dimFFN,
47+
const double dropout,
48+
const InputDataType& attentionMask,
49+
const InputDataType& keyPaddingMask) :
50+
numLayers(numLayers),
51+
tgtSeqLen(tgtSeqLen),
52+
srcSeqLen(srcSeqLen),
53+
dModel(dModel),
54+
numHeads(numHeads),
55+
dimFFN(dimFFN),
56+
dropout(dropout),
57+
attentionMask(attentionMask),
58+
keyPaddingMask(keyPaddingMask)
59+
{
60+
decoder = new Sequential<InputDataType, OutputDataType, false>();
61+
62+
for (size_t N = 0; N < numLayers; ++N)
63+
{
64+
AttentionBlock();
65+
PositionWiseFFNBlock();
66+
}
67+
}
68+
69+
template<typename ActivationFunction, typename RegularizerType,
70+
typename InputDataType, typename OutputDataType>
71+
void TransformerDecoder<ActivationFunction, RegularizerType,
72+
InputDataType, OutputDataType>::LoadModel(const std::string& filepath)
73+
{
74+
data::Load(filepath, "TransformerDecoder", decoder);
75+
std::cout << "Loaded model" << std::endl;
76+
}
77+
78+
template<typename ActivationFunction, typename RegularizerType,
79+
typename InputDataType, typename OutputDataType>
80+
void TransformerDecoder<ActivationFunction, RegularizerType,
81+
InputDataType, OutputDataType>::SaveModel(const std::string& filepath)
82+
{
83+
std::cout << "Saving model" << std::endl;
84+
data::Save(filepath, "TransformerDecoder", decoder);
85+
std::cout << "Model saved in " << filepath << std::endl;
86+
}
87+
88+
} // namespace ann
89+
} // namespace mlpack
90+
91+
#endif

0 commit comments

Comments
 (0)