images are from https://zh-v2.d2l.ai/
-
an architecture commonly used in NLP and other types of tasks
-
Encoder: take raw input and represent the input as tensors after processing(could be word2vec, neural layers, attention...)
-
Decoder: mainly for outputting the result to desire form([0, 1], probability distribution, classification, etc)
from torch import nn class Encoder(nn.Module): def __init__(self, **kwargs): super(Encoder, self).__init__(**kwargs) def forward(self, X, *args): raise NotImplementedError class Decoder(nn.Module): def __init__(self, **kwargs): super(Decoder, self).__init__(**kwargs) def init_state(self, encoder_outputs, *args): raise NotImplementedError def forward(self, X, state): raise NotImplementedError
-
A specific type of tasks whose input and output are both sequences of any length
-
Ex. Machine Translation
-
Common arch of seq2seq models:
-
Machine Translation using RNN
-
BLEU(Bilingual Evaluation Understudy) for machine translation
-
formula: $\exp\left(\min\left(0, 1 - \frac{\mathrm{len}{\text{label}}}{\mathrm{len}{\text{pred}}}\right)\right) \prod_{n=1}^k p_n^{1/2^n}$
-
where
represent the n-gram
accuracy
-
-
Attention Mechanism, KVQ
-
Key
: what is presented -
Value
: sensory inputs(?) -
Query
: what we are interested - The idea is to using Query to find "important"
Key
s
-
-
Attention Score,
- model the relationship(importance, similarity) of
Keys
&Querys
- Kernel Regression
- Additive Attention
- Scaled Dot-Product Attention
- Matrix form:
- model the relationship(importance, similarity) of
- Notice the difference with
- Here, the
Query
is decoder's input,Key
&Value
are both encoders output (final hidden state)
- Self-attention means
- So we are trying to find the relationship between one token
with other tokens -
, where is Query
andis Key-Value
- Self-attention does not contain information about relative positions (of tokens)
- Position Encoding aims to "encode" some relative position information to the input
- A commonly used position encoding method is using these
and - for the Position Encoding Matrix
- for the Position Encoding Matrix
-
Multi-head Attention aims to capture different "relationships" between
Query
andKey
using multiple parallel attention layers and concat them to get the final result. -
Mathematically:
- $\mathbf{h}i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}$, where $f$ is some kind of attention function and $h_i$ is the $i{th}$ head
- $result=\begin{split}\mathbf W_o \begin{bmatrix}\mathbf h_1\\vdots\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}\end{split}$
class MultiHeadAttention(nn.Module): def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # assuming num_queries = num_keys = num_values # initial queries: # (batch_size, num_queries, num_hiddens) # transformed queries: # (batch_size * num_heads, num_queries, num_hiddens/num_heads) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: valid_lens = torch.repeat_interleave( valid_lens, repeats=self.num_heads, dim=0) # (batch_size * num_heads, num_queries, num_hiddens/num_heads) output = self.attention(queries, keys, values, valid_lens) # (batch_size, num_queries, num_hiddens) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat)
def transpose_qkv(X, num_heads): # (batch_size, num_queries, num_hiddens) # (batch_size, num_queries, num_heads, num_hiddens/num_heads) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # (batch_size, num_heads, num_queries, num_hiddens/num_heads) X = X.permute(0, 2, 1, 3) # (batch_size * num_heads, num_queries, num_hiddens/num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) def transpose_output(X, num_heads): """ reverse `transpose_qkv` """ X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1)
-
Shaping
Annotated graph