Tidbits on Transformers
... 3 years too late?
These are some common but tricky-to-google questions about the Transformer architecture.
What are the input and output formats of Transformers?
Consider the sentence, Seoul deserves a better slogan than "I Seoul U"
.
During training:
During training, the Transformer architecture requires three inputs: encoder input,
decoder input, and decoder target (labels).
- Encoder InputsEncoder inputs may or may not have
<bos>
and<eos>
. Depends on the architecture. Which is to say it doesn't matter (The input sequence gets all jumbled up by the encoder anyway and becomes an unrecognizable hidden state).Seoul deserves a better slogan than "I Seoul U"
- or
<bos> Seoul deserves a better slogan than "I Seoul U" <eos>
- Decoder InputsDecoder inputs must be prepended with
<bos>
.<bos> Seoul deserves a better slogan than "I Seoul U"
- Decoder outputsDecoder outputs must be appended with
<eos>
(left shifted with respect to decoder input).Seoul deserves a better slogan than "I Seoul U" <eos>
Remember it like this: we feed <bos>
to the decoder to signal start of the decoding process. Upon seeing <bos>
the decoder must output the first non-bos token from our target, so we structure the target labels this way.
During testing:
There are two inputs. Encoder input and decoder input. Makes sense because if we knew the decoder output we wouldn't be decoding in the first place.
- Encoder input: same as in training.
- Decoder input: same as in training.
Is the decoding process in Transformers different during training and inference?
Yes.
Remember: the training process is not autoregressive (meaning we don't run the whole decoder again before generating each token in a sequence)!
The whole sentence is processed through the decoder at once. We could make decoder training autoregressive like an RNN, but it's pointless because we want teacher forcing (pretending the decoder produces the correct token at previous time step). Parallel loss computation of every token in the decoder sequence automatically ensures teacher forcing.
However, decoding at testing (inference) time is autoregressive. It is autoregressive because at first all you have as the decoder input is <bos>
. After you get the initial output (a single token) with <bos>
as decoder input, you append that token to the decoder input of the next step. Repeat until you see <eos>
from the decoder.
You could make the process non-autoregressive if you are confident that providing just as the decoder input will generate a meaningful token for every decoder timestamp.
How can the triangular decoder attention mask (a.k.a no-lookahead mask a.k.a no-peek mask) be 2d while encoder attention mask isn't?
Remember: attention masks are not applied to the input sequence! They are applied to the attention matrix!!
Attention matrix is a square. Makes sense, because to represent the directional intensity of interest from every token to every token, you need a square matrix.
So it is natural that the decoder attention mask is 2d. And it is triangular because we don't want tokens attending to future tokens.
Encoder attention mask is not 2d, but it is broadcasted to be 2d when being multiplied with the attention matrix.
Remember: attention masks are not applied to the input sequence! They are applied to the attention matrix!!
WTF are Query, Key, and Value?
Just remember QKV is used to obtain the attention matrix discussed above.
- Self attention in Encoder:
Query
: Encoder input * some learned weight matrixKey
: Encoder input * another learned weight matrixValue
: Encoder input * another learned weight matrix - Self attention in Decoder:
Query
: Decoder input * some learned weight matrixKey
: Decoder input * another learned weight matrixValue
: Decoder input * another learned weight matrix - Cross attention from encoder to decoder:
Query
: Decoder input * some learned weight matrixKey
: Encoder hidden state * another learned weight matrixValue
: Encoder hidden state * another learned weight matrix
The learned matrices exist for the network to actually learn a representation that we assume exist.
Swap "input" to "hidden state" in layers > 1.