From Silicon to PyTorch
Every call to torch.matmul(A, B) triggers a chain of events spanning multiple layers of abstraction, from Python down to transistors. Understanding this stack explains why certain operations are fast, why operator fusion matters, and how to reason about performance.
The Abstraction Stack
| Layer | Example | What Happens |
|---|---|---|
| Python API | torch.matmul(A, B) | User-facing call; type-checks, dispatches |
| ATen Dispatcher | Operator dispatch by dtype, device, autograd | Routes to correct backend implementation |
| Kernel Library | cuBLAS cublasGemmEx | Optimized BLAS kernel selected by matrix size and dtype |
| CUDA Runtime | Grid/block/thread scheduling | Maps logical threads to physical SMs |
| PTX / SASS | GPU assembly instructions | Register allocation, instruction scheduling |
| Hardware | Tensor Cores, CUDA Cores, memory controllers | Actual silicon execution |
- Python:
torch.matmulis a Python function that calls into C++ viatorch._C - Dispatcher: The ATen dispatcher checks: dtype is
float16, device iscuda,requires_grad=Trueon one input, routes to the autograd-wrapped CUDA kernel - cuBLAS: For matrices of shape , cuBLAS selects a tile size and algorithm (e.g.,
CUBLAS_GEMM_DEFAULT_TENSOR_OP) that maps to Tensor Cores - Tensor Cores: Execute matrix multiply-accumulate operations in a single clock cycle (warp-level WMMA instructions aggregate multiple Tensor Core operations across several cycles)
- Result: Written back through the memory hierarchy to HBM, wrapped in a new PyTorch tensor with a backward function registered for autograd
BLAS: The Computational Backbone
| Level | Operation Type | Complexity | FLOPs/Word | Example | Library Call |
|---|---|---|---|---|---|
| 1 | Vector-vector | axpy | |||
| 2 | Matrix-vector | gemv | |||
| 3 | Matrix-matrix | gemm |
Level 3 operations have arithmetic intensity that grows with matrix size, making them compute-bound on modern hardware. Level 1 and 2 operations are always memory-bound because each element is touched only times.
where , , , and is identity or transpose. The total FLOP count is (each output element requires multiply-adds).
LAPACK (Linear Algebra PACKage) builds on BLAS to provide higher-level operations: eigenvalue decomposition, SVD, Cholesky factorization, QR decomposition, and linear system solvers. On GPUs, NVIDIA provides cuSOLVER as the LAPACK equivalent.
| Library | Platform | Notes |
|---|---|---|
| OpenBLAS | CPU (multi-platform) | Open-source, hand-tuned assembly kernels |
| MKL | CPU (Intel) | Intel-optimized, often fastest on Intel hardware |
| cuBLAS | GPU (NVIDIA) | Exploits Tensor Cores for mixed-precision GEMM |
| cuSOLVER | GPU (NVIDIA) | GPU LAPACK equivalent |
| CUTLASS | GPU (NVIDIA) | Template-based, customizable GEMM kernels |
| Triton | GPU (NVIDIA) | Python DSL for writing fused GPU kernels |
Neural Network Operations as Linear Algebra
Almost every neural network computation reduces to BLAS calls, which is why optimizing GEMM directly translates to faster training and inference.
| NN Operation | Mathematical Form | BLAS Call | Notes |
|---|---|---|---|
| Linear layer (single input) | GEMV | Memory-bound for small | |
| Linear layer (batch) | GEMM | Compute-bound for large batches | |
| Attention scores | GEMM + element-wise | FlashAttention fuses softmax | |
| Attention output | GEMM | ||
| 2D convolution | GEMM | im2col unrolls patches into columns | |
| Depthwise convolution | Per-channel ops | Specialized kernel | Does not map well to GEMM |
| Multi-head attention | Batched , | Batched GEMM | torch.baddbmm |
Operator Fusion
Consider the compound operation :
Without fusion (3 separate kernels):
- GEMM kernel: compute , write to HBM
- BatchNorm kernel: read from HBM, compute , write to HBM
- ReLU kernel: read from HBM, compute , write to HBM
Total memory traffic: reads/writes to HBM (where is the output size).
With fusion (1 kernel):
Total memory traffic: (read inputs, write outputs). This is a reduction in memory traffic.
- Load tiles of , , into SRAM (shared memory)
- Compute partial attention scores and outputs in SRAM
- Use online softmax (tracking running max and sum) to combine tiles
- Write only the final output to HBM
Result: memory usage drops from to , and wall-clock time improves 2-4x for long sequences despite performing more FLOPs (due to recomputation in the backward pass). This is a striking example of trading compute for memory bandwidth.
| Fused Operation | Separate Ops | Speedup | Framework Support |
|---|---|---|---|
| GEMM + bias + ReLU | 3 kernels | ~2x | cuBLAS, cuDNN |
| LayerNorm + residual | 3 kernels | ~2-3x | Apex, Triton |
| Softmax + mask + dropout | 3 kernels | ~2x | FlashAttention, xFormers |
| GEMM + GELU | 2 kernels | ~1.5x | Megatron-LM |
| Full attention block | ~10 kernels | ~2-4x | FlashAttention |
Torch Compilation
Notation Summary
| Symbol | Meaning |
|---|---|
| BLAS | Basic Linear Algebra Subprograms |
| GEMM | General Matrix Multiply: |
| GEMV | General Matrix-Vector Multiply |
| LAPACK | Linear Algebra Package |
| cuBLAS | NVIDIA GPU BLAS library |
| cuSOLVER | NVIDIA GPU LAPACK library |
| CUTLASS | NVIDIA template GEMM library |
| PTX | Parallel Thread Execution (GPU intermediate representation) |
| SASS | GPU machine code (hardware-specific) |
| HBM | High Bandwidth Memory |
| SRAM | Static RAM (shared memory, L1 cache) |
| Scalar coefficients in GEMM |