Skip to content

Commit 140baad

Browse files
committed
add WDA
1 parent b30a380 commit 140baad

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

examples/plot_WDA.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================
4+
1D optimal transport
5+
====================
6+
7+
@author: rflamary
8+
"""
9+
10+
import numpy as np
11+
import matplotlib.pylab as pl
12+
import ot
13+
from ot.datasets import get_1D_gauss as gauss
14+
from ot.dr import wda
15+
16+
17+
#%% parameters
18+
19+
n=1000 # nb samples in source and target datasets
20+
nz=0.2
21+
xs,ys=ot.datasets.get_data_classif('3gauss',n,nz)
22+
xt,yt=ot.datasets.get_data_classif('3gauss',n,nz)
23+
24+
nbnoise=8
25+
26+
xs=np.hstack((xs,np.random.randn(n,nbnoise)))
27+
xt=np.hstack((xt,np.random.randn(n,nbnoise)))
28+
29+
#%% plot samples
30+
31+
pl.figure(1)
32+
33+
34+
pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
35+
pl.legend(loc=0)
36+
pl.title('Discriminant dimensions')
37+
38+
39+
#%% plot distributions and loss matrix
40+
p=2
41+
reg=1
42+
k=10
43+
maxiter=100
44+
45+
P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
46+
47+
#%% plot samples
48+
49+
xsp=proj(xs)
50+
xtp=proj(xt)
51+
52+
pl.figure(1,(10,5))
53+
54+
pl.subplot(1,2,1)
55+
pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
56+
pl.legend(loc=0)
57+
pl.title('Projected training samples')
58+
59+
60+
pl.subplot(1,2,2)
61+
pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
62+
pl.legend(loc=0)
63+
pl.title('Projected test samples')

ot/dr.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Domain adaptation with optimal transport
4+
"""
5+
6+
7+
import autograd.numpy as np
8+
from pymanopt.manifolds import Stiefel
9+
from pymanopt import Problem
10+
from pymanopt.solvers import SteepestDescent, TrustRegions
11+
12+
def dist(x1,x2):
13+
""" Compute squared euclidena distance between samples
14+
"""
15+
x1p2=np.sum(np.square(x1),1)
16+
x2p2=np.sum(np.square(x2),1)
17+
return x1p2.reshape((-1,1))+x2p2.reshape((1,-1))-2*np.dot(x1,x2.T)
18+
19+
def sinkhorn(w1,w2,M,reg,k):
20+
"""
21+
Simple solver for Sinkhorn algorithm with fixed number of iteration
22+
"""
23+
K=np.exp(-M/reg)
24+
ui=np.ones((M.shape[0],))
25+
vi=np.ones((M.shape[1],))
26+
for i in range(k):
27+
vi=w2/(np.dot(K.T,ui))
28+
ui=w1/(np.dot(K,vi))
29+
G=ui.reshape((M.shape[0],1))*K*vi.reshape((1,M.shape[1]))
30+
return G
31+
32+
def split_classes(X,y):
33+
"""
34+
split samples in X by classes in y
35+
"""
36+
lstsclass=np.unique(y)
37+
return [X[y==i,:].astype(np.float32) for i in lstsclass]
38+
39+
40+
41+
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
42+
"""
43+
Wasserstein Discriminant Analysis
44+
45+
The function solves the following optimization problem:
46+
47+
.. math::
48+
P = arg\min_P \frac{\sum_i W(PX^i,PX^i)}{\sum_{i,j\neq i} W(PX^i,PX^j)}
49+
50+
where :
51+
52+
- :math:`W` is entropic regularized Wasserstein distances
53+
- :math:`X^i` are samples in the dataset corresponding to class i
54+
55+
"""
56+
57+
mx=np.mean(X)
58+
X-=mx.reshape((1,-1))
59+
60+
# data split between classes
61+
d=X.shape[1]
62+
xc=split_classes(X,y)
63+
# compute uniform weighs
64+
wc=[np.ones((x.shape[0]),dtype=np.float32)/x.shape[0] for x in xc]
65+
66+
def cost(P):
67+
# wda loss
68+
loss_b=0
69+
loss_w=0
70+
71+
for i,xi in enumerate(xc):
72+
xi=np.dot(xi,P)
73+
for j,xj in enumerate(xc[i:]):
74+
xj=np.dot(xj,P)
75+
M=dist(xi,xj)
76+
G=sinkhorn(wc[i],wc[j+i],M,reg,k)
77+
if j==0:
78+
loss_w+=np.sum(G*M)
79+
else:
80+
loss_b+=np.sum(G*M)
81+
82+
# loss inversed because minimization
83+
return loss_w/loss_b
84+
85+
86+
# declare manifold and problem
87+
manifold = Stiefel(d, p)
88+
problem = Problem(manifold=manifold, cost=cost)
89+
90+
# declare solver and solve
91+
if solver is None:
92+
solver= SteepestDescent(maxiter=maxiter,logverbosity=verbose)
93+
elif solver in ['tr','TrustRegions']:
94+
solver= TrustRegions(maxiter=maxiter,logverbosity=verbose)
95+
96+
Popt = solver.solve(problem)
97+
98+
def proj(X):
99+
return (X-mx.reshape((1,-1))).dot(Popt)
100+
101+
return Popt, proj

0 commit comments

Comments
 (0)