Skip to content

Commit 6ac8d40

Browse files
committed
add all pages in documentation
1 parent 45d232f commit 6ac8d40

File tree

7 files changed

+1147
-0
lines changed

7 files changed

+1147
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"%matplotlib inline"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"\n# OT for multi-source target shift\n\n\nThis example introduces a target shift problem with two 2D source and 1 target domain.\n\n\n"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {
25+
"collapsed": false
26+
},
27+
"outputs": [],
28+
"source": [
29+
"# Authors: Remi Flamary <[email protected]>\n# Ievgen Redko <[email protected]>\n#\n# License: MIT License\n\nimport pylab as pl\nimport numpy as np\nimport ot\nfrom ot.datasets import make_data_classif"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"Generate data\n-------------\n\n"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"metadata": {
43+
"collapsed": false
44+
},
45+
"outputs": [],
46+
"source": [
47+
"n = 50\nsigma = 0.3\nnp.random.seed(1985)\n\np1 = .2\ndec1 = [0, 2]\n\np2 = .9\ndec2 = [0, -2]\n\npt = .4\ndect = [4, 0]\n\nxs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1)\nxs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2)\nxt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect)\n\nall_Xr = [xs1, xs2]\nall_Yr = [ys1, ys2]"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {
54+
"collapsed": false
55+
},
56+
"outputs": [],
57+
"source": [
58+
"da = 1.5\n\n\ndef plot_ax(dec, name):\n pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)\n pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)\n pl.text(dec[0] - .5, dec[1] + 2, name)"
59+
]
60+
},
61+
{
62+
"cell_type": "markdown",
63+
"metadata": {},
64+
"source": [
65+
"Fig 1 : plots source and target samples\n---------------------------------------\n\n"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {
72+
"collapsed": false
73+
},
74+
"outputs": [],
75+
"source": [
76+
"pl.figure(1)\npl.clf()\nplot_ax(dec1, 'Source 1')\nplot_ax(dec2, 'Source 2')\nplot_ax(dect, 'Target')\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9,\n label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1))\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9,\n label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2))\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9,\n label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt))\npl.title('Data')\n\npl.legend()\npl.axis('equal')\npl.axis('off')"
77+
]
78+
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"Instantiate Sinkhorn transport algorithm and fit them for all source domains\n----------------------------------------------------------------------------\n\n"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {
90+
"collapsed": false
91+
},
92+
"outputs": [],
93+
"source": [
94+
"ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')\n\n\ndef print_G(G, xs, ys, xt):\n for i in range(G.shape[0]):\n for j in range(G.shape[1]):\n if G[i, j] > 5e-4:\n if ys[i]:\n c = 'b'\n else:\n c = 'r'\n pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2)"
95+
]
96+
},
97+
{
98+
"cell_type": "markdown",
99+
"metadata": {},
100+
"source": [
101+
"Fig 2 : plot optimal couplings and transported samples\n------------------------------------------------------\n\n"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": null,
107+
"metadata": {
108+
"collapsed": false
109+
},
110+
"outputs": [],
111+
"source": [
112+
"pl.figure(2)\npl.clf()\nplot_ax(dec1, 'Source 1')\nplot_ax(dec2, 'Source 2')\nplot_ax(dect, 'Target')\nprint_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)\nprint_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)\n\npl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')\npl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')\n\npl.title('Independent OT')\n\npl.legend()\npl.axis('equal')\npl.axis('off')"
113+
]
114+
},
115+
{
116+
"cell_type": "markdown",
117+
"metadata": {},
118+
"source": [
119+
"Instantiate JCPOT adaptation algorithm and fit it\n----------------------------------------------------------------------------\n\n"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"metadata": {
126+
"collapsed": false
127+
},
128+
"outputs": [],
129+
"source": [
130+
"otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)\notda.fit(all_Xr, all_Yr, xt)\n\nws1 = otda.proportions_.dot(otda.log_['D2'][0])\nws2 = otda.proportions_.dot(otda.log_['D2'][1])\n\npl.figure(3)\npl.clf()\nplot_ax(dec1, 'Source 1')\nplot_ax(dec2, 'Source 2')\nplot_ax(dect, 'Target')\nprint_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)\nprint_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)\n\npl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')\npl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')\n\npl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1]))\n\npl.legend()\npl.axis('equal')\npl.axis('off')"
131+
]
132+
},
133+
{
134+
"cell_type": "markdown",
135+
"metadata": {},
136+
"source": [
137+
"Run oracle transport algorithm with known proportions\n----------------------------------------------------------------------------\n\n"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {
144+
"collapsed": false
145+
},
146+
"outputs": [],
147+
"source": [
148+
"h_res = np.array([1 - pt, pt])\n\nws1 = h_res.dot(otda.log_['D2'][0])\nws2 = h_res.dot(otda.log_['D2'][1])\n\npl.figure(4)\npl.clf()\nplot_ax(dec1, 'Source 1')\nplot_ax(dec2, 'Source 2')\nplot_ax(dect, 'Target')\nprint_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)\nprint_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)\n\npl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')\npl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')\n\npl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1]))\n\npl.legend()\npl.axis('equal')\npl.axis('off')\npl.show()"
149+
]
150+
}
151+
],
152+
"metadata": {
153+
"kernelspec": {
154+
"display_name": "Python 3",
155+
"language": "python",
156+
"name": "python3"
157+
},
158+
"language_info": {
159+
"codemirror_mode": {
160+
"name": "ipython",
161+
"version": 3
162+
},
163+
"file_extension": ".py",
164+
"mimetype": "text/x-python",
165+
"name": "python",
166+
"nbconvert_exporter": "python",
167+
"pygments_lexer": "ipython3",
168+
"version": "3.6.9"
169+
}
170+
},
171+
"nbformat": 4,
172+
"nbformat_minor": 0
173+
}
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================
4+
OT for multi-source target shift
5+
========================
6+
7+
This example introduces a target shift problem with two 2D source and 1 target domain.
8+
9+
"""
10+
11+
# Authors: Remi Flamary <[email protected]>
12+
# Ievgen Redko <[email protected]>
13+
#
14+
# License: MIT License
15+
16+
import pylab as pl
17+
import numpy as np
18+
import ot
19+
from ot.datasets import make_data_classif
20+
21+
##############################################################################
22+
# Generate data
23+
# -------------
24+
n = 50
25+
sigma = 0.3
26+
np.random.seed(1985)
27+
28+
p1 = .2
29+
dec1 = [0, 2]
30+
31+
p2 = .9
32+
dec2 = [0, -2]
33+
34+
pt = .4
35+
dect = [4, 0]
36+
37+
xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1)
38+
xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2)
39+
xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect)
40+
41+
all_Xr = [xs1, xs2]
42+
all_Yr = [ys1, ys2]
43+
# %%
44+
45+
da = 1.5
46+
47+
48+
def plot_ax(dec, name):
49+
pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)
50+
pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)
51+
pl.text(dec[0] - .5, dec[1] + 2, name)
52+
53+
54+
##############################################################################
55+
# Fig 1 : plots source and target samples
56+
# ---------------------------------------
57+
58+
pl.figure(1)
59+
pl.clf()
60+
plot_ax(dec1, 'Source 1')
61+
plot_ax(dec2, 'Source 2')
62+
plot_ax(dect, 'Target')
63+
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9,
64+
label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1))
65+
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9,
66+
label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2))
67+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9,
68+
label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt))
69+
pl.title('Data')
70+
71+
pl.legend()
72+
pl.axis('equal')
73+
pl.axis('off')
74+
75+
##############################################################################
76+
# Instantiate Sinkhorn transport algorithm and fit them for all source domains
77+
# ----------------------------------------------------------------------------
78+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')
79+
80+
81+
def print_G(G, xs, ys, xt):
82+
for i in range(G.shape[0]):
83+
for j in range(G.shape[1]):
84+
if G[i, j] > 5e-4:
85+
if ys[i]:
86+
c = 'b'
87+
else:
88+
c = 'r'
89+
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2)
90+
91+
92+
##############################################################################
93+
# Fig 2 : plot optimal couplings and transported samples
94+
# ------------------------------------------------------
95+
pl.figure(2)
96+
pl.clf()
97+
plot_ax(dec1, 'Source 1')
98+
plot_ax(dec2, 'Source 2')
99+
plot_ax(dect, 'Target')
100+
print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)
101+
print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)
102+
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
103+
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
104+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
105+
106+
pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
107+
pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
108+
109+
pl.title('Independent OT')
110+
111+
pl.legend()
112+
pl.axis('equal')
113+
pl.axis('off')
114+
115+
##############################################################################
116+
# Instantiate JCPOT adaptation algorithm and fit it
117+
# ----------------------------------------------------------------------------
118+
otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
119+
otda.fit(all_Xr, all_Yr, xt)
120+
121+
ws1 = otda.proportions_.dot(otda.log_['D2'][0])
122+
ws2 = otda.proportions_.dot(otda.log_['D2'][1])
123+
124+
pl.figure(3)
125+
pl.clf()
126+
plot_ax(dec1, 'Source 1')
127+
plot_ax(dec2, 'Source 2')
128+
plot_ax(dect, 'Target')
129+
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
130+
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
131+
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
132+
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
133+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
134+
135+
pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
136+
pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
137+
138+
pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1]))
139+
140+
pl.legend()
141+
pl.axis('equal')
142+
pl.axis('off')
143+
144+
##############################################################################
145+
# Run oracle transport algorithm with known proportions
146+
# ----------------------------------------------------------------------------
147+
h_res = np.array([1 - pt, pt])
148+
149+
ws1 = h_res.dot(otda.log_['D2'][0])
150+
ws2 = h_res.dot(otda.log_['D2'][1])
151+
152+
pl.figure(4)
153+
pl.clf()
154+
plot_ax(dec1, 'Source 1')
155+
plot_ax(dec2, 'Source 2')
156+
plot_ax(dect, 'Target')
157+
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
158+
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
159+
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
160+
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
161+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
162+
163+
pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
164+
pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
165+
166+
pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1]))
167+
168+
pl.legend()
169+
pl.axis('equal')
170+
pl.axis('off')
171+
pl.show()

0 commit comments

Comments
 (0)