Wasserstein Descent \( \dot{\mathbb{H}}^1 \)-Ascent (WDHA) Documentation

Welcome to WDHA Documentation! This documentation describes the implementation for "Wasserstein Descent \( \dot{\mathbb{H}}^1 \)-Ascent (WDHA) algorithm" for computing Optimal Transport Barycenter (a.k.a Wasserstein Barycenter) introduced in our paper:

"Optimal Transport Barycenter via Nonconvex-Concave Minimax Optimization" [Link]
Kaheon Kim, Rentian Yao, Changbo Zhu, and Xiaohui Chen
International Conference on Machine Learning (ICML), 2025

1. Introduction

1.1. Formulation

We are interested in finding the Wasserstein barycenter \(\bar{\mu}\) for the given probability densities \(\left\{\mu_1,\cdots,\mu_n\right\}\)

\[\bar{\mu} = \arg\min_{\nu\in\mathcal{P}(\Omega)}\frac{1}{n}\sum_{i=1}^n\mathcal{W}_2^2(\nu,\mu_i)\]

Combining with Kantorovich dual formulation, we approach Wasserstein barycenter problem as nonconvex-concave minimax problem between Wasserstein space and Sobolev space.

\[ \begin{align*} \bar{\mu} &= \arg\min_{\nu\in\mathcal{P}(\Omega)}\frac{1}{n}\sum_{i=1}^n\mathcal{W}_2^2(\nu,\mu_i) \\ &= \arg\min_{\nu\in\mathcal{P}(\Omega)}\frac{1}{n}\sum_{i=1}^n\max_{\varphi_i : \text{convex}} \underbrace{ \int\left(\frac{\|x\|_2^2}{2}-\varphi_i(x)\right) d\nu(x) + \int\left(\frac{\|y\|_2^2}{2}-\varphi_i^*(y)\right) d\mu_i(y) }_{\mathcal{I}_{\nu}^{\mu_i}(\varphi_i)} \end{align*} \]

1.2. Algorithm

To tackle this minimax problem, we introduce Wasserstein Descent \( \dot{\mathbb{H}}^1 \)-Ascent (WDHA) algorithm:

WDHA Algorithm

  • Wasserstein Gradient: \(\boldsymbol{\nabla} \mathcal{J}(\nu, \boldsymbol{\varphi}) = \text{id} - \nabla \bar{\varphi}, \text{ where } \bar{\varphi} = \frac{1}{n}\sum_{i=1}^n \varphi_i\)
  • \(\dot{\mathbb{H}}^1\)-Gradient: \(\bbNabla_{\varphi_i} \mathcal{J}(\nu, \boldsymbol{\varphi}) = \frac{1}{n}(-\Delta)^{-1}(-\nu + (\nabla \varphi_i^*)_\# \mu_i)\)
  • Convex Hull: \( (\cdot)^{**}\) where \( \varphi^*(y) = \sup_{x\in \Omega} \left< x,y \right>-\varphi(x)\) is convex conjugate for \(\varphi:\Omega\rightarrow \mathbb{R}\)
You can see more details of formulation and algorithm in the Section 3. Nonconvex-Concave Minimax Formulation for Optimal Transport Barycenter.

2. Implementation

The code is written in python and built based on the core functionality (c-transform and pushforward measure) provided by Flavien Leger in the BFM package. The codes are also available on the Github repository as a ipynb and py files.


  !git clone https://github.com/Math-Jacobs/bfm
  !pip install bfm/python
  !pip install pot
  !wget -O functions.py https://raw.githubusercontent.com/kaheonkim/WDHA/main/implementation2D/functions.py
  !wget -O metric.py https://raw.githubusercontent.com/kaheonkim/WDHA/main/implementation2D/metric.py

  import numpy as np
  from metric import *
  from functions import *

2.1. Example 1 : Uniform Distributions with Varying Support

Four distinct shapes are placed at four different locations: a square at the top-left, a heart at the bottom-left, a cross at the bottom-right, and a circle at the top-right. These shapes represent four uniform probability densities with distinct supports. We demonstrate the application of WDHA on these four synthetic uniform distributions.


  n1, n2 = 1024, 1024
  x, y = np.meshgrid(np.linspace(0.5/n1, 1-0.5/n1, n1),
                     np.linspace(0.5/n2, 1-0.5/n2, n2))
  r = 0.1
  # Initialize densities
  mu1 = np.zeros((n2, n1))
  mu1[(x-0.8)**2 + (y-0.8)**2 < r**2] = 1
  mu2 = np.zeros((n2, n1))
  mu2[(0.8-r/2.5 < x) & (x < 0.8+r/2.5) & (0.3-r < y) & (y < 0.3+r)] = 1
  mu2[(0.3-r/2.5 < y) & (y < 0.3+r/2.5) & (0.8-r < x) & (x < 0.8+r)] = 1
  
  # Normalize
  mu1 *= n1*n2 / np.sum(mu1)
  mu2 *= n1*n2 / np.sum(mu2)
  
  heart = np.zeros((n2, n1))
  heart[((10*x-2)**2+(10*(y-0.3))**2-1)**3 - (10*x-2)**2*(10*(y-0.3))**3 < 0] = 1
  heart *= n1 * n2 / np.sum(heart)
  
  rectangle = np.zeros((n2, n1))
  rectangle[(x < 0.3) & (x > 0.1) & (y > 0.7) & (y < 0.9)] = 1
  rectangle *= n1*n2 / np.sum(rectangle)
  
  mu = [mu1, mu2, heart, rectangle]
  plotting(mu, np.zeros((n2,n1)), '_', save_option=False)
4 Uniform Densities>
4 Uniform Densities
The Wasserstein Barycenter for 4 uniform distributions computed by WDHA:
mu_WGHA = frechet_mean(mu, 300, 'MU', save_option=False, return_option=True)
WDHA Barycenter>
Wasserstein Barycenter computed by WDHA
You can directly run the code for this example with Google Colab Notebook.

2.2. Example 2: High-resolution Handwritten Digit Images

WDHA is applied to the barycenter problem using 100 high-resolution (500×500) images of the digit 8, provided by Cédric Beaulac and Jeffrey S. Rosenthal in their article "Analysis of a high-resolution hand-written digits data set with writer characteristics". The dataset — Images(500x500).npy and WriterInfo.npy — is available at google drive link.


  Images = np.load('/content/drive/My Drive/WDHA/Images(500x500).npy')
  WriterInfo = np.load('/content/drive/My Drive/WDHA/WriterInfo.npy')
  digit = WriterInfo[:, 0]
  user = WriterInfo[:, -1]
  num_image = 100
  num_iter = 300
  numbers8 = 255 - Images[(digit == 8)][:num_image].astype('float64')
  
  for j in range(num_image):
      numbers8[j] /= np.sum(numbers8[j])
      numbers8[j] *= 500 * 500
  del Images, WriterInfo, user, digit

  fig, axes = plt.subplots(1, 3, figsize=(10, 4))  # 3 rows, 1 column
  
  plotting_mnist(numbers8[0], '', ax=axes[0])
  plotting_mnist(numbers8[1], '', ax=axes[1])
  plotting_mnist(numbers8[2], '', ax=axes[2])
  
  plt.tight_layout()
  plt.show()
Exemplary Digits 8
Exemplary Digits 8
The Wasserstein Barycenter for 100 handwritten images of digit 8 computed by WDHA:

  bary8 = frechet_mean(numbers8, num_iter, 'mnist', plot_option=False, save_option=False, return_option=True)
  plotting_mnist(bary8, '')
WDHA Barycenter of Digits>
Wasserstein Barycenter computed by WDHA
You can see the detailed analysis and comparison with Sinkhorn-type methods in Section 4. Numerical Studies.