Skip to content

Commit adc5570

Browse files
authored
Merge pull request #137 from ievred/jcpot
[MRG] Jcpot : Multi source DA with target shift
2 parents 4cd4e09 + 7889484 commit adc5570

File tree

11 files changed

+908
-55
lines changed

11 files changed

+908
-55
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ install:
3535
- pip install -r requirements.txt
3636
- pip install -U "numpy>=1.14" "scipy<1.3" # for numpy array formatting in doctests + scipy version: otherwise, pymanopt fails, cf <https://github.com/pymanopt/pymanopt/issues/77>
3737
- pip install flake8 pytest "pytest-cov<2.6"
38+
- pip install -U "sklearn"
3839
- pip install .
3940
# command to run tests + check syntax style
4041
services:

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ It provides the following solvers:
2929
* Non regularized free support Wasserstein barycenters [20].
3030
* Unbalanced OT with KL relaxation distance and barycenter [10, 25].
3131
* Screening Sinkhorn Algorithm for OT [26].
32+
* JCPOT algorithm for multi-source domain adaptation with target shift [27].
3233

3334
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
3435

@@ -257,3 +258,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
257258
[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS).
258259

259260
[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS).
261+
262+
[27] Redko I., Courty N., Flamary R., Tuia D. (2019). [Optimal Transport for Multi-source Domain Adaptation under Target Shift](http://proceedings.mlr.press/v89/redko19a.html), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics (AISTATS) 22, 2019.

examples/plot_otda_classes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import matplotlib.pylab as pl
1818
import ot
1919

20-
2120
##############################################################################
2221
# Generate data
2322
# -------------

examples/plot_otda_jcpot.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================
4+
OT for multi-source target shift
5+
========================
6+
7+
This example introduces a target shift problem with two 2D source and 1 target domain.
8+
9+
"""
10+
11+
# Authors: Remi Flamary <[email protected]>
12+
# Ievgen Redko <[email protected]>
13+
#
14+
# License: MIT License
15+
16+
import pylab as pl
17+
import numpy as np
18+
import ot
19+
from ot.datasets import make_data_classif
20+
21+
##############################################################################
22+
# Generate data
23+
# -------------
24+
n = 50
25+
sigma = 0.3
26+
np.random.seed(1985)
27+
28+
p1 = .2
29+
dec1 = [0, 2]
30+
31+
p2 = .9
32+
dec2 = [0, -2]
33+
34+
pt = .4
35+
dect = [4, 0]
36+
37+
xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1)
38+
xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2)
39+
xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect)
40+
41+
all_Xr = [xs1, xs2]
42+
all_Yr = [ys1, ys2]
43+
# %%
44+
45+
da = 1.5
46+
47+
48+
def plot_ax(dec, name):
49+
pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)
50+
pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)
51+
pl.text(dec[0] - .5, dec[1] + 2, name)
52+
53+
54+
##############################################################################
55+
# Fig 1 : plots source and target samples
56+
# ---------------------------------------
57+
58+
pl.figure(1)
59+
pl.clf()
60+
plot_ax(dec1, 'Source 1')
61+
plot_ax(dec2, 'Source 2')
62+
plot_ax(dect, 'Target')
63+
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9,
64+
label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1))
65+
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9,
66+
label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2))
67+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9,
68+
label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt))
69+
pl.title('Data')
70+
71+
pl.legend()
72+
pl.axis('equal')
73+
pl.axis('off')
74+
75+
##############################################################################
76+
# Instantiate Sinkhorn transport algorithm and fit them for all source domains
77+
# ----------------------------------------------------------------------------
78+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')
79+
80+
81+
def print_G(G, xs, ys, xt):
82+
for i in range(G.shape[0]):
83+
for j in range(G.shape[1]):
84+
if G[i, j] > 5e-4:
85+
if ys[i]:
86+
c = 'b'
87+
else:
88+
c = 'r'
89+
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2)
90+
91+
92+
##############################################################################
93+
# Fig 2 : plot optimal couplings and transported samples
94+
# ------------------------------------------------------
95+
pl.figure(2)
96+
pl.clf()
97+
plot_ax(dec1, 'Source 1')
98+
plot_ax(dec2, 'Source 2')
99+
plot_ax(dect, 'Target')
100+
print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)
101+
print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)
102+
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
103+
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
104+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
105+
106+
pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
107+
pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
108+
109+
pl.title('Independent OT')
110+
111+
pl.legend()
112+
pl.axis('equal')
113+
pl.axis('off')
114+
115+
##############################################################################
116+
# Instantiate JCPOT adaptation algorithm and fit it
117+
# ----------------------------------------------------------------------------
118+
otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
119+
otda.fit(all_Xr, all_Yr, xt)
120+
121+
ws1 = otda.proportions_.dot(otda.log_['D2'][0])
122+
ws2 = otda.proportions_.dot(otda.log_['D2'][1])
123+
124+
pl.figure(3)
125+
pl.clf()
126+
plot_ax(dec1, 'Source 1')
127+
plot_ax(dec2, 'Source 2')
128+
plot_ax(dect, 'Target')
129+
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
130+
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
131+
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
132+
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
133+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
134+
135+
pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
136+
pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
137+
138+
pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1]))
139+
140+
pl.legend()
141+
pl.axis('equal')
142+
pl.axis('off')
143+
144+
##############################################################################
145+
# Run oracle transport algorithm with known proportions
146+
# ----------------------------------------------------------------------------
147+
h_res = np.array([1 - pt, pt])
148+
149+
ws1 = h_res.dot(otda.log_['D2'][0])
150+
ws2 = h_res.dot(otda.log_['D2'][1])
151+
152+
pl.figure(4)
153+
pl.clf()
154+
plot_ax(dec1, 'Source 1')
155+
plot_ax(dec2, 'Source 2')
156+
plot_ax(dect, 'Target')
157+
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
158+
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
159+
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
160+
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
161+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
162+
163+
pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
164+
pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
165+
166+
pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1]))
167+
168+
pl.legend()
169+
pl.axis('equal')
170+
pl.axis('off')
171+
pl.show()

0 commit comments

Comments
 (0)