Skip to main content

Attention and Transformers

The Transformer architecture (Vaswani et al., 2017) has become the dominant neural network design for language modeling, vision, and increasingly all of machine learning. Its core innovation -- the attention mechanism -- replaces sequential processing (RNNs) with parallel global interactions, enabling efficient training on modern hardware. This chapter covers the mathematical foundations of attention, the full Transformer architecture, and the computational considerations that dominate modern LLM systems.

Scaled Dot-Product Attention

Given queries $Q \in \mathbb{R}^{n \times d_k}$, keys $K \in \mathbb{R}^{m \times d_k}$, and values $V \in \mathbb{R}^{m \times d_v}$ [@vaswani2017attention]:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

where dkd_k is the dimension of the keys. The output is a weighted sum of value vectors, where the weight for each key-value pair is the softmax-normalized dot product between the query and key.

**Why $\sqrt{d_k}$ scaling?** If query and key entries are i.i.d. with mean 0 and variance 1, then $q^\top k = \sum_{i=1}^{d_k} q_i k_i$ has mean 0 and variance $d_k$ (sum of $d_k$ products of unit-variance terms). For large $d_k$, the dot products can be very large in magnitude, pushing the softmax into regions where the gradient is near zero. Dividing by $\sqrt{d_k}$ normalizes the variance to 1, keeping the softmax in its sensitive regime.

Without scaling, with dk=128d_k = 128, the dot products would have standard deviation 11\approx 11, and softmax(11,11)(1,0)(11, -11) \approx (1, 0) -- almost all attention weight on one key. With scaling, std 1\approx 1, and the attention distribution remains informative.

**Attention as soft dictionary lookup.** Attention can be understood as a differentiable key-value lookup:
  1. Query-key matching: QKQK^\top computes similarity scores between each query and all keys (like looking up a key in a dictionary).
  2. Normalization: Softmax converts scores to a probability distribution (soft selection instead of hard lookup).
  3. Value retrieval: Multiply attention weights by values to get a weighted sum (the "retrieved" information).

This is analogous to a hash table with soft collisions: instead of returning one value for a query, it returns a convex combination of all values, weighted by key similarity. The weights WQ,WK,WVW^Q, W^K, W^V are learned to make the "dictionary" useful for the task.

**Attention as kernel smoothing.** Attention can be written as:

Attn(qi)=jexp(qikj/dk)vjjexp(qikj/dk)=jκ(qi,kj)jκ(qi,kj)vj\text{Attn}(q_i) = \frac{\sum_j \exp(q_i^\top k_j / \sqrt{d_k}) \cdot v_j}{\sum_j \exp(q_i^\top k_j / \sqrt{d_k})} = \sum_j \frac{\kappa(q_i, k_j)}{\sum_{j'} \kappa(q_i, k_{j'})} v_j

where κ(q,k)=exp(qk/dk)\kappa(q, k) = \exp(q^\top k / \sqrt{d_k}) is the softmax kernel. This is a Nadaraya-Watson kernel regression estimator with an exponential kernel. Linear attention variants (Katharopoulos et al., 2020) replace the softmax kernel with a decomposable kernel κ(q,k)=ϕ(q)ϕ(k)\kappa(q, k) = \phi(q)^\top \phi(k), reducing complexity from O(n2)O(n^2) to O(n)O(n).

Multi-Head Attention

Instead of one attention function with $d$-dimensional keys/values, project into $h$ independent heads:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)

where WiQRd×dkW_i^Q \in \mathbb{R}^{d \times d_k}, WiKRd×dkW_i^K \in \mathbb{R}^{d \times d_k}, WiVRd×dvW_i^V \in \mathbb{R}^{d \times d_v}, and WORhdv×dW^O \in \mathbb{R}^{hd_v \times d}.

Typically dk=dv=d/hd_k = d_v = d / h, so the total computation is the same as single-head attention with full dimensionality.

**Why multiple heads?** Each head can attend to different aspects of the input:
  • Head specialization: Empirically, different heads learn to focus on syntactic structure, semantic similarity, positional proximity, or specific linguistic patterns (e.g., subject-verb agreement, coreference).
  • Rank argument: Single-head attention with softmax output is approximately rank-1 (each query produces a peaked distribution over keys). Multi-head attention can attend to multiple positions simultaneously, achieving higher effective rank. With hh heads, the combined attention matrix has rank up to hh.
  • Information routing: The output projection WOW^O learns to combine head outputs, effectively routing different types of information through different subspaces.

Grouped-Query Attention (GQA) (Ainslie et al., 2023) uses fewer key-value heads than query heads (e.g., 8 KV heads for 32 query heads), reducing KV-cache memory while maintaining most of the quality.

The Transformer Block

A single Transformer block applies multi-head attention followed by a position-wise feed-forward network, with residual connections and layer normalization:

x=LayerNorm(x+MultiHead(x,x,x))(self-attention)x' = \text{LayerNorm}(x + \text{MultiHead}(x, x, x)) \quad \text{(self-attention)} out=LayerNorm(x+FFN(x))(feed-forward)\text{out} = \text{LayerNorm}(x' + \text{FFN}(x')) \quad \text{(feed-forward)}

where FFN(x)=activation(xW1+b1)W2+b2\text{FFN}(x) = \text{activation}(xW_1 + b_1)W_2 + b_2 with hidden dimension typically 4d4d. Modern LLMs use SwiGLU activation ([?shazeer2020glu]): FFN(x)=(Swish(xW1)xW3)W2\text{FFN}(x) = (\text{Swish}(xW_1) \odot xW_3)W_2 with hidden dimension 8d3\frac{8d}{3} (rounded to a multiple of 256).

ComponentOriginal (Vaswani et al., 2017)Modern LLMs (LLaMA-style)
NormalizationPost-LayerNormPre-RMSNorm
ActivationReLUSwiGLU
Position encodingSinusoidal (absolute)RoPE (relative)
AttentionMHA (hh KV heads)GQA (fewer KV heads)
FFN hidden dim4d4d8d3\frac{8d}{3} (with gate)
BiasYesNo
DropoutYesNo (at scale)
Vocab embeddingSeparate input/outputTied input/output
**Pre-norm vs. post-norm.** The original Transformer uses post-norm: $\text{LN}(x + \text{Sublayer}(x))$. Modern LLMs use pre-norm: $x + \text{Sublayer}(\text{LN}(x))$. Pre-norm is more stable during training because the residual stream is not normalized, allowing gradients to flow unchanged through skip connections. The tradeoff: pre-norm can lead to representation collapse in very deep networks (the residual dominates the sublayer output), which is mitigated by proper initialization scaling.

Positional Encoding

Transformers are permutation-equivariant without positional information: if you permute the input tokens, the output is permuted in the same way. Positional encodings break this symmetry.

Sinusoidal (absolute) (Vaswani et al., 2017):

PE(pos,2i)=sin(pos100002i/d),PE(pos,2i+1)=cos(pos100002i/d)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)

Each dimension oscillates at a different frequency, creating a unique "fingerprint" for each position. The key property: PEpos+kPE_{pos+k} can be written as a linear function of PEposPE_{pos}, so the model can learn to attend to relative positions.

Rotary Position Embeddings (RoPE) (Su et al., 2024): Encode position by rotating the query and key vectors in 2D subspaces:

f(xm,m)=Rmxm,Rm=(cos(mθ)sin(mθ)sin(mθ)cos(mθ))f(x_m, m) = R_m x_m, \quad R_m = \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix}

applied independently to pairs of dimensions. Different pairs use different base frequencies: θi=100002i/d\theta_i = 10000^{-2i/d}.

**RoPE advantages:**
  • Relative position: The dot product f(q,m)f(k,n)=qRmnkf(q, m)^\top f(k, n) = q^\top R_{m-n} k depends only on the relative position mnm - n, not absolute positions.
  • Decay with distance: The inner product naturally decays with distance for most query-key pairs, implementing a soft locality bias.
  • Length generalization: RoPE can extrapolate to longer sequences than seen during training, especially with techniques like NTK-aware scaling, YaRN ([?peng2023yarn]), or dynamic NTK interpolation.
  • No additional parameters: RoPE modifies the attention computation directly, adding no trainable parameters.
**ALiBi (Attention with Linear Biases)** [@press2022train] adds a linear position-dependent bias directly to the attention scores: $\text{score}(q_i, k_j) = q_i^\top k_j - m \cdot |i - j|$, where $m$ is a head-specific slope. This is simpler than RoPE and also supports length extrapolation, but RoPE has become the dominant choice in modern LLMs.

Complexity and KV-Cache

Self-attention has O(n2d)O(n^2 d) time complexity and O(n2+nd)O(n^2 + nd) space for sequence length nn. During autoregressive generation, the KV-cache avoids recomputing previous keys and values:

PhaseComputationMemoryBottleneck
Prefill (prompt)O(n2d)O(n^2 d) FLOPsO(nd)O(nd) for KV-cacheCompute-bound (matmul)
Each decode stepO(nd)O(nd) per tokenKV-cache grows by 2d2d per layer per tokenMemory-bound (KV-cache read)
Total decode (TT tokens)O(nTd)O(nTd)O((n+T)Ld)O((n+T)Ld)KV-cache capacity
**KV-cache memory dominates inference cost.** For a model with $L$ layers, $h$ heads, key dimension $d_k$, sequence length $n$, and batch size $B$:

KV-cache memory=2×B×L×n×h×dk×bytes per element\text{KV-cache memory} = 2 \times B \times L \times n \times h \times d_k \times \text{bytes per element}

For LLaMA-70B (L=80L = 80, hkv=8h_{kv} = 8, dk=128d_k = 128) at n=4096n = 4096, B=1B = 1, in FP16:

2×1×80×4096×8×128×2=1.34 GB2 \times 1 \times 80 \times 4096 \times 8 \times 128 \times 2 = 1.34 \text{ GB}

At n=128Kn = 128K, this grows to 42\sim 42 GB -- often exceeding the model weights themselves. This is why KV-cache compression (quantization, eviction, paged attention) is critical for long-context inference.

FlashAttention

**FlashAttention** [@dao2022flashattention] computes exact attention with $O(n^2 d)$ FLOPs but only $O(n)$ extra memory by tiling the computation to exploit the GPU memory hierarchy:
  1. Tiling: Divide QQ, KK, VV into blocks of size Br×dB_r \times d and Bc×dB_c \times d that fit in SRAM (on-chip shared memory, typically 192KB per SM on H100).
  2. Online softmax: For each QQ block, iterate over KK, VV blocks, maintaining running softmax statistics (max and sum of exponentials) using the online softmax trick: mnew=max(mold,mblock)m_{\text{new}} = \max(m_{\text{old}}, m_{\text{block}}), then rescale partial sums.
  3. No materialization: The n×nn \times n attention matrix is never stored in HBM -- each block is computed in SRAM, used immediately, and discarded.

The IO complexity drops from O(n2)O(n^2) HBM reads/writes (standard attention) to O(n2d2/M)O(n^2 d^2 / M) where MM is SRAM size.

**FlashAttention does not change the math** -- it computes the exact same result as standard attention. The speedup (2-4x on A100, more on H100) comes entirely from reducing HBM memory traffic. Key implications:
  • Enables long contexts: Without FlashAttention, n=128Kn = 128K requires 128K2×2=32128K^2 \times 2 = 32 GB just for the attention matrix in FP16. FlashAttention needs only O(n)O(n) extra memory.
  • Backward pass: FlashAttention recomputes the attention matrix during the backward pass rather than storing it, trading computation for memory. This is a form of gradient checkpointing applied specifically to attention.
  • FlashAttention-2 (Dao, 2023) further optimizes parallelism across sequence length (not just batch and heads), achieving closer to peak GPU throughput.
  • FlashAttention-3 targets Hopper architecture features (TMA, FP8, warp-specialized kernels).

Efficient Attention Variants

MethodComplexityExact?Key Idea
Standard attentionO(n2d)O(n^2 d)YesFull pairwise computation
FlashAttentionO(n2d)O(n^2 d) FLOPs, O(n)O(n) memoryYesMemory-efficient tiling
Linear attention (Katharopoulos et al., 2020)O(nd2)O(nd^2)NoReplace softmax with ϕ(q)ϕ(k)\phi(q)\phi(k)^\top
Sparse attention (Child et al., 2019)O(nnd)O(n\sqrt{n}d)NoAttend only to local + strided positions
Ring attention ([?liu2023ring])O(n2d)O(n^2 d) FLOPs, distributedYesDistribute sequence across devices
Multi-query attention (Shazeer, 2019)O(n2d)O(n^2 d)YesShare KV heads across query heads
Grouped-query attention (Ainslie et al., 2023)O(n2d)O(n^2 d)YesFew KV head groups, each shared
Sliding window (Beltagy et al., 2020)O(nwd)O(nwd)NoEach token attends to ww neighbors
**The trend in efficient attention.** Rather than reducing the $O(n^2)$ complexity (which typically sacrifices quality), modern systems focus on:
  1. Making O(n2)O(n^2) faster: FlashAttention, hardware-aware implementations.
  2. Reducing KV-cache memory: GQA, MQA, KV-cache quantization (e.g., FP8 KV), PagedAttention for serving.
  3. Hybrid architectures: Combine local attention (sliding window) with sparse global attention, or interleave attention layers with linear recurrence layers (e.g., Mamba, RWKV).
  4. Context extension: RoPE scaling, YaRN, and continued pretraining to extend context from 4K to 128K+ tokens.

The Complete Transformer

**Parameter count of a Transformer.** For a model with $L$ layers, dimension $d$, $h$ heads, vocabulary size $V$, and FFN hidden dim $d_{\text{ff}}$:
ComponentParameters per layerTotal
Self-attention (WQ,WK,WV,WOW^Q, W^K, W^V, W^O)4d24d^2 (with MHA)4Ld24Ld^2
FFN (W1,W2W_1, W_2)2ddff2d \cdot d_{\text{ff}}2Lddff2Ld \cdot d_{\text{ff}}
FFN with SwiGLU (W1,W2,W3W_1, W_2, W_3)3ddff3d \cdot d_{\text{ff}}3Lddff3Ld \cdot d_{\text{ff}}
LayerNorm / RMSNormdd or 2d2d2Ld\sim 2Ld
Embedding + LM headVdVd (tied)VdVd

For LLaMA-7B (L=32,d=4096,dff=11008,V=32000L=32, d=4096, d_{\text{ff}}=11008, V=32000): 6.7B\approx 6.7B parameters. The dominant cost is the FFN (with SwiGLU), followed by attention projections.

Notation Summary

SymbolMeaning
Q,K,VQ, K, VQuery, key, value matrices
dk,dvd_k, d_vKey and value dimensions per head
hhNumber of attention heads
nnSequence length
ddModel dimension (hidden size)
dffd_{\text{ff}}Feed-forward hidden dimension
LLNumber of Transformer layers
WQ,WK,WV,WOW^Q, W^K, W^V, W^OAttention projection matrices
W1,W2,W3W_1, W_2, W_3FFN weight matrices
RoPERotary position embeddings
GQAGrouped-query attention
MQAMulti-query attention
MMSRAM (shared memory) size
KV-cacheStored keys and values for autoregressive decoding

References