Optimal transport (OT) is the mathematical theory of moving mass efficiently between distributions. It provides geometrically meaningful distances between probability distributions that, unlike KL divergence, respect the underlying geometry of the sample space. In machine learning, OT appears in generative models (Wasserstein GANs), domain adaptation, fairness, and increasingly as a foundation for modern generative frameworks like flow matching and rectified flows.
Given two probability distributions $\mu$ (source) and $\nu$ (target) over a space $\mathcal{X}$:
Monge formulation (1781): Find a transport map T:X→X that pushes μ forward to ν (i.e., T#μ=ν, meaning ν(B)=μ(T−1(B)) for all measurable B) and minimizes:
infT:T#μ=ν∫c(x,T(x))dμ(x)
Kantorovich formulation (1942): Relax to a transport plan γ∈Π(μ,ν) -- a joint distribution with marginals μ and ν -- that allows mass splitting:
minγ∈Π(μ,ν)∫X×Xc(x,y)dγ(x,y)
where c(x,y) is the cost of moving a unit of mass from x to y, and Π(μ,ν)={γ:∫γ(x,y)dy=μ(x),∫γ(x,y)dx=ν(y)}.
**Monge vs. Kantorovich.** Monge's formulation requires a deterministic map (each source point goes to exactly one destination), which may not exist (e.g., transporting a point mass to two point masses). Kantorovich's relaxation always has a solution and is a linear program. When the cost is $c(x,y) = \|x-y\|^2$ and $\mu$ is absolutely continuous, the Monge and Kantorovich solutions coincide: the optimal plan is supported on the graph of a map $T^*(x) = x - \nabla \phi(x)$ where $\phi$ is a convex function (Brenier's theorem).
Wasserstein Distance
The **$p$-Wasserstein distance** uses cost $c(x,y) = \|x - y\|^p$:
Wp(μ,ν)=(infγ∈Π(μ,ν)∫∥x−y∥pdγ(x,y))1/p
The 1-Wasserstein distance (p=1) is also called the Earth Mover's Distance (EMD): the minimum "work" (mass × distance) needed to reshape one pile of dirt into another. The 2-Wasserstein distance (p=2) has nicer geometric properties and connections to optimal maps.
**Why Wasserstein is better than KL for some tasks.**
| Property | KL Divergence | Wasserstein Distance |
|---|
| Metric? | No (asymmetric, no triangle ineq.) | Yes (symmetric, triangle ineq.) |
| Non-overlapping support | DKL=∞ | Finite (uses geometry) |
| Sensitivity to geometry | None (only ratios p/q) | Respects distances in X |
| Gradient quality | Vanishes when supports don't overlap | Always provides useful gradients |
| Computation | O(n) (sample-based) | O(n3) exact, O(n2/ϵ2) entropic |
The key advantage: when μ and ν have disjoint supports (common early in GAN training when the generator produces images far from real data), KL divergence is infinite and provides no gradient signal. Wasserstein distance is finite and provides a meaningful gradient pointing μ toward ν.
**1D Wasserstein.** In one dimension, the Wasserstein distance has a beautiful closed form:
Wp(μ,ν)=(∫01∣Fμ−1(t)−Fν−1(t)∣pdt)1/p
where F−1 is the quantile function (inverse CDF). For p=1: W1=∫−∞∞∣Fμ(x)−Fν(x)∣dx. The optimal transport map is T=Fν−1∘Fμ (map quantiles to quantiles). This is why 1D optimal transport is O(nlogn) (just sort and pair up).
**Wasserstein between Gaussians.** For $\mu = \mathcal{N}(m_1, \Sigma_1)$ and $\nu = \mathcal{N}(m_2, \Sigma_2)$, the 2-Wasserstein distance has a closed form (the **Bures metric**):
W22(μ,ν)=∥m1−m2∥2+tr(Σ1+Σ2−2(Σ11/2Σ2Σ11/2)1/2)
For diagonal covariances: W22=∥m1−m2∥2+∑i(σ1,i−σ2,i)2. This is used in the FID (Frechet Inception Distance) metric for evaluating generative models, where the Inception features of real and generated images are modeled as Gaussians.
Kantorovich Duality
The **dual formulation** of the 1-Wasserstein distance is:
W1(μ,ν)=sup∥f∥L≤1(Ex∼μ[f(x)]−Ey∼ν[f(y)])
where the supremum is over all 1-Lipschitz functions f (satisfying ∣f(x)−f(y)∣≤∥x−y∥ for all x,y).
More generally, for cost c(x,y), the dual is:
supϕ,ψ{Eμ[ϕ(x)]+Eν[ψ(y)]:ϕ(x)+ψ(y)≤c(x,y)∀x,y}
where (ϕ,ψ) are Kantorovich potentials. For c=∥x−y∥, the constraint forces ϕ=−ψ and ϕ to be 1-Lipschitz.
**From duality to WGAN.** The Kantorovich dual replaces optimization over transport plans (high-dimensional joint distributions) with optimization over a single function. This is exactly what the **Wasserstein GAN (WGAN)** [@arjovsky2017wgan] exploits:
- Critic (discriminator): Parameterize a 1-Lipschitz function fω and maximize Ex∼pdata[fω(x)]−Ez∼pz[fω(Gθ(z))].
- Generator: Minimize the same objective (push generated samples closer to real data in Wasserstein sense).
The Lipschitz constraint is enforced via:
- Weight clipping (original WGAN): Clip ω∈[−c,c]. Simple but biases toward simple functions.
- Gradient penalty (WGAN-GP) ([?gulrajani2017improved]): Add λE[(∥∇fω(x^)∥−1)2] where x^ interpolates between real and fake samples. More stable.
- Spectral normalization ([?miyato2018spectral]): Normalize weight matrices by their spectral norm. Efficient and widely used.
Entropic Optimal Transport
Adding an **entropic regularization** term makes OT computationally tractable:
Wϵ(μ,ν)=minγ∈Π(μ,ν)∑i,jcijγij+ϵ∑i,jγijlogγij
The entropy term ϵH(γ) encourages the transport plan to spread out (higher entropy = more diffuse plan). As ϵ→0, Wϵ→W (exact OT). As ϵ→∞, the plan approaches the independent coupling γ=μ⊗ν.
Input: Cost matrix C∈Rn×m, marginals a∈Δn−1, b∈Δm−1, regularization ϵ>0
- Compute Gibbs kernel: Kij=exp(−Cij/ϵ)
- Initialize: v=1m
- for ℓ=1,2,… until convergence do
- u←a⊘(Kv) (row normalization: ui=ai/∑jKijvj)
- v←b⊘(K⊤u) (column normalization: vj=bj/∑iKijui)
- end for
Output: Transport plan γij∗=uiKijvj, cost Wϵ=⟨C,γ∗⟩
**Sinkhorn algorithm properties:**
- Convergence: Linear rate O((1−δ)ℓ) where δ depends on ϵ and the cost matrix. Smaller ϵ means slower convergence but more accurate OT.
- GPU-friendly: Each iteration is a matrix-vector multiply Kv and elementwise operations -- perfectly suited for GPU parallelism. For n points, each iteration is O(n2).
- Differentiable: Both the forward and backward passes are differentiable, enabling end-to-end learning with OT losses. Automatic differentiation through the Sinkhorn iterations is straightforward.
- Log-domain stability: For small ϵ, the kernel K has entries close to 0 or very large. Computing in log-domain: logu,logv,logK avoids numerical underflow/overflow.
- Minibatch Sinkhorn: For large-scale problems, compute Sinkhorn on random minibatches of source and target points. This gives a biased but useful estimator.
Optimal Transport Maps and Displacement Interpolation
Given the optimal transport map $T^*$ from $\mu$ to $\nu$, the **displacement interpolation** (McCann interpolation) defines a geodesic path between distributions:
μt=((1−t)Id+tT∗)#μ,t∈[0,1]
At t=0: μ0=μ. At t=1: μ1=ν. Each intermediate μt moves mass along the optimal transport paths, creating a natural interpolation in distribution space.
**Displacement interpolation in generative models.** Flow matching [@lipman2023flow] and rectified flows [@liu2023flow] learn the velocity field of the displacement interpolation:
xt=(1−t)x0+tx1where x0∼μ (noise),x1∼ν (data)
The velocity field vt(x)=x1−x0 produces straight-line trajectories from noise to data. Training the model vθ(xt,t)≈x1−x0 with MSE loss gives the flow matching objective. This is simpler than diffusion (no noise schedule, no SDE/ODE theory) and produces straighter flows that require fewer sampling steps.
The connection to OT: if (x0,x1) are coupled via the optimal transport plan (rather than independently), the trajectories are non-crossing and the learned velocity field is smoother, leading to better generation quality.
Wasserstein Barycenters
The **Wasserstein barycenter** of distributions $\mu_1, \ldots, \mu_K$ with weights $\lambda_1, \ldots, \lambda_K$ is:
μˉ=argminμ∑k=1KλkW22(μ,μk)
The barycenter is the "average" distribution in the Wasserstein sense, which preserves geometric structure better than mixture-based averaging.
**Applications:** Wasserstein barycenters are used for texture mixing, shape interpolation, multi-source domain adaptation (average the source domains into a barycenter, then adapt to the target), and aggregating predictions from multiple generative models. The entropic-regularized barycenter can be computed via iterative Sinkhorn projections.
ML Applications
| Application | How OT is Used | Key Advantage |
|---|
| Wasserstein GAN | Critic approximates W1 via Kantorovich duality | Stable training, meaningful loss |
| Flow matching | OT coupling gives straight trajectories | Few-step generation |
| Domain adaptation | Minimize W2 between source and target features | Geometry-aware alignment |
| FID score | W22 between Gaussians fit to Inception features | Standard generative model metric |
| Distributional RL | Model return as a distribution; use Wp for comparison | Preserves return distribution shape |
| Fairness | W1 between demographic group distributions | Measure and mitigate distributional disparity |
| Data augmentation | Displacement interpolation between classes | Meaningful between-class interpolation |
| Sliced Wasserstein | Eθ[W1(projθμ,projθν)] | Scalable: 1D OT on random projections |
| Graph matching | OT between node feature distributions | Permutation-invariant comparison |
| Point cloud registration | OT map aligns two point sets | Correspondence without labels |
**Computational complexity of OT.**
| Method | Complexity | Accuracy | Use Case |
|---|
| Exact (linear program) | O(n3logn) | Exact | Small problems (n<1000) |
| Sinkhorn (entropic) | O(n2/ϵ2) | ϵ-approximate | Medium problems, differentiable |
| Sliced Wasserstein | O(Lnlogn) (L projections) | Approximation | Large-scale, any dimension |
| Minibatch Sinkhorn | O(b2) per batch | Biased estimator | Very large scale |
| Neural OT | O(forward pass) | Amortized | Continuous distributions |
For high-dimensional problems, sliced Wasserstein computes W1 on random 1D projections (where OT is just sorting) and averages. This avoids the curse of dimensionality and scales to millions of points.
Notation Summary
| Symbol | Meaning |
|---|
| μ,ν | Source and target distributions |
| γ | Transport plan (coupling) |
| T | Transport map (Monge) |
| T#μ | Pushforward of μ by T |
| Π(μ,ν) | Set of all couplings with marginals μ,ν |
| c(x,y) | Transport cost function |
| Wp | p-Wasserstein distance |
| ϵ | Entropic regularization parameter |
| K | Gibbs kernel: Kij=e−cij/ϵ |
| u,v | Sinkhorn scaling vectors |
| ϕ,ψ | Kantorovich dual potentials |
| μˉ | Wasserstein barycenter |
| EMD | Earth Mover's Distance (W1) |
| FID | Frechet Inception Distance (Bures-W2) |