|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +===================================== |
| 4 | +OT Barycenter with Generic Costs Demo |
| 5 | +===================================== |
| 6 | +
|
| 7 | +This example illustrates the computation of an Optimal Transport for a ground |
| 8 | +cost that is not a power of a norm. We take the example of ground costs |
| 9 | +:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear) |
| 10 | +projection onto a circle k. This is an example of the fixed-point barycenter |
| 11 | +solver introduced in [74] which generalises [20]. |
| 12 | +
|
| 13 | +The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in |
| 14 | +\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over |
| 15 | +:math:`x` with Pytorch. |
| 16 | +
|
| 17 | +[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing |
| 18 | +Barycentres of Measures for Generic Transport |
| 19 | +Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) |
| 20 | +
|
| 21 | +[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein |
| 22 | +Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International |
| 23 | +Conference in Machine Learning |
| 24 | +
|
| 25 | +""" |
| 26 | + |
| 27 | +# Author: Eloi Tanguy <[email protected]> |
| 28 | +# |
| 29 | +# License: MIT License |
| 30 | + |
| 31 | +# sphinx_gallery_thumbnail_number = 1 |
| 32 | + |
| 33 | +# %% Generate data |
| 34 | +import torch |
| 35 | +from torch.optim import Adam |
| 36 | +from ot.utils import dist |
| 37 | +import numpy as np |
| 38 | +from ot.lp import free_support_barycenter_generic_costs |
| 39 | +import matplotlib.pyplot as plt |
| 40 | + |
| 41 | + |
| 42 | +torch.manual_seed(42) |
| 43 | + |
| 44 | +n = 100 # number of points of the of the barycentre |
| 45 | +d = 2 # dimensions of the original measure |
| 46 | +K = 4 # number of measures to barycentre |
| 47 | +m = 50 # number of points of the measures |
| 48 | +b_list = [torch.ones(m) / m] * K # weights of the 4 measures |
| 49 | +weights = torch.ones(K) / K # weights for the barycentre |
| 50 | +stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo |
| 51 | + |
| 52 | + |
| 53 | +# map R^2 -> R^2 projection onto circle |
| 54 | +def proj_circle(X, origin, radius): |
| 55 | + diffs = X - origin[None, :] |
| 56 | + norms = torch.norm(diffs, dim=1) |
| 57 | + return origin[None, :] + radius * diffs / norms[:, None] |
| 58 | + |
| 59 | + |
| 60 | +# circles on which to project |
| 61 | +origin1 = torch.tensor([-1.0, -1.0]) |
| 62 | +origin2 = torch.tensor([-1.0, 2.0]) |
| 63 | +origin3 = torch.tensor([2.0, 2.0]) |
| 64 | +origin4 = torch.tensor([2.0, -1.0]) |
| 65 | +r = np.sqrt(2) |
| 66 | +P_list = [ |
| 67 | + lambda X: proj_circle(X, origin1, r), |
| 68 | + lambda X: proj_circle(X, origin2, r), |
| 69 | + lambda X: proj_circle(X, origin3, r), |
| 70 | + lambda X: proj_circle(X, origin4, r), |
| 71 | +] |
| 72 | + |
| 73 | +# measures to barycentre are projections of different random circles |
| 74 | +# onto the K circles |
| 75 | +Y_list = [] |
| 76 | +for k in range(K): |
| 77 | + t = torch.rand(m) * 2 * np.pi |
| 78 | + X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1) |
| 79 | + X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :] |
| 80 | + Y_list.append(P_list[k](X_temp)) |
| 81 | + |
| 82 | + |
| 83 | +# %% Define costs and ground barycenter function |
| 84 | +# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a |
| 85 | +# (n, n_k) matrix of costs |
| 86 | +def c1(x, y): |
| 87 | + return dist(P_list[0](x), y) |
| 88 | + |
| 89 | + |
| 90 | +def c2(x, y): |
| 91 | + return dist(P_list[1](x), y) |
| 92 | + |
| 93 | + |
| 94 | +def c3(x, y): |
| 95 | + return dist(P_list[2](x), y) |
| 96 | + |
| 97 | + |
| 98 | +def c4(x, y): |
| 99 | + return dist(P_list[3](x), y) |
| 100 | + |
| 101 | + |
| 102 | +cost_list = [c1, c2, c3, c4] |
| 103 | + |
| 104 | + |
| 105 | +# batched total ground cost function for candidate points x (n, d) |
| 106 | +# for computation of the ground barycenter B with gradient descent |
| 107 | +def C(x, y): |
| 108 | + """ |
| 109 | + Computes the barycenter cost for candidate points x (n, d) and |
| 110 | + measure supports y: List(n, d_k). |
| 111 | + """ |
| 112 | + n = x.shape[0] |
| 113 | + K = len(y) |
| 114 | + out = torch.zeros(n) |
| 115 | + for k in range(K): |
| 116 | + out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1) |
| 117 | + return out |
| 118 | + |
| 119 | + |
| 120 | +# ground barycenter function |
| 121 | +def B(y, its=150, lr=1, stop_threshold=stop_threshold): |
| 122 | + """ |
| 123 | + Computes the ground barycenter for measure supports y: List(n, d_k). |
| 124 | + Output: (n, d) array |
| 125 | + """ |
| 126 | + x = torch.randn(n, d) |
| 127 | + x.requires_grad_(True) |
| 128 | + opt = Adam([x], lr=lr) |
| 129 | + for _ in range(its): |
| 130 | + x_prev = x.data.clone() |
| 131 | + opt.zero_grad() |
| 132 | + loss = torch.sum(C(x, y)) |
| 133 | + loss.backward() |
| 134 | + opt.step() |
| 135 | + diff = torch.sum((x.data - x_prev) ** 2) |
| 136 | + if diff < stop_threshold: |
| 137 | + break |
| 138 | + return x |
| 139 | + |
| 140 | + |
| 141 | +# %% Compute the barycenter measure |
| 142 | +fixed_point_its = 10 |
| 143 | +X_init = torch.rand(n, d) |
| 144 | +X_bar = free_support_barycenter_generic_costs( |
| 145 | + X_init, |
| 146 | + Y_list, |
| 147 | + b_list, |
| 148 | + cost_list, |
| 149 | + B, |
| 150 | + max_its=fixed_point_its, |
| 151 | + stop_threshold=stop_threshold, |
| 152 | +) |
| 153 | + |
| 154 | +# %% Plot Barycenter (Iteration 10) |
| 155 | +alpha = 0.5 |
| 156 | +labels = ["circle 1", "circle 2", "circle 3", "circle 4"] |
| 157 | +for Y, label in zip(Y_list, labels): |
| 158 | + plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label) |
| 159 | +plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha) |
| 160 | +plt.axis("equal") |
| 161 | +plt.xlim(-0.3, 1.3) |
| 162 | +plt.ylim(-0.3, 1.3) |
| 163 | +plt.axis("off") |
| 164 | +plt.legend() |
| 165 | +plt.tight_layout() |
| 166 | + |
| 167 | +# %% |
0 commit comments