Skip to content

Commit 62dcfbf

Browse files
authored
Merge pull request #28 from Slasnista/domain_adaptation_corrections
Domain adaptation corrections, closes #26
2 parents 1669704 + 2097116 commit 62dcfbf

File tree

3 files changed

+384
-163
lines changed

3 files changed

+384
-163
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
============================================
4+
OTDA unsupervised vs semi-supervised setting
5+
============================================
6+
7+
This example introduces a semi supervised domain adaptation in a 2D setting.
8+
It explicits the problem of semi supervised domain adaptation and introduces
9+
some optimal transport approaches to solve it.
10+
11+
Quantities such as optimal couplings, greater coupling coefficients and
12+
transported samples are represented in order to give a visual understanding
13+
of what the transport methods are doing.
14+
"""
15+
16+
# Authors: Remi Flamary <[email protected]>
17+
# Stanislas Chambon <[email protected]>
18+
#
19+
# License: MIT License
20+
21+
import matplotlib.pylab as pl
22+
import ot
23+
24+
25+
##############################################################################
26+
# generate data
27+
##############################################################################
28+
29+
n_samples_source = 150
30+
n_samples_target = 150
31+
32+
Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
33+
Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
34+
35+
36+
##############################################################################
37+
# Transport source samples onto target samples
38+
##############################################################################
39+
40+
# unsupervised domain adaptation
41+
ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
42+
ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
43+
transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
44+
45+
# semi-supervised domain adaptation
46+
ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
47+
ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
48+
transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
49+
50+
# semi supervised DA uses available labaled target samples to modify the cost
51+
# matrix involved in the OT problem. The cost of transporting a source sample
52+
# of class A onto a target sample of class B != A is set to infinite, or a
53+
# very large value
54+
55+
# note that in the present case we consider that all the target samples are
56+
# labeled. For daily applications, some target sample might not have labels,
57+
# in this case the element of yt corresponding to these samples should be
58+
# filled with -1.
59+
60+
# Warning: we recall that -1 cannot be used as a class label
61+
62+
63+
##############################################################################
64+
# Fig 1 : plots source and target samples + matrix of pairwise distance
65+
##############################################################################
66+
67+
pl.figure(1, figsize=(10, 10))
68+
pl.subplot(2, 2, 1)
69+
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
70+
pl.xticks([])
71+
pl.yticks([])
72+
pl.legend(loc=0)
73+
pl.title('Source samples')
74+
75+
pl.subplot(2, 2, 2)
76+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
77+
pl.xticks([])
78+
pl.yticks([])
79+
pl.legend(loc=0)
80+
pl.title('Target samples')
81+
82+
pl.subplot(2, 2, 3)
83+
pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
84+
pl.xticks([])
85+
pl.yticks([])
86+
pl.title('Cost matrix - unsupervised DA')
87+
88+
pl.subplot(2, 2, 4)
89+
pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
90+
pl.xticks([])
91+
pl.yticks([])
92+
pl.title('Cost matrix - semisupervised DA')
93+
94+
pl.tight_layout()
95+
96+
# the optimal coupling in the semi-supervised DA case will exhibit " shape
97+
# similar" to the cost matrix, (block diagonal matrix)
98+
99+
100+
##############################################################################
101+
# Fig 2 : plots optimal couplings for the different methods
102+
##############################################################################
103+
104+
pl.figure(2, figsize=(8, 4))
105+
106+
pl.subplot(1, 2, 1)
107+
pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
108+
pl.xticks([])
109+
pl.yticks([])
110+
pl.title('Optimal coupling\nUnsupervised DA')
111+
112+
pl.subplot(1, 2, 2)
113+
pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
114+
pl.xticks([])
115+
pl.yticks([])
116+
pl.title('Optimal coupling\nSemi-supervised DA')
117+
118+
pl.tight_layout()
119+
120+
121+
##############################################################################
122+
# Fig 3 : plot transported samples
123+
##############################################################################
124+
125+
# display transported samples
126+
pl.figure(4, figsize=(8, 4))
127+
pl.subplot(1, 2, 1)
128+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
129+
label='Target samples', alpha=0.5)
130+
pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
131+
marker='+', label='Transp samples', s=30)
132+
pl.title('Transported samples\nEmdTransport')
133+
pl.legend(loc=0)
134+
pl.xticks([])
135+
pl.yticks([])
136+
137+
pl.subplot(1, 2, 2)
138+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
139+
label='Target samples', alpha=0.5)
140+
pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
141+
marker='+', label='Transp samples', s=30)
142+
pl.title('Transported samples\nSinkhornTransport')
143+
pl.xticks([])
144+
pl.yticks([])
145+
146+
pl.tight_layout()
147+
pl.show()

ot/da.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -966,8 +966,12 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
966966
The class labels
967967
Xt : array-like, shape (n_target_samples, n_features)
968968
The training input samples.
969-
yt : array-like, shape (n_labeled_target_samples,)
970-
The class labels
969+
yt : array-like, shape (n_target_samples,)
970+
The class labels. If some target samples are unlabeled, fill the
971+
yt's elements with -1.
972+
973+
Warning: Note that, due to this convention -1 cannot be used as a
974+
class label
971975
972976
Returns
973977
-------
@@ -989,7 +993,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
989993

990994
# assumes labeled source samples occupy the first rows
991995
# and labeled target samples occupy the first columns
992-
classes = np.unique(ys)
996+
classes = [c for c in np.unique(ys) if c != -1]
993997
for c in classes:
994998
idx_s = np.where((ys != c) & (ys != -1))
995999
idx_t = np.where(yt == c)
@@ -1023,8 +1027,12 @@ def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
10231027
The class labels
10241028
Xt : array-like, shape (n_target_samples, n_features)
10251029
The training input samples.
1026-
yt : array-like, shape (n_labeled_target_samples,)
1027-
The class labels
1030+
yt : array-like, shape (n_target_samples,)
1031+
The class labels. If some target samples are unlabeled, fill the
1032+
yt's elements with -1.
1033+
1034+
Warning: Note that, due to this convention -1 cannot be used as a
1035+
class label
10281036
10291037
Returns
10301038
-------
@@ -1045,8 +1053,12 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
10451053
The class labels
10461054
Xt : array-like, shape (n_target_samples, n_features)
10471055
The training input samples.
1048-
yt : array-like, shape (n_labeled_target_samples,)
1049-
The class labels
1056+
yt : array-like, shape (n_target_samples,)
1057+
The class labels. If some target samples are unlabeled, fill the
1058+
yt's elements with -1.
1059+
1060+
Warning: Note that, due to this convention -1 cannot be used as a
1061+
class label
10501062
batch_size : int, optional (default=128)
10511063
The batch size for out of sample inverse transform
10521064
@@ -1110,8 +1122,12 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11101122
The class labels
11111123
Xt : array-like, shape (n_target_samples, n_features)
11121124
The training input samples.
1113-
yt : array-like, shape (n_labeled_target_samples,)
1114-
The class labels
1125+
yt : array-like, shape (n_target_samples,)
1126+
The class labels. If some target samples are unlabeled, fill the
1127+
yt's elements with -1.
1128+
1129+
Warning: Note that, due to this convention -1 cannot be used as a
1130+
class label
11151131
batch_size : int, optional (default=128)
11161132
The batch size for out of sample inverse transform
11171133
@@ -1241,8 +1257,12 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
12411257
The class labels
12421258
Xt : array-like, shape (n_target_samples, n_features)
12431259
The training input samples.
1244-
yt : array-like, shape (n_labeled_target_samples,)
1245-
The class labels
1260+
yt : array-like, shape (n_target_samples,)
1261+
The class labels. If some target samples are unlabeled, fill the
1262+
yt's elements with -1.
1263+
1264+
Warning: Note that, due to this convention -1 cannot be used as a
1265+
class label
12461266
12471267
Returns
12481268
-------
@@ -1333,8 +1353,12 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13331353
The class labels
13341354
Xt : array-like, shape (n_target_samples, n_features)
13351355
The training input samples.
1336-
yt : array-like, shape (n_labeled_target_samples,)
1337-
The class labels
1356+
yt : array-like, shape (n_target_samples,)
1357+
The class labels. If some target samples are unlabeled, fill the
1358+
yt's elements with -1.
1359+
1360+
Warning: Note that, due to this convention -1 cannot be used as a
1361+
class label
13381362
13391363
Returns
13401364
-------
@@ -1434,8 +1458,12 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14341458
The class labels
14351459
Xt : array-like, shape (n_target_samples, n_features)
14361460
The training input samples.
1437-
yt : array-like, shape (n_labeled_target_samples,)
1438-
The class labels
1461+
yt : array-like, shape (n_target_samples,)
1462+
The class labels. If some target samples are unlabeled, fill the
1463+
yt's elements with -1.
1464+
1465+
Warning: Note that, due to this convention -1 cannot be used as a
1466+
class label
14391467
14401468
Returns
14411469
-------
@@ -1545,8 +1573,12 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15451573
The class labels
15461574
Xt : array-like, shape (n_target_samples, n_features)
15471575
The training input samples.
1548-
yt : array-like, shape (n_labeled_target_samples,)
1549-
The class labels
1576+
yt : array-like, shape (n_target_samples,)
1577+
The class labels. If some target samples are unlabeled, fill the
1578+
yt's elements with -1.
1579+
1580+
Warning: Note that, due to this convention -1 cannot be used as a
1581+
class label
15501582
15511583
Returns
15521584
-------
@@ -1662,8 +1694,12 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
16621694
The class labels
16631695
Xt : array-like, shape (n_target_samples, n_features)
16641696
The training input samples.
1665-
yt : array-like, shape (n_labeled_target_samples,)
1666-
The class labels
1697+
yt : array-like, shape (n_target_samples,)
1698+
The class labels. If some target samples are unlabeled, fill the
1699+
yt's elements with -1.
1700+
1701+
Warning: Note that, due to this convention -1 cannot be used as a
1702+
class label
16671703
16681704
Returns
16691705
-------

0 commit comments

Comments
 (0)