From Assembly to PyTorch
Every time you write torch.matmul(A, B), a remarkable chain of abstraction layers transforms your one-line Python call into millions of hardware operations executing on Tensor Cores. Understanding this stack -- from Python to silicon -- gives you the mental model needed to reason about performance, debug mysterious slowdowns, and decide when to drop to a lower level of abstraction.
The Abstraction Stack
Layer 7: Python torch.matmul(A, B)
↓ Python C extension call (pybind11)
Layer 6: ATen/C++ at::matmul(A, B)
↓ PyTorch dispatcher: checks device, dtype, layout
Layer 5: Backend at::cuda::matmul → cublasGemmEx(...)
↓ cuBLAS selects optimal kernel based on matrix size
Layer 4: CUDA Runtime cudaLaunchKernel(gemm_kernel, grid, block, args)
↓ Sets up GPU execution configuration
Layer 3: PTX Assembly mad.f32 %f3, %f1, %f2, %f3;
↓ NVIDIA's intermediate representation (virtual ISA)
Layer 2: SASS Hardware-specific GPU machine code
↓ Generated by ptxas (PTX assembler)
Layer 1: Silicon Tensor Cores execute 16x8x16 FP16 matrix FMAs
Each Tensor Core: 256 FP16 FMAs per clock cycle
| Layer | Abstraction Level | When to Use | Performance Control |
|---|---|---|---|
| Python (torch) | Highest | 95% of ML code | None (framework decides) |
| torch.compile | High | When eager mode is too slow | Operator fusion, graph optimization |
| Triton | Medium | Custom kernels, new operations | Block-level memory management |
| CUDA C++ | Low | Maximum performance, novel algorithms | Full hardware control |
| PTX/SASS | Lowest | Hardware research, micro-optimization | Instruction-level control |
The 95/5 rule: Stay at layers 6-7 (Python/ATen) for 95% of your code. Drop to lower layers only for the 5% of operations that are performance-critical and not well-served by existing libraries. Most custom kernel work today is done in Triton rather than raw CUDA.
How PyTorch Dispatches Operations
When you call a PyTorch operation, the dispatcher determines which implementation to run based on the tensor's properties:
import torch
A = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
B = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
C = torch.matmul(A, B) # What happens here?
The dispatch path:
- Python binding --
torch.matmulcalls into C++ via pybind11. The overhead is ~1-5 us. - Dispatcher -- examines dispatch keys: device (CUDA), dtype (float16), layout (strided), autograd (requires_grad). Each key is a bit in a dispatch key set.
- Autograd wrapper -- if
requires_grad=True, wraps the operation to save inputs for the backward pass. Creates aMulBackward0node in the autograd graph. - Backend selection -- routes to
at::cuda::matmulbased on the CUDA dispatch key. - cuBLAS call -- for float16 GEMM, calls
cublasGemmExwithCUBLAS_COMPUTE_16F, which selects the best Tensor Core kernel for the given matrix dimensions. - Kernel launch -- cuBLAS launches a tiled GEMM kernel with a grid/block configuration optimized for the matrix size and GPU architecture.
- Tensor Core execution -- each Tensor Core computes a 16x8x16 FP16 matrix multiply-accumulate per cycle.
This is exactly the problem torch.compile solves: it traces the computation graph once, eliminating per-operation dispatch overhead, and fuses multiple operations into a single kernel.
You can observe dispatch overhead with:
# See every dispatched operation
import torch._dynamo
torch._dynamo.config.log_level = logging.DEBUG
# Or trace operations
with torch.autograd.profiler.profile() as prof:
output = model(input)
print(prof.key_averages().table())
Highest priority Lowest priority
| |
v v
Vmap → Autograd → Autocast → Functionalize → Backend (CPU/CUDA)
Each dispatch key can intercept the operation, do some work (e.g., autograd saves tensors for backward), and delegate to the next key. This layered design allows PyTorch to compose features (autograd + autocast + distributed) without combinatorial complexity.
torch.compile and TorchInductor
PyTorch 2.0 introduced torch.compile, which captures the computation graph and optimizes it via TorchInductor:
# Without compile: 7 separate CUDA kernel launches
def gelu(x):
return 0.5 * x * (1 + torch.tanh(
math.sqrt(2 / math.pi) * (x + 0.044715 * x ** 3)
))
# With compile: 1 fused Triton kernel
@torch.compile
def gelu_compiled(x):
return 0.5 * x * (1 + torch.tanh(
math.sqrt(2 / math.pi) * (x + 0.044715 * x ** 3)
))
Python code → TorchDynamo → FX Graph → TorchInductor → Triton/C++ kernel
- TorchDynamo -- captures the Python computation graph by bytecode analysis. Inserts graph breaks where it encounters unsupported Python (e.g., data-dependent control flow, print statements, calls to non-PyTorch libraries).
- AOTAutograd -- traces the forward and backward pass ahead of time, producing a joint graph. This eliminates autograd overhead during execution.
- TorchInductor -- the backend compiler. It generates optimized Triton kernels for GPU or C++/OpenMP code for CPU. Key optimizations:
- Operator fusion: combines memory-bound operations into a single kernel
- Memory planning: reuses buffers to reduce allocation overhead
- Loop tiling: schedules computation to maximize cache utilization
- Automatic vectorization: uses SIMD instructions on CPU
| Mode | Compilation Time | Runtime Speed | Graph Breaks | Use Case |
|---|---|---|---|---|
torch.compile(mode="default") | Moderate | Good | Some tolerance | General training |
torch.compile(mode="reduce-overhead") | Higher | Better | Less tolerant | Inference, small models |
torch.compile(mode="max-autotune") | Very high | Best | Least tolerant | Production deployment |
torch.compile(fullgraph=True) | Moderate | Good | Errors on break | Ensures full graph capture |
# BAD: causes graph break
@torch.compile
def forward(self, x):
x = self.layer1(x)
print(f"Shape: {x.shape}") # Graph break: print with tensor
x = self.layer2(x)
return x
# BAD: causes graph break
@torch.compile
def forward(self, x):
if x.sum() > 0: # Graph break: data-dependent control flow
return self.branch_a(x)
return self.branch_b(x)
# GOOD: no graph breaks
@torch.compile
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
Use torch._dynamo.explain(model)(input) to find graph breaks and their causes.
Operator Fusion in Practice
The biggest performance wins in ML systems come from fusing memory-bound operations to eliminate intermediate global memory round-trips:
# Unfused: 3 kernel launches, 3 global memory round-trips
# Each operation reads from and writes to HBM
def unfused(x, w, b, gamma, beta):
y = torch.linear(x, w, b) # Read x,w,b from HBM; write y to HBM
y = torch.batch_norm(y, ...) # Read y from HBM; write y to HBM
y = torch.relu(y) # Read y from HBM; write y to HBM
return y
# Total HBM traffic: ~6x the size of y (3 reads + 3 writes)
# Fused (via torch.compile or custom kernel): 1 kernel, 1 round-trip
# Intermediate results stay in registers/shared memory
@torch.compile
def fused(x, w, b, gamma, beta):
y = torch.linear(x, w, b)
y = torch.batch_norm(y, ...)
y = torch.relu(y)
return y
# Total HBM traffic: ~2x the size of y (1 read + 1 write for the fused part)
# Note: the GEMM in linear is separate (compute-bound, not fusable with memory-bound ops)
| Pattern | Operations | Speedup | Why It Helps |
|---|---|---|---|
| Linear + activation | matmul + ReLU/GELU/SiLU | 1.2-1.5x | Activation reads matmul output from registers |
| Attention (FlashAttention) | + mask + softmax + | 2-4x | Avoids materializing attention matrix |
| LayerNorm + residual + dropout | Add + mean + var + normalize + mask | 2-3x | Single pass over data instead of 5 |
| Pointwise chain | Any sequence of elementwise ops | 2-10x | Eliminates all intermediate HBM writes |
| Embedding + scale | Gather + multiply | 1.5-2x | Scale applied to gathered data in registers |
| SwiGLU | Linear + Swish + Gate + Linear | 1.3-1.5x | Gate and activation fused with first linear |
CUDA Graphs
For workloads with many small kernels, CUDA Graphs capture an entire sequence of kernel launches and replay them with a single API call, eliminating per-launch overhead:
# Capture a CUDA graph
static_input = torch.randn(batch_size, seq_len, d_model, device='cuda')
# Warmup (required to populate cuBLAS plans, etc.)
for _ in range(3):
output = model(static_input)
# Capture the graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_output = model(static_input)
# Replay: launches all kernels with one API call
static_input.copy_(real_input) # Update input in-place
graph.replay()
result = static_output.clone() # Read output
CUDA Graphs are most useful for inference (fixed shapes) and less useful for training (where gradient accumulation, gradient clipping, and optimizer steps may have dynamic behavior). torch.compile(mode="reduce-overhead") uses CUDA Graphs internally.
BLAS: The Foundation of Numerical Computing
All matrix operations in deep learning ultimately call BLAS (Basic Linear Algebra Subprograms), a standardized interface first defined in 1979 that remains the performance backbone of scientific computing:
Level 1 (vector-vector): O(n) operations, O(n) data
saxpy y = αx + y -- backbone of gradient updates
sdot s = x·y -- similarity computation
snrm2 s = ||x||₂ -- gradient norms
Level 2 (matrix-vector): O(n²) operations, O(n²) data
sgemv y = αAx + βy -- single-sample inference
strsv x = A⁻¹b -- triangular solve
Level 3 (matrix-matrix): O(n³) operations, O(n²) data
sgemm C = αAB + βC -- THE operation of deep learning
ssyrk C = αAA^T + βC -- covariance computation
strsm X = A⁻¹B -- batched triangular solve
The naming convention encodes the operation precisely:
| Position | Options | Meaning |
|---|---|---|
| 1st letter | s / d / c / z / h | float32 / float64 / complex64 / complex128 / float16 |
| 2nd-3rd | ge / sy / tr / di | general / symmetric / triangular / diagonal |
| 4th-5th | mm / mv / sv / r / rk | matrix-matrix / matrix-vector / solve / rank-1 update / rank-k update |
| Library | Hardware | Key Feature |
|---|---|---|
| cuBLAS | NVIDIA GPU | Tensor Core support, auto-tuning per matrix size |
| cuBLASLt | NVIDIA GPU | More flexible API, FP8 support, epilogue fusion |
| Intel MKL | Intel CPU | AVX-512 optimized, multi-threaded |
| OpenBLAS | CPU (any) | Open source, portable |
| BLIS | CPU (any) | Modern, cache-friendly, open source |
| rocBLAS | AMD GPU | HIP-based, MI300X support |
In PyTorch, torch.backends.cuda.matmul.allow_tf32 = True (default on Ampere+) enables TF32 Tensor Cores for FP32 GEMMs, giving ~8x speedup with ~3 decimal digits of precision loss.
# Healthy profile (compute-bound):
cublasGemmEx 45% ← Good: most time in GEMM
flash_attn_kernel 25% ← Good: fused attention
elementwise_kernel 10% ← Acceptable: fused pointwise
nccl_allreduce 8% ← Acceptable: communication
other 12%
# Unhealthy profile (overhead-bound):
cudaLaunchKernel 20% ← Bad: launch overhead dominates
cudaMemcpyAsync 15% ← Bad: too many memory copies
nccl_allreduce 30% ← Bad: communication bottleneck
many_tiny_kernels 25% ← Bad: not fused
cublasGemmEx 10% ← Problem: compute is minority