Attention and Transformers
The Transformer architecture (Vaswani et al., 2017) has become the dominant neural network design for language modeling, vision, and increasingly all of machine learning. Its core innovation -- the attention mechanism -- replaces sequential processing (RNNs) with parallel global interactions, enabling efficient training on modern hardware. This chapter covers the mathematical foundations of attention, the full Transformer architecture, and the computational considerations that dominate modern LLM systems.
Scaled Dot-Product Attention
where is the dimension of the keys. The output is a weighted sum of value vectors, where the weight for each key-value pair is the softmax-normalized dot product between the query and key.
Without scaling, with , the dot products would have standard deviation , and softmax -- almost all attention weight on one key. With scaling, std , and the attention distribution remains informative.
- Query-key matching: computes similarity scores between each query and all keys (like looking up a key in a dictionary).
- Normalization: Softmax converts scores to a probability distribution (soft selection instead of hard lookup).
- Value retrieval: Multiply attention weights by values to get a weighted sum (the "retrieved" information).
This is analogous to a hash table with soft collisions: instead of returning one value for a query, it returns a convex combination of all values, weighted by key similarity. The weights are learned to make the "dictionary" useful for the task.
where is the softmax kernel. This is a Nadaraya-Watson kernel regression estimator with an exponential kernel. Linear attention variants (Katharopoulos et al., 2020) replace the softmax kernel with a decomposable kernel , reducing complexity from to .
Multi-Head Attention
where , , , and .
Typically , so the total computation is the same as single-head attention with full dimensionality.
- Head specialization: Empirically, different heads learn to focus on syntactic structure, semantic similarity, positional proximity, or specific linguistic patterns (e.g., subject-verb agreement, coreference).
- Rank argument: Single-head attention with softmax output is approximately rank-1 (each query produces a peaked distribution over keys). Multi-head attention can attend to multiple positions simultaneously, achieving higher effective rank. With heads, the combined attention matrix has rank up to .
- Information routing: The output projection learns to combine head outputs, effectively routing different types of information through different subspaces.
Grouped-Query Attention (GQA) (Ainslie et al., 2023) uses fewer key-value heads than query heads (e.g., 8 KV heads for 32 query heads), reducing KV-cache memory while maintaining most of the quality.
The Transformer Block
where with hidden dimension typically . Modern LLMs use SwiGLU activation ([?shazeer2020glu]): with hidden dimension (rounded to a multiple of 256).
| Component | Original (Vaswani et al., 2017) | Modern LLMs (LLaMA-style) |
|---|---|---|
| Normalization | Post-LayerNorm | Pre-RMSNorm |
| Activation | ReLU | SwiGLU |
| Position encoding | Sinusoidal (absolute) | RoPE (relative) |
| Attention | MHA ( KV heads) | GQA (fewer KV heads) |
| FFN hidden dim | (with gate) | |
| Bias | Yes | No |
| Dropout | Yes | No (at scale) |
| Vocab embedding | Separate input/output | Tied input/output |
Positional Encoding
Transformers are permutation-equivariant without positional information: if you permute the input tokens, the output is permuted in the same way. Positional encodings break this symmetry.
Sinusoidal (absolute) (Vaswani et al., 2017):
Each dimension oscillates at a different frequency, creating a unique "fingerprint" for each position. The key property: can be written as a linear function of , so the model can learn to attend to relative positions.
Rotary Position Embeddings (RoPE) (Su et al., 2024): Encode position by rotating the query and key vectors in 2D subspaces:
applied independently to pairs of dimensions. Different pairs use different base frequencies: .
- Relative position: The dot product depends only on the relative position , not absolute positions.
- Decay with distance: The inner product naturally decays with distance for most query-key pairs, implementing a soft locality bias.
- Length generalization: RoPE can extrapolate to longer sequences than seen during training, especially with techniques like NTK-aware scaling, YaRN ([?peng2023yarn]), or dynamic NTK interpolation.
- No additional parameters: RoPE modifies the attention computation directly, adding no trainable parameters.
Complexity and KV-Cache
Self-attention has time complexity and space for sequence length . During autoregressive generation, the KV-cache avoids recomputing previous keys and values:
| Phase | Computation | Memory | Bottleneck |
|---|---|---|---|
| Prefill (prompt) | FLOPs | for KV-cache | Compute-bound (matmul) |
| Each decode step | per token | KV-cache grows by per layer per token | Memory-bound (KV-cache read) |
| Total decode ( tokens) | KV-cache capacity |
For LLaMA-70B (, , ) at , , in FP16:
At , this grows to GB -- often exceeding the model weights themselves. This is why KV-cache compression (quantization, eviction, paged attention) is critical for long-context inference.
FlashAttention
- Tiling: Divide , , into blocks of size and that fit in SRAM (on-chip shared memory, typically 192KB per SM on H100).
- Online softmax: For each block, iterate over , blocks, maintaining running softmax statistics (max and sum of exponentials) using the online softmax trick: , then rescale partial sums.
- No materialization: The attention matrix is never stored in HBM -- each block is computed in SRAM, used immediately, and discarded.
The IO complexity drops from HBM reads/writes (standard attention) to where is SRAM size.
- Enables long contexts: Without FlashAttention, requires GB just for the attention matrix in FP16. FlashAttention needs only extra memory.
- Backward pass: FlashAttention recomputes the attention matrix during the backward pass rather than storing it, trading computation for memory. This is a form of gradient checkpointing applied specifically to attention.
- FlashAttention-2 (Dao, 2023) further optimizes parallelism across sequence length (not just batch and heads), achieving closer to peak GPU throughput.
- FlashAttention-3 targets Hopper architecture features (TMA, FP8, warp-specialized kernels).
Efficient Attention Variants
| Method | Complexity | Exact? | Key Idea |
|---|---|---|---|
| Standard attention | Yes | Full pairwise computation | |
| FlashAttention | FLOPs, memory | Yes | Memory-efficient tiling |
| Linear attention (Katharopoulos et al., 2020) | No | Replace softmax with | |
| Sparse attention (Child et al., 2019) | No | Attend only to local + strided positions | |
| Ring attention ([?liu2023ring]) | FLOPs, distributed | Yes | Distribute sequence across devices |
| Multi-query attention (Shazeer, 2019) | Yes | Share KV heads across query heads | |
| Grouped-query attention (Ainslie et al., 2023) | Yes | Few KV head groups, each shared | |
| Sliding window (Beltagy et al., 2020) | No | Each token attends to neighbors |
- Making faster: FlashAttention, hardware-aware implementations.
- Reducing KV-cache memory: GQA, MQA, KV-cache quantization (e.g., FP8 KV), PagedAttention for serving.
- Hybrid architectures: Combine local attention (sliding window) with sparse global attention, or interleave attention layers with linear recurrence layers (e.g., Mamba, RWKV).
- Context extension: RoPE scaling, YaRN, and continued pretraining to extend context from 4K to 128K+ tokens.
The Complete Transformer
| Component | Parameters per layer | Total |
|---|---|---|
| Self-attention () | (with MHA) | |
| FFN () | ||
| FFN with SwiGLU () | ||
| LayerNorm / RMSNorm | or | |
| Embedding + LM head | (tied) |
For LLaMA-7B (): parameters. The dominant cost is the FFN (with SwiGLU), followed by attention projections.
Notation Summary
| Symbol | Meaning |
|---|---|
| Query, key, value matrices | |
| Key and value dimensions per head | |
| Number of attention heads | |
| Sequence length | |
| Model dimension (hidden size) | |
| Feed-forward hidden dimension | |
| Number of Transformer layers | |
| Attention projection matrices | |
| FFN weight matrices | |
| RoPE | Rotary position embeddings |
| GQA | Grouped-query attention |
| MQA | Multi-query attention |
| SRAM (shared memory) size | |
| KV-cache | Stored keys and values for autoregressive decoding |
References
- Joshua Ainslie, James Lee-Thorp, Michiel de Jong (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP.
- Iz Beltagy, Matthew E. Peters, Arman Cohan (2020). Longformer: The Long-Document Transformer. arXiv.
- Rewon Child, Scott Gray, Alec Radford, Ilya Sutskever (2019). Generating Long Sequences with Sparse Transformers. arXiv.
- Tri Dao (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR.
- Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, Francois Fleuret (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML.
- Noam Shazeer (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv.
- Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, Yunfeng Liu (2024). RoFormer: Enhanced Transformer with Rotary Position Embedding. Neurocomputing.
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin (2017). Attention Is All You Need. NeurIPS.