Skip to content

Commit 566645a

Browse files
committed
add mapping estimation (still debugging)
1 parent 9813511 commit 566645a

File tree

3 files changed

+234
-53
lines changed

3 files changed

+234
-53
lines changed

ot/da.py

Lines changed: 165 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import numpy as np
77
from .bregman import sinkhorn
88
from .lp import emd
9-
from .utils import unif,dist
9+
from .utils import unif,dist,kernel
1010
from .optim import cg
1111

1212

1313
def indices(a, func):
1414
return [i for (i, val) in enumerate(a) if func(val)]
1515

16+
17+
1618
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
1719
"""
1820
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -129,34 +131,38 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
129131

130132
if bias:
131133
xs1=np.hstack((xs,np.ones((ns,1))))
132-
I=eta*np.eye(d+1)
134+
xstxs=xs1.T.dot(xs1)
135+
I=np.eye(d+1)
133136
I[-1]=0
134137
I0=I[:,:-1]
135138
sel=lambda x : x[:-1,:]
136139
else:
137140
xs1=xs
138-
I=eta*np.eye(d)
141+
xstxs=xs1.T.dot(xs1)
142+
I=np.eye(d)
139143
I0=I
140144
sel=lambda x : x
141145

142146
if log:
143147
log={'err':[]}
144148

145149
a,b=unif(ns),unif(nt)
146-
M=dist(xs,xt)
150+
M=dist(xs,xt)*ns
147151
G=emd(a,b,M)
148152

149153
vloss=[]
150154

151155
def loss(L,G):
156+
"""Compute full loss"""
152157
return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
153158

154159
def solve_L(G):
155-
""" solve problem with fixed G"""
160+
""" solve L problem with fixed G (least square)"""
156161
xst=ns*G.dot(xt)
157-
return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0)
162+
return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0)
158163

159164
def solve_G(L,G0):
165+
"""Update G with CG algorithm"""
160166
xsi=xs1.dot(L)
161167
def f(G):
162168
return np.sum((xsi-ns*G.dot(xt))**2)
@@ -175,8 +181,11 @@ def df(G):
175181
print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0))
176182

177183

178-
# regul matrix
179-
loop=1
184+
# init loop
185+
if numItermax>0:
186+
loop=1
187+
else:
188+
loop=0
180189
it=0
181190

182191
while loop:
@@ -191,18 +200,116 @@ def df(G):
191200

192201
vloss.append(loss(L,G))
193202

203+
if it>=numItermax:
204+
loop=0
205+
194206
if abs(vloss[-1]-vloss[-2])<stopThr:
195207
loop=0
196208

197209
if verbose:
198210
if it%20==0:
199211
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
200212
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
213+
if log:
214+
log['loss']=vloss
215+
return G,L,log
216+
else:
217+
return G,L
201218

202-
return G,L
203219

220+
def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
221+
"""Joint Ot and mapping estimation (uniform weights and )
222+
"""
204223

224+
ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]
205225

226+
if bias:
227+
K=
228+
xs1=np.hstack((xs,np.ones((ns,1))))
229+
xstxs=xs1.T.dot(xs1)
230+
I=np.eye(d+1)
231+
I[-1]=0
232+
I0=I[:,:-1]
233+
sel=lambda x : x[:-1,:]
234+
else:
235+
xs1=xs
236+
xstxs=xs1.T.dot(xs1)
237+
I=np.eye(d)
238+
I0=I
239+
sel=lambda x : x
240+
241+
if log:
242+
log={'err':[]}
243+
244+
a,b=unif(ns),unif(nt)
245+
M=dist(xs,xt)*ns
246+
G=emd(a,b,M)
247+
248+
vloss=[]
249+
250+
def loss(L,G):
251+
"""Compute full loss"""
252+
return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
253+
254+
def solve_L(G):
255+
""" solve L problem with fixed G (least square)"""
256+
xst=ns*G.dot(xt)
257+
return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0)
258+
259+
def solve_G(L,G0):
260+
"""Update G with CG algorithm"""
261+
xsi=xs1.dot(L)
262+
def f(G):
263+
return np.sum((xsi-ns*G.dot(xt))**2)
264+
def df(G):
265+
return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T)
266+
G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr)
267+
return G
268+
269+
270+
L=solve_L(G)
271+
272+
vloss.append(loss(L,G))
273+
274+
if verbose:
275+
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
276+
print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0))
277+
278+
279+
# init loop
280+
if numItermax>0:
281+
loop=1
282+
else:
283+
loop=0
284+
it=0
285+
286+
while loop:
287+
288+
it+=1
289+
290+
# update G
291+
G=solve_G(L,G)
292+
293+
#update L
294+
L=solve_L(G)
295+
296+
vloss.append(loss(L,G))
297+
298+
if it>=numItermax:
299+
loop=0
300+
301+
if abs(vloss[-1]-vloss[-2])<stopThr:
302+
loop=0
303+
304+
if verbose:
305+
if it%20==0:
306+
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
307+
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
308+
if log:
309+
log['loss']=vloss
310+
return G,L,log
311+
else:
312+
return G,L
206313

207314

208315
class OTDA(object):
@@ -294,6 +401,7 @@ def predict(self,x,direction=1):
294401

295402
class OTDA_sinkhorn(OTDA):
296403
"""Class for domain adaptation with optimal transport with entropic regularization"""
404+
297405
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
298406
""" Fit domain adaptation between samples is xs and xt (with optional
299407
weights)"""
@@ -335,3 +443,51 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
335443
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
336444
self.computed=True
337445

446+
class OTDA_mapping(OTDA):
447+
"""Class for optimal transport with joint linear mapping estimation"""
448+
449+
450+
def __init__(self,metric='sqeuclidean'):
451+
""" Class initialization"""
452+
453+
454+
self.xs=0
455+
self.xt=0
456+
self.G=0
457+
self.L=0
458+
self.bias=False
459+
self.metric=metric
460+
self.computed=False
461+
462+
def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
463+
""" Fit domain adaptation between samples is xs and xt (with optional
464+
weights)"""
465+
self.xs=xs
466+
self.xt=xt
467+
self.bias=bias
468+
469+
self.ws=unif(xs.shape[0])
470+
self.wt=unif(xt.shape[0])
471+
472+
self.G,self.L=joint_OT_mapping_linear(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs)
473+
self.computed=True
474+
475+
def mapping(self):
476+
return lambda x: self.predict(x)
477+
478+
479+
def predict(self,x):
480+
""" Out of sample mapping using the formulation from Ferradans
481+
482+
It basically find the source sample the nearset to the nex sample and
483+
apply the difference to the displaced source sample.
484+
485+
"""
486+
if self.computed:
487+
if self.bias:
488+
x=np.hstack((x,np.ones((x.shape[0],1))))
489+
return x.dot(self.L) # aply the delta to the interpolation
490+
else:
491+
print("Warning, model not fitted yet, returning None")
492+
return None
493+

ot/datasets.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
def get_1D_gauss(n,m,s):
11-
"""return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
12-
11+
"""return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
12+
1313
Parameters
1414
----------
1515
@@ -20,21 +20,21 @@ def get_1D_gauss(n,m,s):
2020
s : float
2121
standard deviaton of the gaussian distribution
2222
23-
23+
2424
Returns
2525
-------
2626
h : np.array (n,)
27-
1D histogram for a gaussian distribution
28-
27+
1D histogram for a gaussian distribution
28+
2929
"""
3030
x=np.arange(n,dtype=np.float64)
3131
h=np.exp(-(x-m)**2/(2*s^2))
3232
return h/h.sum()
33-
34-
33+
34+
3535
def get_2D_samples_gauss(n,m,sigma):
36-
"""return n samples drawn from 2D gaussian N(m,sigma)
37-
36+
"""return n samples drawn from 2D gaussian N(m,sigma)
37+
3838
Parameters
3939
----------
4040
@@ -45,12 +45,12 @@ def get_2D_samples_gauss(n,m,sigma):
4545
sigma : np.array (2,2)
4646
covariance matrix of the gaussian distribution
4747
48-
48+
4949
Returns
5050
-------
5151
X : np.array (n,2)
52-
n samples drawn from N(m,sigma)
53-
52+
n samples drawn from N(m,sigma)
53+
5454
"""
5555
if np.isscalar(sigma):
5656
sigma=np.array([sigma,])
@@ -61,9 +61,10 @@ def get_2D_samples_gauss(n,m,sigma):
6161
res= np.random.randn(n,2)*np.sqrt(sigma)+m
6262
return res
6363

64-
def get_data_classif(dataset,n,nz=.5,**kwargs):
64+
65+
def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
6566
""" dataset generation for classification problems
66-
67+
6768
Parameters
6869
----------
6970
@@ -74,13 +75,13 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
7475
nz : float
7576
noise level (>0)
7677
77-
78+
7879
Returns
7980
-------
8081
X : np.array (n,d)
81-
n observation of size d
82+
n observation of size d
8283
y : np.array (n,)
83-
labels of the samples
84+
labels of the samples
8485
8586
"""
8687
if dataset.lower()=='3gauss':
@@ -90,10 +91,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
9091
x[y==1,0]=-1.; x[y==1,1]=-1.
9192
x[y==2,0]=-1.; x[y==2,1]=1.
9293
x[y==3,0]=1. ; x[y==3,1]=0
93-
94+
9495
x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2)
9596
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
96-
97+
9798
elif dataset.lower()=='3gauss2':
9899
y=np.floor((np.arange(n)*1.0/n*3))+1
99100
x=np.zeros((n,2))
@@ -102,12 +103,29 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
102103
x[y==1,0]=-2.; x[y==1,1]=-2.
103104
x[y==2,0]=-2.; x[y==2,1]=2.
104105
x[y==3,0]=2. ; x[y==3,1]=0
105-
106+
106107
x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
107-
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
108+
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
109+
110+
elif dataset.lower()=='gaussrot' :
111+
rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
112+
m1=np.array([-1,-1])
113+
m2=np.array([1,1])
114+
y=np.floor((np.arange(n)*1.0/n*2))+1
115+
n1=np.sum(y==1)
116+
n2=np.sum(y==2)
117+
x=np.zeros((n,2))
118+
119+
x[y==1,:]=get_2D_samples_gauss(n1,m1,nz)
120+
x[y==2,:]=get_2D_samples_gauss(n2,m2,nz)
121+
122+
x=x.dot(rot)
123+
124+
125+
108126
else:
109127
x=0
110128
y=0
111129
print("unknown dataset")
112-
130+
113131
return x,y.astype(int)

0 commit comments

Comments
 (0)