From Silicon to PyTorch
Every call to torch.matmul(A, B) triggers a chain of events spanning multiple layers of abstraction -- from Python down to transistors. Understanding this stack explains why certain operations are fast, why operator fusion matters, and how to reason about performance.
The Abstraction Stack
| Layer | Example | What Happens |
|---|---|---|
| Python API | torch.matmul(A, B) | User-facing call; type-checks, dispatches |
| ATen Dispatcher | Operator dispatch by dtype, device, autograd | Routes to correct backend implementation |
| Kernel Library | cuBLAS cublasGemmEx | Optimized BLAS kernel selected by matrix size and dtype |
| CUDA Runtime | Grid/block/thread scheduling | Maps logical threads to physical SMs |
| PTX / SASS | GPU assembly instructions | Register allocation, instruction scheduling |
| Hardware | Tensor Cores, CUDA Cores, memory controllers | Actual silicon execution |
- Python:
torch.matmulis a Python function that calls into C++ viatorch._C - Dispatcher: The ATen dispatcher checks: dtype is
float16, device iscuda,requires_grad=Trueon one input -- routes to the autograd-wrapped CUDA kernel - cuBLAS: For matrices of shape , cuBLAS selects a tile size and algorithm (e.g.,
CUBLAS_GEMM_DEFAULT_TENSOR_OP) that maps to Tensor Cores - Tensor Cores: Execute (or ) matrix multiply-accumulate operations in a single clock cycle
- Result: Written back through the memory hierarchy to HBM, wrapped in a new PyTorch tensor with a backward function registered for autograd
BLAS: The Computational Backbone
| Level | Operation Type | Complexity | FLOPs/Word | Example | Library Call |
|---|---|---|---|---|---|
| 1 | Vector-vector | axpy | |||
| 2 | Matrix-vector | gemv | |||
| 3 | Matrix-matrix | gemm |
Level 3 operations have arithmetic intensity that grows with matrix size, making them compute-bound on modern hardware. Level 1 and 2 operations are always memory-bound because each element is touched only times.
where , , , and is identity or transpose. The total FLOP count is (each output element requires multiply-adds).
LAPACK (Linear Algebra PACKage) builds on BLAS to provide higher-level operations: eigenvalue decomposition, SVD, Cholesky factorization, QR decomposition, and linear system solvers. On GPUs, NVIDIA provides cuSOLVER as the LAPACK equivalent.
| Library | Platform | Notes |
|---|---|---|
| OpenBLAS | CPU (multi-platform) | Open-source, hand-tuned assembly kernels |
| MKL | CPU (Intel) | Intel-optimized, often fastest on Intel hardware |
| cuBLAS | GPU (NVIDIA) | Exploits Tensor Cores for mixed-precision GEMM |
| cuSOLVER | GPU (NVIDIA) | GPU LAPACK equivalent |
| CUTLASS | GPU (NVIDIA) | Template-based, customizable GEMM kernels |
| Triton | GPU (NVIDIA) | Python DSL for writing fused GPU kernels |
Neural Network Operations as Linear Algebra
Almost every neural network computation reduces to BLAS calls, which is why optimizing GEMM directly translates to faster training and inference.
| NN Operation | Mathematical Form | BLAS Call | Notes |
|---|---|---|---|
| Linear layer (single input) | GEMV | Memory-bound for small | |
| Linear layer (batch) | GEMM | Compute-bound for large batches | |
| Attention scores | GEMM + element-wise | FlashAttention fuses softmax | |
| Attention output | GEMM | ||
| 2D convolution | GEMM | im2col unrolls patches into columns | |
| Depthwise convolution | Per-channel ops | Specialized kernel | Does not map well to GEMM |
| Multi-head attention | Batched , | Batched GEMM | torch.baddbmm |
Operator Fusion
Consider the compound operation :
Without fusion (3 separate kernels):
- GEMM kernel: compute , write to HBM
- BatchNorm kernel: read from HBM, compute , write to HBM
- ReLU kernel: read from HBM, compute , write to HBM
Total memory traffic: reads/writes to HBM (where is the output size).
With fusion (1 kernel):
Total memory traffic: (read inputs, write outputs). This is a reduction in memory traffic.
- Load tiles of , , into SRAM (shared memory)
- Compute partial attention scores and outputs in SRAM
- Use online softmax (tracking running max and sum) to combine tiles
- Write only the final output to HBM
Result: memory usage drops from to , and wall-clock time improves 2-4x for long sequences despite performing more FLOPs (due to recomputation in the backward pass). This is a striking example of trading compute for memory bandwidth.
| Fused Operation | Separate Ops | Speedup | Framework Support |
|---|---|---|---|
| GEMM + bias + ReLU | 3 kernels | ~2x | cuBLAS, cuDNN |
| LayerNorm + residual | 3 kernels | ~2-3x | Apex, Triton |
| Softmax + mask + dropout | 3 kernels | ~2x | FlashAttention, xFormers |
| GEMM + GELU | 2 kernels | ~1.5x | Megatron-LM |
| Full attention block | ~10 kernels | ~2-4x | FlashAttention |
Torch Compilation
Notation Summary
| Symbol | Meaning |
|---|---|
| BLAS | Basic Linear Algebra Subprograms |
| GEMM | General Matrix Multiply: |
| GEMV | General Matrix-Vector Multiply |
| LAPACK | Linear Algebra Package |
| cuBLAS | NVIDIA GPU BLAS library |
| cuSOLVER | NVIDIA GPU LAPACK library |
| CUTLASS | NVIDIA template GEMM library |
| PTX | Parallel Thread Execution (GPU intermediate representation) |
| SASS | GPU machine code (hardware-specific) |
| HBM | High Bandwidth Memory |
| SRAM | Static RAM (shared memory, L1 cache) |
| Scalar coefficients in GEMM |