Skip to content

Commit fd1ff29

Browse files
Circle CICircle CI
authored andcommitted
CircleCI update of dev docs (2779).
1 parent 1075619 commit fd1ff29

File tree

375 files changed

+738752
-735639
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

375 files changed

+738752
-735639
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# GMM Flow\n\nIllustration of the flow of a Gaussian Mixture with\nrespect to its GMM-OT distance with respect to a\nfixed GMM.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: Eloi Tanguy <eloi.tanguy@u-paris>\n# Remi Flamary <[email protected]>\n# Julie Delon <[email protected]>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 4\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom matplotlib import colormaps as cm\nimport ot\nimport ot.plot\nfrom ot.utils import proj_SDP, proj_simplex\nfrom ot.gmm import gmm_ot_loss\nimport torch\nfrom torch.optim import Adam\nfrom matplotlib.patches import Ellipse"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Generate data and plot it\n\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"collapsed": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"torch.manual_seed(3)\nks = 3\nkt = 2\nd = 2\neps = 0.1\nm_s = torch.randn(ks, d)\nm_s.requires_grad_()\nm_t = torch.randn(kt, d)\nC_s = torch.randn(ks, d, d)\nC_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1))\nC_s += eps * torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1)\nC_s.requires_grad_()\nC_t = torch.randn(kt, d, d)\nC_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1))\nC_t += eps * torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1)\nw_s = torch.randn(ks)\nw_s = proj_simplex(w_s)\nw_s.requires_grad_()\nw_t = torch.tensor(ot.unif(kt))\n\n\ndef draw_cov(mu, C, color=None, label=None, nstd=1, alpha=.5):\n\n def eigsorted(cov):\n vals, vecs = np.linalg.eigh(cov)\n order = vals.argsort()[::-1]\n return vals[order], vecs[:, order]\n\n vals, vecs = eigsorted(C)\n theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))\n w, h = 2 * nstd * np.sqrt(vals)\n ell = Ellipse(xy=(mu[0], mu[1]),\n width=w, height=h, alpha=alpha,\n angle=theta, facecolor=color, edgecolor=color, label=label, fill=True)\n pl.gca().add_artist(ell)\n\n\ndef draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1):\n for k in range(ms.shape[0]):\n draw_cov(ms[k], Cs[k], color, None, nstd,\n alpha * ws[k])\n\n\naxis = [-3, 3, -3, 3]\npl.figure(1, (20, 10))\npl.clf()\n\npl.subplot(1, 2, 1)\npl.scatter(m_s[:, 0].detach(), m_s[:, 1].detach(), color='C0')\ndraw_gmm(m_s.detach(), C_s.detach(),\n torch.softmax(w_s, 0).detach().numpy(),\n color='C0')\npl.axis(axis)\npl.title('Source GMM')\n\npl.subplot(1, 2, 2)\npl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1')\ndraw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1')\npl.axis(axis)\npl.title('Target GMM')"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"## Gradient descent loop\n\n"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {
50+
"collapsed": false
51+
},
52+
"outputs": [],
53+
"source": [
54+
"n_gd_its = 100\nlr = 3e-2\nopt = Adam([{'params': m_s, 'lr': 2 * lr},\n {'params': C_s, 'lr': lr},\n {'params': w_s, 'lr': lr}])\nm_list = [m_s.data.numpy().copy()]\nC_list = [C_s.data.numpy().copy()]\nw_list = [torch.softmax(w_s, 0).data.numpy().copy()]\nloss_list = []\n\nfor _ in range(n_gd_its):\n opt.zero_grad()\n loss = gmm_ot_loss(m_s, m_t, C_s, C_t,\n torch.softmax(w_s, 0), w_t)\n loss.backward()\n opt.step()\n with torch.no_grad():\n C_s.data = proj_SDP(C_s.data, vmin=1e-6)\n m_list.append(m_s.data.numpy().copy())\n C_list.append(C_s.data.numpy().copy())\n w_list.append(torch.softmax(w_s, 0).data.numpy().copy())\n loss_list.append(loss.item())\n\npl.figure(2)\npl.clf()\npl.plot(loss_list)\npl.title('Loss')\npl.xlabel('its')\npl.ylabel('loss')"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"## Last step visualisation\n\n"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {
68+
"collapsed": false
69+
},
70+
"outputs": [],
71+
"source": [
72+
"axis = [-3, 3, -3, 3]\npl.figure(3, (10, 10))\npl.clf()\npl.title('GMM flow, last step')\npl.scatter(m_list[0][:, 0], m_list[0][:, 1], color='C0', label='Source')\ndraw_gmm(m_list[0], C_list[0], w_list[0], color='C0')\npl.axis(axis)\n\npl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1', label='Target')\ndraw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1')\npl.axis(axis)\n\nk = -1\npl.scatter(m_list[k][:, 0], m_list[k][:, 1], color='C2', alpha=1, label='Last step')\ndraw_gmm(m_list[k], C_list[k], w_list[0], color='C2', alpha=1)\n\npl.axis(axis)\npl.legend(fontsize=15)"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"## Steps visualisation\n\n"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {
86+
"collapsed": false
87+
},
88+
"outputs": [],
89+
"source": [
90+
"def index_to_color(i):\n return int(i**0.5)\n\n\nn_steps_visu = 100\npl.figure(3, (10, 10))\npl.clf()\npl.title('GMM flow, all steps')\n\nits_to_show = [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)]\ncmp = cm['plasma'].resampled(index_to_color(n_steps_visu))\n\npl.scatter(m_list[0][:, 0], m_list[0][:, 1],\n color=cmp(index_to_color(0)), label='Source')\ndraw_gmm(m_list[0], C_list[0], w_list[0],\n color=cmp(index_to_color(0)))\n\npl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(),\n color=cmp(index_to_color(n_steps_visu - 1)), label='Target')\ndraw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(),\n color=cmp(index_to_color(n_steps_visu - 1)))\n\n\nfor k in its_to_show:\n pl.scatter(m_list[k][:, 0], m_list[k][:, 1],\n color=cmp(index_to_color(k)), alpha=0.8)\n draw_gmm(m_list[k], C_list[k], w_list[0],\n color=cmp(index_to_color(k)), alpha=0.04)\n\npl.axis(axis)\npl.legend(fontsize=15)"
91+
]
92+
}
93+
],
94+
"metadata": {
95+
"kernelspec": {
96+
"display_name": "Python 3",
97+
"language": "python",
98+
"name": "python3"
99+
},
100+
"language_info": {
101+
"codemirror_mode": {
102+
"name": "ipython",
103+
"version": 3
104+
},
105+
"file_extension": ".py",
106+
"mimetype": "text/x-python",
107+
"name": "python",
108+
"nbconvert_exporter": "python",
109+
"pygments_lexer": "ipython3",
110+
"version": "3.10.14"
111+
}
112+
},
113+
"nbformat": 4,
114+
"nbformat_minor": 0
115+
}
Binary file not shown.

0 commit comments

Comments
 (0)