Transformer architecture and training tricks (for next-word prediction language models), spelled out in code
A (hopefully) simple guide, explained in torch
Google any variant of “transformer architecture explained”, and you will invariably see the same visualization copy-and-pasted from the 2017 paper Attention is all you need. Personally, I find this fairly useless, since it doesn’t help me build an intuition as to the actual input/output of the model.
Here, I’ll give a short and hopefully useful introduction to the architecture of language models, as well as a few common training tricks.
This post focuses on models that, given an input structure, generate the most probable next-word. This is the type of architecture found in GPT-3 and ChatGPT’s policy model.
This is not an architecture that would work for machine translation, i.e. an English to French translator, since this is a decoder-block only model.
How do we make sentences into numbers?
Firstly, how do we go from a full sentence (“I went to the store”), to some numerical representation we can mash matrix multiplications with. There are two steps: Tokenization and embedding
Tokenization - Taking a sentence and encoding its structure as integers. This could be as simple as encoding each token as a separate integer (i.e. a→1, b→2, etc), mapping words to tokens (aarvark → 1, etc), or some combination of the two. There are lots of ways to tokenize sentences.
For now, let’s pretend our tokenizer maps each word to some integer (giving us a total of 171,146 tokens for the # words in the English dictionary), and our example sentence gets mapped from
“I went to the store!” → (5, 110, 79, 3, 81, 877)
Embedding - Once we tokenize our input, we generally create an embedding for each token. This is simply a lookup table, which maps a given token to a vector of fixed length. For example, with torch we could use
nn.Embedding(171,146, 64)
, which creates a lookup table with 171,146 keys, where each value is initialized to a vector of length 64. This is done as the first layer of the model, so the signal can be backpropagated and the embeddings can be learned.
Q: Why not just do embedding without tokenization?
A: You could embed each sentence, or each word split by space, but there would be a lot more (too many) things for your model to learn. For example, “store!” would be an entirely different token from “store.” or “store”. This isn’t optimal, so the granularity at which we tokenize is an important consideration.
Let T be the tokens in our tokenized sentence, and C be the embedding dimension. That means our unbatched model input is a tensor of shape (T,), or batched (B, T). This is the standard input format to a generative language model.
Note: The T dimension is often called the “context length” or “sequence length” of the model
TLDR; Sentence → mapped to integers → each integer mapped to an embedding vector in the first layer of the model. This means our sentence of shape (T,) gets mapped to (B, T), then embedded to (B, T, C).
What does our model architecture look like?
1. Attention block
Let’s first focused on the so-called self-attention block with a head size H. This module is initialized with three linear layers without bias of shape (C, H) (embedding size, head size) named key, query, and value.
On a forward pass, we perform the following operation,
# B = batch size
# T = number of tokens in sequence
# C = embedding dimension
# H = attention head size
B, T, C = x.shape
k = key(x)
q = query(x)
# don't transpose along batch dim!
# affinities are of shape (B, T, T)
affinity = q @ k.transpose(-2, -1) * C ** (-0.5)
affinity = F.softmax(affinity, dim=-1)
v = value(x)
# Final output is of shape (B, T, H)
x = affinity @ v
To get the keys and queries, we simply multiply the key and query matrix along each element of the batch. The affinities are the multiplication of these two, and the scaling factor is to normalize variance to ~0.
Autoregressive language models (i..e models that predict the next word in a sequence), use masked self attention, which introduces the following change to the calculation of attention scores
# don't transpose along batch dim!
affinity = q @ k.transpose(-2, -1) * C ** (-0.5)
# a lower triangular matrix that masks out future tokens at each step
affinity = affinity.masked_fill(tril[:T, :T] == 0, float("-inf"))
affinity = F.softmax(affinity, dim=-1)
Keep in mind, the affinities are of shape (B, T, T).
So for each batch item, we have a (T, T) tensor representing the attention scores between each key and query.
Note: The “tril” matrix is a lower-diagonal matrix of 1s. (i.e., 1s on and below the diagonal and 0s everywhere else).
The masked fill operation multiplies each (T, T) batch tensor by our lower triangular matrix of 1s, such that future tokens for each subsequence of each batch are set to -inf
. This is fully independently on each input sequence, so no information is leaked between samples along the batch dimension.
This means that once we take the softmax of the affinities, the affinities become lower triangular too!
2. Multi-headed attention
Given the number of “heads” (attention blocks) we want to use, multi-headed attention is simply;
projection = nn.Linear(C, C)
heads = [Attention() for _ in range(num_heads)]
x = torch.concat([head(x) for head in heads], dim=-1)
x = projection(x)
That is, we simply run our input through several attention blocks, stack up the embeddings found by each attention head for each token, and run it through one more linear layer. Keep in mind, this is broadcast along the batch dimension, so each batch is independent.
3. Transformer decoder block
Right — now that we’ve defined self-attention and multi-headed attention, we’re ready to define our transformer block.
Given the embedding dimension C and the num_heads we’d like to use in our multi-headed attention, the standard head size is floor(C/num_heads)
. There is no requirement to use this value, though.
The decoder block looks like the following;
head_size = n_embd // n_heads
self.attention = MultiHeadAttention(
n_heads=n_heads,
n_embd=n_embd,
head_size=head_size,
context_length=context_length,
)
# feedforward section
# multiplier of 4 is recommended from attention is all you need paper
self.ff = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.Linear(4 * n_embd, n_embd),
nn.GELU(),
nn.Dropout(dropout),
)
# Layernorm for training stability
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
That is, a multi-headed attention block followed by a 2-layer multi-layer perceptron with optional dropout.
The forward pass is then
x = self.ln1(x)
x = x + self.attention(x)
x = self.ln2(x)
x = x + self.ff(x)
Notice two things; we normalize the input before the attention block, and use residual connections to add the normalized input to the output of the attention block.
“Temperature” and generating new sequences conditioned on a prompt
The “temperature” hyperparameter, which is used at inference-time only, is a way to scale the logits down such that once we take the softmax, the distribution of probabilities is closer to uniform. In practice, this tends to make the model writing less “stiff” since the most-probable word isn’t always the most “creative”.
Naturally, one might ask how we can generate a sequence of responses based on an input prompt. We do this in the same way we generate our data:
Sentence → tokenizer → embed → pass into model and generate new tokens.
Given an input sequence length T, generating new tokens looks like
prompt = tokenizer.encode(prompt).unsqeeze(0) # add batch dim!
for _ in range(max_new_tokens):
prompt_cond = prompt[:, -T :]
logits = self(prompt_cond)
# focus only on the last time step
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
# sample using multinomial, but can also just take argmax
prompt_next = torch.multinomial(probs, num_samples=1)
prompt = torch.cat((prompt, prompt_next), dim=1)
# Our promp, continued by the model
# remove the batch dim for decoding!
promp = tokenizer.decode(prompt.squeeze())
And hallelujah, you can converse freely with your model! Remember that the amount you can prompt is limited by the sequence length T. Note: GPT-3 has T=3000.
How do we input our data?
For a next-word prediction model, we need to consider how we’d like our data to be input such that calculating the loss makes sense. We do this in the following way; given a sentence of tokens [0, 1, .. K], we let the input be [0: K-1] and the labels be [1: K].
For example, “I went to the store” would be split into “I went to the” for the input and “went to the store” for the target. This way, when we can get the following input/output pairs:
I → went
I went → to
I went to → the
I went to the → store
However, we never explicitly construct these pairs. Rather we use the upper triangular attention mechanism in the decoder to learn all these sequences at the same time. For our loss, we have as output (batch, seq_len, vocab size)
from the model and our integer targets (batch, seq_len)
.
Since we want to compute crossentropy across all tokens, we need to reshape our output via
output = output.view(-1, vocab_size) # (batch*seq_len, vocab_size)
targets = targets.view(-1) # aka (batch*seq_len)
Notice that at index i of the input, we are trying to predict index i of the output (since the output is just “shifted right” by one word). This is the reason for the lower-triangular matrix of 1s in our attention block. It masks out the future values of our input by multiplying them by 0.
To take a look at the full model code, data loaders and a little training script using PyTorch-Lightning (with some useful things like sampling text from the model during training), check out https://github.com/jlehrer1/lightning-gpt.
Training tricks:
Warmup
Warmup is a training procedure that scales down the learning rate at the beginning of training.
We choose a number of iterations K. Then from iteration i=0,..,K we scale our learning rate r by r ← r*i/K before the backwards pass.
That is, we gradually increase our learning rate from r/K to r. Although unintuitive, this helps transformer-based architecture stop loss divergence early in training.
Downweighting attention heads at the beginning of training
In our decoder-block, notice that during our forwards pass we allow a residual connection between the embedded input and the output of the multi-head attention layer. It turns out that setting the initial weights of the attention head to be very small, so the contribution is extremely minimal during the early stages of training can help convergence quite a bit.
Use a big batch size
Generally, using a larger batch size (>1024) is necessary for convergence. Transformers are weird.
Use a big model. Then make it bigger.
Emperically, using a ton of parameters (i.e. lots of attention heads + lots of decode-block layers) helps a lot with convergence.
Browse these resources:
The following papers and blog posts have been useful, I highly recommend you check them out! Training tricks for transformer models, How to train your transformer, A survey on efficient training of transformers.
Thanks!