Skip to main content

From Silicon to PyTorch

Every call to torch.matmul(A, B) triggers a chain of events spanning multiple layers of abstraction -- from Python down to transistors. Understanding this stack explains why certain operations are fast, why operator fusion matters, and how to reason about performance.

The Abstraction Stack

LayerExampleWhat Happens
Python APItorch.matmul(A, B)User-facing call; type-checks, dispatches
ATen DispatcherOperator dispatch by dtype, device, autogradRoutes to correct backend implementation
Kernel LibrarycuBLAS cublasGemmExOptimized BLAS kernel selected by matrix size and dtype
CUDA RuntimeGrid/block/thread schedulingMaps logical threads to physical SMs
PTX / SASSGPU assembly instructionsRegister allocation, instruction scheduling
HardwareTensor Cores, CUDA Cores, memory controllersActual silicon execution
Each layer trades generality for performance. The user writes generic Python; the system compiles it down to hardware-specific instructions that exploit the exact memory layout, data type, and compute units available. Performance bugs usually live at one specific layer -- understanding the stack tells you where to look. **Tracing a matmul call.** When you call `torch.matmul(A, B)` where both tensors are `float16` on `cuda:0`:
  1. Python: torch.matmul is a Python function that calls into C++ via torch._C
  2. Dispatcher: The ATen dispatcher checks: dtype is float16, device is cuda, requires_grad=True on one input -- routes to the autograd-wrapped CUDA kernel
  3. cuBLAS: For matrices of shape (m,k)×(k,n)(m, k) \times (k, n), cuBLAS selects a tile size and algorithm (e.g., CUBLAS_GEMM_DEFAULT_TENSOR_OP) that maps to Tensor Cores
  4. Tensor Cores: Execute 4×4×44 \times 4 \times 4 (or 16×16×1616 \times 16 \times 16) matrix multiply-accumulate operations in a single clock cycle
  5. Result: Written back through the memory hierarchy to HBM, wrapped in a new PyTorch tensor with a backward function registered for autograd

BLAS: The Computational Backbone

**BLAS** is a standardized interface for basic linear algebra operations, organized into three levels by computational intensity. Originally specified in Fortran in the 1970s, it remains the foundation of all numerical computing libraries.
LevelOperation TypeComplexityFLOPs/WordExampleLibrary Call
1Vector-vectorO(n)O(n)O(1)O(1)yαx+yy \leftarrow \alpha x + yaxpy
2Matrix-vectorO(n2)O(n^2)O(1)O(1)yAx+yy \leftarrow Ax + ygemv
3Matrix-matrixO(n3)O(n^3)O(n)O(n)CαAB+βCC \leftarrow \alpha AB + \beta Cgemm

Level 3 operations have arithmetic intensity that grows with matrix size, making them compute-bound on modern hardware. Level 1 and 2 operations are always memory-bound because each element is touched only O(1)O(1) times.

The **GEMM** operation is the workhorse of deep learning:

Cαop(A)op(B)+βCC \leftarrow \alpha \, \text{op}(A) \, \text{op}(B) + \beta \, C

where ARm×kA \in \mathbb{R}^{m \times k}, BRk×nB \in \mathbb{R}^{k \times n}, CRm×nC \in \mathbb{R}^{m \times n}, and op()\text{op}(\cdot) is identity or transpose. The total FLOP count is 2mkn2mkn (each output element requires kk multiply-adds).

LAPACK (Linear Algebra PACKage) builds on BLAS to provide higher-level operations: eigenvalue decomposition, SVD, Cholesky factorization, QR decomposition, and linear system solvers. On GPUs, NVIDIA provides cuSOLVER as the LAPACK equivalent.

LibraryPlatformNotes
OpenBLASCPU (multi-platform)Open-source, hand-tuned assembly kernels
MKLCPU (Intel)Intel-optimized, often fastest on Intel hardware
cuBLASGPU (NVIDIA)Exploits Tensor Cores for mixed-precision GEMM
cuSOLVERGPU (NVIDIA)GPU LAPACK equivalent
CUTLASSGPU (NVIDIA)Template-based, customizable GEMM kernels
TritonGPU (NVIDIA)Python DSL for writing fused GPU kernels

Neural Network Operations as Linear Algebra

Almost every neural network computation reduces to BLAS calls, which is why optimizing GEMM directly translates to faster training and inference.

NN OperationMathematical FormBLAS CallNotes
Linear layer (single input)y=Wx+by = Wx + bGEMVMemory-bound for small xx
Linear layer (batch)Y=XW+1bY = XW^\top + \mathbf{1}b^\topGEMMCompute-bound for large batches
Attention scoresS=QK/dkS = QK^\top / \sqrt{d_k}GEMM + element-wiseFlashAttention fuses softmax
Attention outputO=softmax(S)VO = \text{softmax}(S) \cdot VGEMM
2D convolutionim2col(X)Wreshape\text{im2col}(X) \cdot W_{\text{reshape}}GEMMim2col unrolls patches into columns
Depthwise convolutionPer-channel opsSpecialized kernelDoes not map well to GEMM
Multi-head attentionBatched QiKiQ_i K_i^\top, softmaxVi\text{softmax} \cdot V_iBatched GEMMtorch.baddbmm
In a standard transformer training step, GEMM operations (linear projections, attention scores, attention output, FFN) account for roughly 60-70% of total FLOPs. The remaining 30-40% (layer norm, softmax, dropout, activation functions) are element-wise and memory-bound. This is why Tensor Cores, which accelerate GEMM by 8-16x, produce only 2-3x end-to-end speedup -- the memory-bound operations become the bottleneck.

Operator Fusion

**Operator fusion** combines multiple sequential operations into a single GPU kernel, eliminating intermediate reads from and writes to global memory (HBM). Instead of writing intermediate results to HBM between each operation, the fused kernel keeps data in registers or shared memory (SRAM).

Consider the compound operation ReLU(BatchNorm(Wx+b))\text{ReLU}(\text{BatchNorm}(Wx + b)):

Without fusion (3 separate kernels):

  1. GEMM kernel: compute z=Wx+bz = Wx + b, write zz to HBM
  2. BatchNorm kernel: read zz from HBM, compute z^=γzμσ2+ϵ+β\hat{z} = \gamma \frac{z - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta, write z^\hat{z} to HBM
  3. ReLU kernel: read z^\hat{z} from HBM, compute max(0,z^)\max(0, \hat{z}), write to HBM

Total memory traffic: 6×n6 \times n reads/writes to HBM (where nn is the output size).

With fusion (1 kernel):

y=max ⁣(0,  γWx+bμσ2+ϵ+β)(one round-trip to HBM)y = \max\!\left(0, \;\gamma \frac{Wx + b - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\right) \quad \text{(one round-trip to HBM)}

Total memory traffic: 2×n2 \times n (read inputs, write outputs). This is a 3×3\times reduction in memory traffic.

**FlashAttention as operator fusion** [@dao2022flashattention]. Standard attention computes and materializes the $T \times T$ attention matrix $S = QK^\top/\sqrt{d}$, which requires $O(T^2)$ HBM. FlashAttention fuses the entire $\text{softmax}(QK^\top/\sqrt{d}) \cdot V$ computation into tiled SRAM operations:
  1. Load tiles of QQ, KK, VV into SRAM (shared memory)
  2. Compute partial attention scores and outputs in SRAM
  3. Use online softmax (tracking running max and sum) to combine tiles
  4. Write only the final output OO to HBM

Result: memory usage drops from O(T2)O(T^2) to O(T)O(T), and wall-clock time improves 2-4x for long sequences despite performing more FLOPs (due to recomputation in the backward pass). This is a striking example of trading compute for memory bandwidth.

Fused OperationSeparate OpsSpeedupFramework Support
GEMM + bias + ReLU3 kernels~2xcuBLAS, cuDNN
LayerNorm + residual3 kernels~2-3xApex, Triton
Softmax + mask + dropout3 kernels~2xFlashAttention, xFormers
GEMM + GELU2 kernels~1.5xMegatron-LM
Full attention block~10 kernels~2-4xFlashAttention

Torch Compilation

Modern PyTorch provides `torch.compile()` [@ansel2024pytorch2], which automatically traces the computation graph and applies operator fusion, kernel selection, and memory planning. Under the hood, it uses TorchInductor to generate Triton kernels that fuse element-wise operations. This often achieves 1.5-2x speedup with a single line of code, though it does not yet match hand-written fusions like FlashAttention for complex patterns.

Notation Summary

SymbolMeaning
BLASBasic Linear Algebra Subprograms
GEMMGeneral Matrix Multiply: CαAB+βCC \leftarrow \alpha AB + \beta C
GEMVGeneral Matrix-Vector Multiply
LAPACKLinear Algebra Package
cuBLASNVIDIA GPU BLAS library
cuSOLVERNVIDIA GPU LAPACK library
CUTLASSNVIDIA template GEMM library
PTXParallel Thread Execution (GPU intermediate representation)
SASSGPU machine code (hardware-specific)
HBMHigh Bandwidth Memory
SRAMStatic RAM (shared memory, L1 cache)
α,β\alpha, \betaScalar coefficients in GEMM