Optimal transport (OT) is the mathematical theory of moving mass efficiently between distributions. It provides geometrically meaningful distances between probability distributions that, unlike KL divergence, respect the underlying geometry of the sample space. In machine learning, OT appears in generative models (Wasserstein GANs), domain adaptation, fairness, and increasingly as a foundation for modern generative frameworks like flow matching and rectified flows.
Given two probability distributions $\mu$ (source) and $\nu$ (target) over a space $\mathcal{X}$:
Monge formulation (1781): Find a transport mapT:X→X that pushes μ forward to ν (i.e., T#μ=ν, meaning ν(B)=μ(T−1(B)) for all measurable B) and minimizes:
infT:T#μ=ν∫c(x,T(x))dμ(x)
Kantorovich formulation (1942): Relax to a transport planγ∈Π(μ,ν) (a joint distribution with marginals μ and ν) that allows mass splitting:
minγ∈Π(μ,ν)∫X×Xc(x,y)dγ(x,y)
where c(x,y) is the cost of moving a unit of mass from x to y, and Π(μ,ν)={γ:∫γ(x,y)dy=μ(x),∫γ(x,y)dx=ν(y)}.
**Monge vs. Kantorovich.** Monge's formulation requires a deterministic map (each source point goes to exactly one destination), which may not exist (e.g., transporting a point mass to two point masses). Kantorovich's relaxation always has a solution and is a linear program. When the cost is $c(x,y) = \|x-y\|^2$ and $\mu$ is absolutely continuous, the Monge and Kantorovich solutions coincide: the optimal plan is supported on the graph of a map $T^*(x) = \nabla \psi(x)$ where $\psi$ is a convex function (Brenier's theorem). Writing $T^*(x) = x - \nabla \phi(x)$, the convex Brenier potential is $\psi(x) = \tfrac{1}{2}\|x\|^2 - \phi(x)$, not $\phi$ itself.
The **$p$-Wasserstein distance** uses cost $c(x,y) = \|x - y\|^p$:
Wp(μ,ν)=(infγ∈Π(μ,ν)∫∥x−y∥pdγ(x,y))1/p
The 1-Wasserstein distance (p=1) is also called the Earth Mover's Distance (EMD): the minimum "work" (mass × distance) needed to reshape one pile of dirt into another. The 2-Wasserstein distance (p=2) has nicer geometric properties and connections to optimal maps.
**Why Wasserstein is better than KL for some tasks.**
Property
KL Divergence
Wasserstein Distance
Metric?
No (asymmetric, no triangle ineq.)
Yes (symmetric, triangle ineq.)
Non-overlapping support
DKL=∞
Finite (uses geometry)
Sensitivity to geometry
None (only ratios p/q)
Respects distances in X
Gradient quality
Vanishes when supports don't overlap
Always provides useful gradients
Computation
O(n) (sample-based)
O(n3) exact, O(n2/ϵ2) entropic
The key advantage: when μ and ν have disjoint supports (common early in GAN training when the generator produces images far from real data), KL divergence is infinite and provides no gradient signal. Wasserstein distance is finite and provides a meaningful gradient pointing μ toward ν.
**1D Wasserstein.** In one dimension, the Wasserstein distance has a beautiful closed form:
Wp(μ,ν)=(∫01∣Fμ−1(t)−Fν−1(t)∣pdt)1/p
where F−1 is the quantile function (inverse CDF). For p=1: W1=∫−∞∞∣Fμ(x)−Fν(x)∣dx. The optimal transport map is T=Fν−1∘Fμ (map quantiles to quantiles). This is why 1D optimal transport is O(nlogn) (just sort and pair up).
**Wasserstein between Gaussians.** For $\mu = \mathcal{N}(m_1, \Sigma_1)$ and $\nu = \mathcal{N}(m_2, \Sigma_2)$, the 2-Wasserstein distance has a closed form (the **Bures metric**):
For diagonal covariances: W22=∥m1−m2∥2+∑i(Σ1,i−Σ2,i)2 (where Σk,i denotes the i-th diagonal variance of Σk). This is used in the FID (Frechet Inception Distance) metric for evaluating generative models, where the Inception features of real and generated images are modeled as Gaussians.
The **dual formulation** of the 1-Wasserstein distance is:
W1(μ,ν)=sup∥f∥L≤1(Ex∼μ[f(x)]−Ey∼ν[f(y)])
where the supremum is over all 1-Lipschitz functionsf (satisfying ∣f(x)−f(y)∣≤∥x−y∥ for all x,y).
More generally, for cost c(x,y), the dual is:
supϕ,ψ{Eμ[ϕ(x)]+Eν[ψ(y)]:ϕ(x)+ψ(y)≤c(x,y)∀x,y}
where (ϕ,ψ) are Kantorovich potentials. For c=∥x−y∥, the constraint forces ϕ=−ψ and ϕ to be 1-Lipschitz.
**From duality to WGAN.** The Kantorovich dual replaces optimization over transport plans (high-dimensional joint distributions) with optimization over a single function. This is exactly what the **Wasserstein GAN (WGAN)** [@arjovsky2017wgan] exploits:
Critic (discriminator): Parameterize a 1-Lipschitz function fω and maximize Ex∼pdata[fω(x)]−Ez∼pz[fω(Gθ(z))].
Generator: Minimize the same objective (push generated samples closer to real data in Wasserstein sense).
The entropy term ϵH(γ) encourages the transport plan to spread out (higher entropy = more diffuse plan). As ϵ→0, Wϵ→W (exact OT). As ϵ→∞, the plan approaches the independent coupling γ=μ⊗ν.
Output: Transport plan γij∗=uiKijvj, cost Wϵ=⟨C,γ∗⟩
**Sinkhorn algorithm properties:**
Convergence: Linear rate O((1−δ)ℓ) where δ depends on ϵ and the cost matrix. Smaller ϵ means slower convergence but more accurate OT.
GPU-friendly: Each iteration is a matrix-vector multiply Kv and elementwise operations, perfectly suited for GPU parallelism. For n points, each iteration is O(n2).
Differentiable: Both the forward and backward passes are differentiable, enabling end-to-end learning with OT losses. Automatic differentiation through the Sinkhorn iterations is straightforward.
Log-domain stability: For small ϵ, the kernel K has entries close to 0 or very large. Computing in log-domain: logu,logv,logK avoids numerical underflow/overflow.
Minibatch Sinkhorn: For large-scale problems, compute Sinkhorn on random minibatches of source and target points. This gives a biased but useful estimator.
Optimal Transport Maps and Displacement Interpolation
Given the optimal transport map $T^*$ from $\mu$ to $\nu$, the **displacement interpolation** (McCann interpolation) defines a geodesic path between distributions:
μt=((1−t)Id+tT∗)#μ,t∈[0,1]
At t=0: μ0=μ. At t=1: μ1=ν. Each intermediate μt moves mass along the optimal transport paths, creating a natural interpolation in distribution space.
**Displacement interpolation in generative models.** Flow matching [@lipman2023flow] and rectified flows [@liu2023flow] learn the velocity field of the displacement interpolation:
xt=(1−t)x0+tx1where x0∼μ (noise),x1∼ν (data)
The velocity field vt(x)=x1−x0 produces straight-line trajectories from noise to data. Training the model vθ(xt,t)≈x1−x0 with MSE loss gives the flow matching objective. This is simpler than diffusion (no noise schedule, no SDE/ODE theory) and produces straighter flows that require fewer sampling steps.
The connection to OT: if (x0,x1) are coupled via the optimal transport plan (rather than independently), the trajectories are non-crossing and the learned velocity field is smoother, leading to better generation quality.
The **Wasserstein barycenter** of distributions $\mu_1, \ldots, \mu_K$ with weights $\lambda_1, \ldots, \lambda_K$ is:
μˉ=argminμ∑k=1KλkW22(μ,μk)
The barycenter is the "average" distribution in the Wasserstein sense, which preserves geometric structure better than mixture-based averaging.
**Applications:** Wasserstein barycenters are used for texture mixing, shape interpolation, multi-source domain adaptation (average the source domains into a barycenter, then adapt to the target), and aggregating predictions from multiple generative models. The entropic-regularized barycenter can be computed via iterative Sinkhorn projections.
Model return as a distribution; use Wp for comparison
Preserves return distribution shape
Fairness
W1 between demographic group distributions
Measure and mitigate distributional disparity
Data augmentation
Displacement interpolation between classes
Meaningful between-class interpolation
Sliced Wasserstein
Eθ[W1(projθμ,projθν)]
Scalable: 1D OT on random projections
Graph matching
OT between node feature distributions
Permutation-invariant comparison
Point cloud registration
OT map aligns two point sets
Correspondence without labels
**Computational complexity of OT.**
Method
Complexity
Accuracy
Use Case
Exact (linear program)
O(n3logn)
Exact
Small problems (n<1000)
Sinkhorn (entropic)
O(n2/ϵ2)
ϵ-approximate
Medium problems, differentiable
Sliced Wasserstein
O(Lnlogn) (L projections)
Approximation
Large-scale, any dimension
Minibatch Sinkhorn
O(b2) per batch
Biased estimator
Very large scale
Neural OT
O(forward pass)
Amortized
Continuous distributions
For high-dimensional problems, sliced Wasserstein computes W1 on random 1D projections (where OT is just sorting) and averages. This avoids the curse of dimensionality and scales to millions of points.