|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +r""" |
| 3 | +=============================================================== |
| 4 | +Learning sample marginal distribution with CO-Optimal Transport |
| 5 | +=============================================================== |
| 6 | +
|
| 7 | +In this example, we illustrate how to estimate the sample marginal distribution which minimizes |
| 8 | +the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data |
| 9 | +:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed |
| 10 | +histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem |
| 11 | +
|
| 12 | +.. math:: |
| 13 | + \min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right) |
| 14 | +
|
| 15 | +where :math:`\Delta` is the probability simplex. This minimization is done with a |
| 16 | +simple projected gradient descent in PyTorch. We use the automatic backend of POT that |
| 17 | +allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2` |
| 18 | +with differentiable losses. |
| 19 | +
|
| 20 | +.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). |
| 21 | + `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_. |
| 22 | + Advances in Neural Information Processing Systems, 33. |
| 23 | +""" |
| 24 | + |
| 25 | +# Author: Remi Flamary <[email protected]> |
| 26 | +# Quang Huy Tran <[email protected]> |
| 27 | +# License: MIT License |
| 28 | + |
| 29 | +from matplotlib.patches import ConnectionPatch |
| 30 | +import torch |
| 31 | +import numpy as np |
| 32 | + |
| 33 | +import matplotlib.pyplot as pl |
| 34 | +import ot |
| 35 | + |
| 36 | +from ot.coot import co_optimal_transport as coot |
| 37 | +from ot.coot import co_optimal_transport2 as coot2 |
| 38 | + |
| 39 | + |
| 40 | +# %% |
| 41 | +# Generate data |
| 42 | +# ------------- |
| 43 | +# The source and clean target matrices are generated by |
| 44 | +# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and |
| 45 | +# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`. |
| 46 | +# The target matrix is then contaminated by adding 5 row outliers. |
| 47 | +# Intuitively, we expect that the estimated sample distribution should ignore these outliers, |
| 48 | +# i.e. their weights should be zero. |
| 49 | + |
| 50 | +np.random.seed(182) |
| 51 | + |
| 52 | +n1, d1 = 20, 16 |
| 53 | +n2, d2 = 10, 8 |
| 54 | +n = 15 |
| 55 | + |
| 56 | +X = ( |
| 57 | + torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] + |
| 58 | + torch.cos(torch.arange(d1) * torch.pi / d1)[None, :] |
| 59 | +) |
| 60 | + |
| 61 | +# Generate clean target data mixed with outliers |
| 62 | +Y_noisy = torch.randn((n, d2)) * 10.0 |
| 63 | +Y_noisy[:n2, :] = ( |
| 64 | + torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] + |
| 65 | + torch.cos(torch.arange(d2) * torch.pi / d2)[None, :] |
| 66 | +) |
| 67 | +Y = Y_noisy[:n2, :] |
| 68 | + |
| 69 | +X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double() |
| 70 | + |
| 71 | +fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5)) |
| 72 | +axes[0].imshow(X, vmin=-2, vmax=2) |
| 73 | +axes[0].set_title('$X$') |
| 74 | + |
| 75 | +axes[1].imshow(Y, vmin=-2, vmax=2) |
| 76 | +axes[1].set_title('Clean $Y$') |
| 77 | + |
| 78 | +axes[2].imshow(Y_noisy, vmin=-2, vmax=2) |
| 79 | +axes[2].set_title('Noisy $Y$') |
| 80 | + |
| 81 | +pl.tight_layout() |
| 82 | + |
| 83 | +# %% |
| 84 | +# Optimize the COOT distance with respect to the sample marginal distribution |
| 85 | +# --------------------------------------------------------------------------- |
| 86 | + |
| 87 | +losses = [] |
| 88 | +lr = 1e-3 |
| 89 | +niter = 1000 |
| 90 | + |
| 91 | +b = torch.tensor(ot.unif(n), requires_grad=True) |
| 92 | + |
| 93 | +for i in range(niter): |
| 94 | + |
| 95 | + loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False) |
| 96 | + losses.append(float(loss)) |
| 97 | + |
| 98 | + loss.backward() |
| 99 | + |
| 100 | + with torch.no_grad(): |
| 101 | + b -= lr * b.grad # gradient step |
| 102 | + b[:] = ot.utils.proj_simplex(b) # projection on the simplex |
| 103 | + |
| 104 | + b.grad.zero_() |
| 105 | + |
| 106 | +# Estimated sample marginal distribution and training loss curve |
| 107 | +pl.plot(losses[10:]) |
| 108 | +pl.title('CO-Optimal Transport distance') |
| 109 | + |
| 110 | +print(f"Marginal distribution = {b.detach().numpy()}") |
| 111 | + |
| 112 | +# %% |
| 113 | +# Visualizing the row and column alignments with the estimated sample marginal distribution |
| 114 | +# ----------------------------------------------------------------------------------------- |
| 115 | +# |
| 116 | +# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers. |
| 117 | + |
| 118 | +X, Y_noisy = X.numpy(), Y_noisy.numpy() |
| 119 | +b = b.detach().numpy() |
| 120 | + |
| 121 | +pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True) |
| 122 | + |
| 123 | +fig = pl.figure(4, (9, 7)) |
| 124 | +pl.clf() |
| 125 | + |
| 126 | +ax1 = pl.subplot(2, 2, 3) |
| 127 | +pl.imshow(X, vmin=-2, vmax=2) |
| 128 | +pl.xlabel('$X$') |
| 129 | + |
| 130 | +ax2 = pl.subplot(2, 2, 2) |
| 131 | +ax2.yaxis.tick_right() |
| 132 | +pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2) |
| 133 | +pl.title("Transpose(Noisy $Y$)") |
| 134 | +ax2.xaxis.tick_top() |
| 135 | + |
| 136 | +for i in range(n1): |
| 137 | + j = np.argmax(pi_sample[i, :]) |
| 138 | + xyA = (d1 - .5, i) |
| 139 | + xyB = (j, d2 - .5) |
| 140 | + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, |
| 141 | + coordsB=ax2.transData, color="black") |
| 142 | + fig.add_artist(con) |
| 143 | + |
| 144 | +for i in range(d1): |
| 145 | + j = np.argmax(pi_feature[i, :]) |
| 146 | + xyA = (i, -.5) |
| 147 | + xyB = (-.5, j) |
| 148 | + con = ConnectionPatch( |
| 149 | + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") |
| 150 | + fig.add_artist(con) |
0 commit comments