Skip to main content

Optimal Transport

Optimal transport (OT) is the mathematical theory of moving mass efficiently between distributions. It provides geometrically meaningful distances between probability distributions that, unlike KL divergence, respect the underlying geometry of the sample space. In machine learning, OT appears in generative models (Wasserstein GANs), domain adaptation, fairness, and increasingly as a foundation for modern generative frameworks like flow matching and rectified flows.

The Transport Problem

Given two probability distributions $\mu$ (source) and $\nu$ (target) over a space $\mathcal{X}$:

Monge formulation (1781): Find a transport map T:XXT: \mathcal{X} \to \mathcal{X} that pushes μ\mu forward to ν\nu (i.e., T#μ=νT_\# \mu = \nu, meaning ν(B)=μ(T1(B))\nu(B) = \mu(T^{-1}(B)) for all measurable BB) and minimizes:

infT:T#μ=νc(x,T(x))dμ(x)\inf_{T: T_\# \mu = \nu} \int c(x, T(x)) \, d\mu(x)

Kantorovich formulation (1942): Relax to a transport plan γΠ(μ,ν)\gamma \in \Pi(\mu, \nu) -- a joint distribution with marginals μ\mu and ν\nu -- that allows mass splitting:

minγΠ(μ,ν)X×Xc(x,y)dγ(x,y)\min_{\gamma \in \Pi(\mu, \nu)} \int_{\mathcal{X} \times \mathcal{X}} c(x, y) \, d\gamma(x, y)

where c(x,y)c(x, y) is the cost of moving a unit of mass from xx to yy, and Π(μ,ν)={γ:γ(x,y)dy=μ(x),γ(x,y)dx=ν(y)}\Pi(\mu, \nu) = \{\gamma : \int \gamma(x, y) dy = \mu(x), \int \gamma(x, y) dx = \nu(y)\}.

**Monge vs. Kantorovich.** Monge's formulation requires a deterministic map (each source point goes to exactly one destination), which may not exist (e.g., transporting a point mass to two point masses). Kantorovich's relaxation always has a solution and is a linear program. When the cost is $c(x,y) = \|x-y\|^2$ and $\mu$ is absolutely continuous, the Monge and Kantorovich solutions coincide: the optimal plan is supported on the graph of a map $T^*(x) = x - \nabla \phi(x)$ where $\phi$ is a convex function (Brenier's theorem).

Wasserstein Distance

The **$p$-Wasserstein distance** uses cost $c(x,y) = \|x - y\|^p$:

Wp(μ,ν)=(infγΠ(μ,ν)xypdγ(x,y))1/pW_p(\mu, \nu) = \left(\inf_{\gamma \in \Pi(\mu, \nu)} \int \|x - y\|^p \, d\gamma(x, y)\right)^{1/p}

The 1-Wasserstein distance (p=1p=1) is also called the Earth Mover's Distance (EMD): the minimum "work" (mass ×\times distance) needed to reshape one pile of dirt into another. The 2-Wasserstein distance (p=2p=2) has nicer geometric properties and connections to optimal maps.

**Why Wasserstein is better than KL for some tasks.**
PropertyKL DivergenceWasserstein Distance
Metric?No (asymmetric, no triangle ineq.)Yes (symmetric, triangle ineq.)
Non-overlapping supportDKL=D_{\text{KL}} = \inftyFinite (uses geometry)
Sensitivity to geometryNone (only ratios p/qp/q)Respects distances in X\mathcal{X}
Gradient qualityVanishes when supports don't overlapAlways provides useful gradients
ComputationO(n)O(n) (sample-based)O(n3)O(n^3) exact, O(n2/ϵ2)O(n^2/\epsilon^2) entropic

The key advantage: when μ\mu and ν\nu have disjoint supports (common early in GAN training when the generator produces images far from real data), KL divergence is infinite and provides no gradient signal. Wasserstein distance is finite and provides a meaningful gradient pointing μ\mu toward ν\nu.

**1D Wasserstein.** In one dimension, the Wasserstein distance has a beautiful closed form:

Wp(μ,ν)=(01Fμ1(t)Fν1(t)pdt)1/pW_p(\mu, \nu) = \left(\int_0^1 |F_\mu^{-1}(t) - F_\nu^{-1}(t)|^p \, dt\right)^{1/p}

where F1F^{-1} is the quantile function (inverse CDF). For p=1p = 1: W1=Fμ(x)Fν(x)dxW_1 = \int_{-\infty}^{\infty} |F_\mu(x) - F_\nu(x)| dx. The optimal transport map is T=Fν1FμT = F_\nu^{-1} \circ F_\mu (map quantiles to quantiles). This is why 1D optimal transport is O(nlogn)O(n \log n) (just sort and pair up).

**Wasserstein between Gaussians.** For $\mu = \mathcal{N}(m_1, \Sigma_1)$ and $\nu = \mathcal{N}(m_2, \Sigma_2)$, the 2-Wasserstein distance has a closed form (the **Bures metric**):

W22(μ,ν)=m1m22+tr(Σ1+Σ22(Σ11/2Σ2Σ11/2)1/2)W_2^2(\mu, \nu) = \|m_1 - m_2\|^2 + \text{tr}\left(\Sigma_1 + \Sigma_2 - 2\left(\Sigma_1^{1/2} \Sigma_2 \Sigma_1^{1/2}\right)^{1/2}\right)

For diagonal covariances: W22=m1m22+i(σ1,iσ2,i)2W_2^2 = \|m_1 - m_2\|^2 + \sum_i (\sigma_{1,i} - \sigma_{2,i})^2. This is used in the FID (Frechet Inception Distance) metric for evaluating generative models, where the Inception features of real and generated images are modeled as Gaussians.

Kantorovich Duality

The **dual formulation** of the 1-Wasserstein distance is:

W1(μ,ν)=supfL1(Exμ[f(x)]Eyν[f(y)])W_1(\mu, \nu) = \sup_{\|f\|_L \leq 1} \left(\mathbb{E}_{x \sim \mu}[f(x)] - \mathbb{E}_{y \sim \nu}[f(y)]\right)

where the supremum is over all 1-Lipschitz functions ff (satisfying f(x)f(y)xy|f(x) - f(y)| \leq \|x - y\| for all x,yx, y).

More generally, for cost c(x,y)c(x,y), the dual is:

supϕ,ψ{Eμ[ϕ(x)]+Eν[ψ(y)]:ϕ(x)+ψ(y)c(x,y)  x,y}\sup_{\phi, \psi} \left\{\mathbb{E}_\mu[\phi(x)] + \mathbb{E}_\nu[\psi(y)] : \phi(x) + \psi(y) \leq c(x,y) \; \forall x, y\right\}

where (ϕ,ψ)(\phi, \psi) are Kantorovich potentials. For c=xyc = \|x-y\|, the constraint forces ϕ=ψ\phi = -\psi and ϕ\phi to be 1-Lipschitz.

**From duality to WGAN.** The Kantorovich dual replaces optimization over transport plans (high-dimensional joint distributions) with optimization over a single function. This is exactly what the **Wasserstein GAN (WGAN)** [@arjovsky2017wgan] exploits:
  • Critic (discriminator): Parameterize a 1-Lipschitz function fωf_\omega and maximize Expdata[fω(x)]Ezpz[fω(Gθ(z))]\mathbb{E}_{x \sim p_{\text{data}}}[f_\omega(x)] - \mathbb{E}_{z \sim p_z}[f_\omega(G_\theta(z))].
  • Generator: Minimize the same objective (push generated samples closer to real data in Wasserstein sense).

The Lipschitz constraint is enforced via:

  • Weight clipping (original WGAN): Clip ω[c,c]\omega \in [-c, c]. Simple but biases toward simple functions.
  • Gradient penalty (WGAN-GP) ([?gulrajani2017improved]): Add λE[(fω(x^)1)2]\lambda \mathbb{E}[(\|\nabla f_\omega(\hat{x})\| - 1)^2] where x^\hat{x} interpolates between real and fake samples. More stable.
  • Spectral normalization ([?miyato2018spectral]): Normalize weight matrices by their spectral norm. Efficient and widely used.

Entropic Optimal Transport

Adding an **entropic regularization** term makes OT computationally tractable:

Wϵ(μ,ν)=minγΠ(μ,ν)i,jcijγij+ϵi,jγijlogγijW_\epsilon(\mu, \nu) = \min_{\gamma \in \Pi(\mu, \nu)} \sum_{i,j} c_{ij} \gamma_{ij} + \epsilon \sum_{i,j} \gamma_{ij} \log \gamma_{ij}

The entropy term ϵH(γ)\epsilon H(\gamma) encourages the transport plan to spread out (higher entropy = more diffuse plan). As ϵ0\epsilon \to 0, WϵWW_\epsilon \to W (exact OT). As ϵ\epsilon \to \infty, the plan approaches the independent coupling γ=μν\gamma = \mu \otimes \nu.

Input: Cost matrix CRn×mC \in \mathbb{R}^{n \times m}, marginals aΔn1a \in \Delta^{n-1}, bΔm1b \in \Delta^{m-1}, regularization ϵ>0\epsilon > 0

  1. Compute Gibbs kernel: Kij=exp(Cij/ϵ)K_{ij} = \exp(-C_{ij}/\epsilon)
  2. Initialize: v=1mv = \mathbf{1}_m
  3. for =1,2,\ell = 1, 2, \ldots until convergence do
  4.     ua(Kv)u \leftarrow a \oslash (Kv)     (row normalization: ui=ai/jKijvju_i = a_i / \sum_j K_{ij} v_j)
  5.     vb(Ku)v \leftarrow b \oslash (K^\top u)     (column normalization: vj=bj/iKijuiv_j = b_j / \sum_i K_{ij} u_i)
  6. end for

Output: Transport plan γij=uiKijvj\gamma^*_{ij} = u_i K_{ij} v_j, cost Wϵ=C,γW_\epsilon = \langle C, \gamma^* \rangle

**Sinkhorn algorithm properties:**
  • Convergence: Linear rate O((1δ))O((1-\delta)^\ell) where δ\delta depends on ϵ\epsilon and the cost matrix. Smaller ϵ\epsilon means slower convergence but more accurate OT.
  • GPU-friendly: Each iteration is a matrix-vector multiply KvKv and elementwise operations -- perfectly suited for GPU parallelism. For nn points, each iteration is O(n2)O(n^2).
  • Differentiable: Both the forward and backward passes are differentiable, enabling end-to-end learning with OT losses. Automatic differentiation through the Sinkhorn iterations is straightforward.
  • Log-domain stability: For small ϵ\epsilon, the kernel KK has entries close to 0 or very large. Computing in log-domain: logu,logv,logK\log u, \log v, \log K avoids numerical underflow/overflow.
  • Minibatch Sinkhorn: For large-scale problems, compute Sinkhorn on random minibatches of source and target points. This gives a biased but useful estimator.

Optimal Transport Maps and Displacement Interpolation

Given the optimal transport map $T^*$ from $\mu$ to $\nu$, the **displacement interpolation** (McCann interpolation) defines a geodesic path between distributions:

μt=((1t)Id+tT)#μ,t[0,1]\mu_t = ((1-t)\text{Id} + tT^*)_\# \mu, \quad t \in [0, 1]

At t=0t = 0: μ0=μ\mu_0 = \mu. At t=1t = 1: μ1=ν\mu_1 = \nu. Each intermediate μt\mu_t moves mass along the optimal transport paths, creating a natural interpolation in distribution space.

**Displacement interpolation in generative models.** Flow matching [@lipman2023flow] and rectified flows [@liu2023flow] learn the velocity field of the displacement interpolation:

xt=(1t)x0+tx1where x0μ (noise),  x1ν (data)x_t = (1-t)x_0 + t x_1 \quad \text{where } x_0 \sim \mu \text{ (noise)}, \; x_1 \sim \nu \text{ (data)}

The velocity field vt(x)=x1x0v_t(x) = x_1 - x_0 produces straight-line trajectories from noise to data. Training the model vθ(xt,t)x1x0v_\theta(x_t, t) \approx x_1 - x_0 with MSE loss gives the flow matching objective. This is simpler than diffusion (no noise schedule, no SDE/ODE theory) and produces straighter flows that require fewer sampling steps.

The connection to OT: if (x0,x1)(x_0, x_1) are coupled via the optimal transport plan (rather than independently), the trajectories are non-crossing and the learned velocity field is smoother, leading to better generation quality.

Wasserstein Barycenters

The **Wasserstein barycenter** of distributions $\mu_1, \ldots, \mu_K$ with weights $\lambda_1, \ldots, \lambda_K$ is:

μˉ=argminμk=1KλkW22(μ,μk)\bar{\mu} = \arg\min_\mu \sum_{k=1}^K \lambda_k W_2^2(\mu, \mu_k)

The barycenter is the "average" distribution in the Wasserstein sense, which preserves geometric structure better than mixture-based averaging.

**Applications:** Wasserstein barycenters are used for texture mixing, shape interpolation, multi-source domain adaptation (average the source domains into a barycenter, then adapt to the target), and aggregating predictions from multiple generative models. The entropic-regularized barycenter can be computed via iterative Sinkhorn projections.

ML Applications

ApplicationHow OT is UsedKey Advantage
Wasserstein GANCritic approximates W1W_1 via Kantorovich dualityStable training, meaningful loss
Flow matchingOT coupling gives straight trajectoriesFew-step generation
Domain adaptationMinimize W2W_2 between source and target featuresGeometry-aware alignment
FID scoreW22W_2^2 between Gaussians fit to Inception featuresStandard generative model metric
Distributional RLModel return as a distribution; use WpW_p for comparisonPreserves return distribution shape
FairnessW1W_1 between demographic group distributionsMeasure and mitigate distributional disparity
Data augmentationDisplacement interpolation between classesMeaningful between-class interpolation
Sliced WassersteinEθ[W1(projθμ,projθν)]\mathbb{E}_\theta[W_1(\text{proj}_\theta \mu, \text{proj}_\theta \nu)]Scalable: 1D OT on random projections
Graph matchingOT between node feature distributionsPermutation-invariant comparison
Point cloud registrationOT map aligns two point setsCorrespondence without labels
**Computational complexity of OT.**
MethodComplexityAccuracyUse Case
Exact (linear program)O(n3logn)O(n^3 \log n)ExactSmall problems (n<1000n < 1000)
Sinkhorn (entropic)O(n2/ϵ2)O(n^2 / \epsilon^2)ϵ\epsilon-approximateMedium problems, differentiable
Sliced WassersteinO(Lnlogn)O(Ln\log n) (LL projections)ApproximationLarge-scale, any dimension
Minibatch SinkhornO(b2)O(b^2) per batchBiased estimatorVery large scale
Neural OTO(forward pass)O(\text{forward pass})AmortizedContinuous distributions

For high-dimensional problems, sliced Wasserstein computes W1W_1 on random 1D projections (where OT is just sorting) and averages. This avoids the curse of dimensionality and scales to millions of points.

Notation Summary

SymbolMeaning
μ,ν\mu, \nuSource and target distributions
γ\gammaTransport plan (coupling)
TTTransport map (Monge)
T#μT_\# \muPushforward of μ\mu by TT
Π(μ,ν)\Pi(\mu, \nu)Set of all couplings with marginals μ,ν\mu, \nu
c(x,y)c(x, y)Transport cost function
WpW_ppp-Wasserstein distance
ϵ\epsilonEntropic regularization parameter
KKGibbs kernel: Kij=ecij/ϵK_{ij} = e^{-c_{ij}/\epsilon}
u,vu, vSinkhorn scaling vectors
ϕ,ψ\phi, \psiKantorovich dual potentials
μˉ\bar{\mu}Wasserstein barycenter
EMDEarth Mover's Distance (W1W_1)
FIDFrechet Inception Distance (Bures-W2W_2)