Wasserstein Descent \( \dot{\mathbb{H}}^1 \)-Ascent (WDHA) Documentation

Welcome to WDHA Documentation! This documentation describes the implementation for "Wasserstein Descent \( \dot{\mathbb{H}}^1 \)-Ascent (WDHA) algorithm" for computing Optimal Transport Barycenter (a.k.a Wasserstein Barycenter) introduced in our paper:

"Optimal Transport Barycenter via Nonconvex-Concave Minimax Optimization" [Link]
Kaheon Kim, Rentian Yao, Changbo Zhu, and Xiaohui Chen
International Conference on Machine Learning (ICML), 2025

WDHA Algorithm

1. Introduction

1.1. Formulation

We are interested in finding the Wasserstein barycenter \(\bar{\mu}\) for the given probability densities \(\left\{\mu_1,\cdots,\mu_n\right\}\)

\[\bar{\mu} = \arg\min_{\nu\in\mathcal{P}(\Omega)}\frac{1}{n}\sum_{i=1}^n\mathcal{W}_2^2(\nu,\mu_i)\]

Combining with Kantorovich dual formulation, we approach Wasserstein barycenter problem as nonconvex-concave minimax problem between Wasserstein space and Sobolev space.

\[ \begin{align*} \bar{\mu} &= \arg\min_{\nu\in\mathcal{P}(\Omega)}\frac{1}{n}\sum_{i=1}^n\mathcal{W}_2^2(\nu,\mu_i) \\ &= \arg\min_{\nu\in\mathcal{P}(\Omega)}\frac{1}{n}\sum_{i=1}^n\max_{\varphi_i : \text{convex}} \underbrace{ \int\left(\frac{\|x\|_2^2}{2}-\varphi_i(x)\right) d\nu(x) + \int\left(\frac{\|y\|_2^2}{2}-\varphi_i^*(y)\right) d\mu_i(y) }_{\mathcal{I}_{\nu}^{\mu_i}(\varphi_i)} \end{align*} \]

1.2. Algorithm

To tackle this minimax problem, we introduce Wasserstein Descent \( \dot{\mathbb{H}}^1 \)-Ascent (WDHA) algorithm:

WDHA Algorithm

  • Wasserstein Gradient: \(\boldsymbol{\nabla} \mathcal{J}(\nu, \boldsymbol{\varphi}) = \text{id} - \nabla \bar{\varphi}, \text{ where } \bar{\varphi} = \frac{1}{n}\sum_{i=1}^n \varphi_i\)
  • \(\dot{\mathbb{H}}^1\)-Gradient: \(\bbNabla_{\varphi_i} \mathcal{J}(\nu, \boldsymbol{\varphi}) = \frac{1}{n}(-\Delta)^{-1}(-\nu + (\nabla \varphi_i^*)_\# \mu_i)\)
  • Convex Hull: \( (\cdot)^{**}\) where \( \varphi^*(y) = \sup_{x\in \Omega} \left< x,y \right>-\varphi(x)\) is convex conjugate for \(\varphi:\Omega\rightarrow \mathbb{R}\)
You can see more details of formulation and algorithm in the Section 3. Nonconvex-Concave Minimax Formulation for Optimal Transport Barycenter.

2. Implementation

The code is written in python and built based on the core functionality (c-transform and pushforward measure) provided by Flavien Leger in the BFM package. The codes are also available on the Github repository as a ipynb and py files.


  !git clone https://github.com/Math-Jacobs/bfm
  !pip install bfm/python
  !pip install pot
  !wget -O functions.py https://raw.githubusercontent.com/kaheonkim/WDHA/main/implementation2D/functions.py
  !wget -O metric.py https://raw.githubusercontent.com/kaheonkim/WDHA/main/implementation2D/metric.py

  import numpy as np
  from metric import *
  from functions import *

2.1. Example 1 : Uniform Distributions with Varying Support

Four distinct shapes are placed at four different locations: a square at the top-left, a heart at the bottom-left, a cross at the bottom-right, and a circle at the top-right. These shapes represent four uniform probability densities with distinct supports. We demonstrate the application of WDHA on these four synthetic uniform distributions.


  n1, n2 = 1024, 1024
  x, y = np.meshgrid(np.linspace(0.5/n1, 1-0.5/n1, n1),
                     np.linspace(0.5/n2, 1-0.5/n2, n2))
  r = 0.1
  # Initialize densities
  mu1 = np.zeros((n2, n1))
  mu1[(x-0.8)**2 + (y-0.8)**2 < r**2] = 1
  mu2 = np.zeros((n2, n1))
  mu2[(0.8-r/2.5 < x) & (x < 0.8+r/2.5) & (0.3-r < y) & (y < 0.3+r)] = 1
  mu2[(0.3-r/2.5 < y) & (y < 0.3+r/2.5) & (0.8-r < x) & (x < 0.8+r)] = 1
  
  # Normalize
  mu1 *= n1*n2 / np.sum(mu1)
  mu2 *= n1*n2 / np.sum(mu2)
  
  heart = np.zeros((n2, n1))
  heart[((10*x-2)**2+(10*(y-0.3))**2-1)**3 - (10*x-2)**2*(10*(y-0.3))**3 < 0] = 1
  heart *= n1 * n2 / np.sum(heart)
  
  rectangle = np.zeros((n2, n1))
  rectangle[(x < 0.3) & (x > 0.1) & (y > 0.7) & (y < 0.9)] = 1
  rectangle *= n1*n2 / np.sum(rectangle)
  
  mu = [mu1, mu2, heart, rectangle]
  plotting(mu, np.zeros((n2,n1)), '_', save_option=False)
4 Uniform Densities>
4 Uniform Densities
The Wasserstein Barycenter for 4 uniform distributions computed by WDHA:
mu_WGHA = frechet_mean(mu, 300, 'MU', save_option=False, return_option=True)
WDHA Barycenter>
Wasserstein Barycenter computed by WDHA
You can directly run the code for this example with Google Colab Notebook.

2.2. Example 2: High-resolution Handwritten Digit Images

WDHA is applied to the barycenter problem using 100 high-resolution (500×500) images of the digit 8, provided by Cédric Beaulac and Jeffrey S. Rosenthal in their article "Analysis of a high-resolution hand-written digits data set with writer characteristics". The dataset — Images(500x500).npy and WriterInfo.npy — is available at google drive link.


  Images = np.load('/content/drive/My Drive/WDHA/Images(500x500).npy')
  WriterInfo = np.load('/content/drive/My Drive/WDHA/WriterInfo.npy')
  digit = WriterInfo[:, 0]
  user = WriterInfo[:, -1]
  num_image = 100
  num_iter = 300
  numbers8 = 255 - Images[(digit == 8)][:num_image].astype('float64')
  
  for j in range(num_image):
      numbers8[j] /= np.sum(numbers8[j])
      numbers8[j] *= 500 * 500
  del Images, WriterInfo, user, digit

  fig, axes = plt.subplots(1, 3, figsize=(10, 4))  # 3 rows, 1 column
  
  plotting_mnist(numbers8[0], '', ax=axes[0])
  plotting_mnist(numbers8[1], '', ax=axes[1])
  plotting_mnist(numbers8[2], '', ax=axes[2])
  
  plt.tight_layout()
  plt.show()
Exemplary Digits 8
Exemplary Digits 8
The Wasserstein Barycenter for 100 handwritten images of digit 8 computed by WDHA:

  bary8 = frechet_mean(numbers8, num_iter, 'mnist', plot_option=False, save_option=False, return_option=True)
  plotting_mnist(bary8, '')
WDHA Barycenter of Digits>
Wasserstein Barycenter computed by WDHA
You can see the detailed analysis and comparison with Sinkhorn-type methods in Section 4. Numerical Studies.