Skip to content

Commit 15f4b29

Browse files
authored
Merge branch 'master' into stochastic_OT
2 parents 63b34bf + 5180023 commit 15f4b29

26 files changed

+3279
-292
lines changed

Makefile

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22

33
PYTHON=python3
4+
branch := $(shell git symbolic-ref --short -q HEAD)
45

56
help :
67
@echo "The following make targets are available:"
@@ -57,6 +58,16 @@ rdoc :
5758
notebook :
5859
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
5960

61+
bench :
62+
@git stash >/dev/null 2>&1
63+
@echo 'Branch master'
64+
@git checkout master >/dev/null 2>&1
65+
python3 $(script)
66+
@echo 'Branch $(branch)'
67+
@git checkout $(branch) >/dev/null 2>&1
68+
python3 $(script)
69+
@git stash apply >/dev/null 2>&1
70+
6071
autopep8 :
6172
autopep8 -ir test ot examples --jobs -1
6273

Loading
Loading
Loading
Loading
Loading
Loading
Loading
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"%matplotlib inline"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"\n# 1D Wasserstein barycenter comparison between exact LP and entropic regularization\n\n\nThis example illustrates the computation of regularized Wasserstein Barycenter\nas proposed in [3] and exact LP barycenters using standard LP solver.\n\nIt reproduces approximately Figure 3.1 and 3.2 from the following paper:\nCuturi, M., & Peyr\u00e9, G. (2016). A smoothed dual approach for variational\nWasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.\n\n[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyr\u00e9, G. (2015).\nIterative Bregman projections for regularized transportation problems\nSIAM Journal on Scientific Computing, 37(2), A1111-A1138.\n\n\n"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {
25+
"collapsed": false
26+
},
27+
"outputs": [],
28+
"source": [
29+
"# Author: Remi Flamary <[email protected]>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection # noqa\n\n#import ot.lp.cvx as cvx"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"Gaussian Data\n-------------\n\n"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"metadata": {
43+
"collapsed": false
44+
},
45+
"outputs": [],
46+
"source": [
47+
"#%% parameters\n\nproblems = []\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])"
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"metadata": {},
53+
"source": [
54+
"Dirac Data\n----------\n\n"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {
61+
"collapsed": false
62+
},
63+
"outputs": [],
64+
"source": [
65+
"#%% parameters\n\na1 = 1.0 * (x > 10) * (x < 50)\na2 = 1.0 * (x > 60) * (x < 80)\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\n#%% parameters\n\na1 = np.zeros(n)\na2 = np.zeros(n)\n\na1[10] = .25\na1[20] = .5\na1[30] = .25\na2[80] = 1\n\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {},
71+
"source": [
72+
"Final figure\n------------\n\n\n"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": null,
78+
"metadata": {
79+
"collapsed": false
80+
},
81+
"outputs": [],
82+
"source": [
83+
"#%% plot\n\nnbm = len(problems)\nnbm2 = (nbm // 2)\n\n\npl.figure(2, (20, 6))\npl.clf()\n\nfor i in range(nbm):\n\n A = problems[i][0]\n bary_l2 = problems[i][1][0]\n bary_wass = problems[i][1][1]\n bary_wass2 = problems[i][1][2]\n\n pl.subplot(2, nbm, 1 + i)\n for j in range(n_distributions):\n pl.plot(x, A[:, j])\n if i == nbm2:\n pl.title('Distributions')\n pl.xticks(())\n pl.yticks(())\n\n pl.subplot(2, nbm, 1 + i + nbm)\n\n pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')\n pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\n pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\n if i == nbm - 1:\n pl.legend()\n if i == nbm2:\n pl.title('Barycenters')\n\n pl.xticks(())\n pl.yticks(())"
84+
]
85+
}
86+
],
87+
"metadata": {
88+
"kernelspec": {
89+
"display_name": "Python 3",
90+
"language": "python",
91+
"name": "python3"
92+
},
93+
"language_info": {
94+
"codemirror_mode": {
95+
"name": "ipython",
96+
"version": 3
97+
},
98+
"file_extension": ".py",
99+
"mimetype": "text/x-python",
100+
"name": "python",
101+
"nbconvert_exporter": "python",
102+
"pygments_lexer": "ipython3",
103+
"version": "3.6.5"
104+
}
105+
},
106+
"nbformat": 4,
107+
"nbformat_minor": 0
108+
}

0 commit comments

Comments
 (0)