Tutorial
If you have time, go read [4] to get a sense of the big picture (prototypes) before reading this page.
Glossary
Attention
Seq2Seq Model [1] [3]
Sequence to Sequence Learning with Neural Networks (2014)
- Seq2Seq modeling
- An example of Seq2Seq is NMT(neural machine translation). This can help you understand but really doesn’t matter as long as both input and output are sequences. NMT is mentioned because it is the original task that these mechanisms are born for.
- Input and output sequences are $\mathbf{x} = [x_1, x_2, \dots, x_n], \mathbf{y} = [y_1, y_2, \dots, y_m]$ respectively. Notice the length can be different.
- encoder-decoder architecture
- encoder hidden state vector $h_i$, decoder hidden state vector $s_t$
- encoder
-
A encoder compute context vector $c$ (aka sentence embedding or “thought” vector): usually the last hidden state $h_i$ in encoder is passed to the decoder. Therefore the RNN encoder's target vector can be disregarded [4].
[1]
-
An apparent problem is incapability of remembering long sentences.
- decoder
- teacher forcing training: A decoder is trained to turn the target sequences into the same sequence but offset by one timestep in the future.
- Usually $f(s_{t-1}, y_t) \to (s_t, l_t)$, where $l_t$ is output logit vector
- "language model head" layer: Say the RNN output vector / target vector is $o$, then $Y^{vocab} o$ is a cosine-similarity score between vocabolary embeddings and the output vector, therefore can serve as logic vector [4].
- Essentially modelling $P(y_i | y_{0:i-1}, c)$ and use the chain rule to get probability of any output sequence. [4]
- The training loss function can be simply categorical cross-entropy loss.
- Causal mask: therefore, when computing $s_t$, the decoder should only have access to $y_1 \dots y_t$ but not $y_{t+1}$.
- Inference sampling: start from a sequence of size 1 consisting of start-of-sequence token. Repeatly feed into the decoder to sample the probability distribution of next token. Stops when reachs end-of-sequence token (
<eos>
).
- both the encoder and decoder are **Recurrent Neural Networks(**RNN). A apparent problem with RNNs: slow because it’s trained sequentially
- Output layer
[4]
Attention mechanism [1] [2]
Neural Machine Translation by Jointly Learning to Align and Translate (2015)
- The encoder is a bidirectional RNN to “include both the preceding and following words in the annotation of one word”.
- the context vector turned out to be a bottleneck when dealing with long sequences.
-
Instead of passing the last hidden state of the encoding stage, the encoder now passes all the hidden states to the decoder
https://jalammar.github.io/images/seq2seq_7.mp4
-
“the secret sauce invented by attention is to create shortcuts between the context vector and the entire source input. The weights (aka alignment score) of these shortcut connections are customizable for each output element.”
- in different time step $t$, decoder computes a different context vector $c_t$ that pays attention to different parts of all those encoder hidden states.