Skip to content

Commit 9813511

Browse files
committed
add mapping estimation (still debugging)
1 parent 3c4944c commit 9813511

File tree

2 files changed

+208
-126
lines changed

2 files changed

+208
-126
lines changed

ot/da.py

Lines changed: 155 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .bregman import sinkhorn
88
from .lp import emd
99
from .utils import unif,dist
10+
from .optim import cg
1011

1112

1213
def indices(a, func):
@@ -15,81 +16,81 @@ def indices(a, func):
1516
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
1617
"""
1718
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
18-
19+
1920
The function solves the following optimization problem:
20-
21+
2122
.. math::
2223
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
23-
24+
2425
s.t. \gamma 1 = a
25-
26-
\gamma^T 1= b
27-
26+
27+
\gamma^T 1= b
28+
2829
\gamma\geq 0
2930
where :
30-
31+
3132
- M is the (ns,nt) metric cost matrix
3233
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
3334
- :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
3435
- a and b are source and target weights (sum to 1)
35-
36+
3637
The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
37-
38-
38+
39+
3940
Parameters
4041
----------
4142
a : np.ndarray (ns,)
4243
samples weights in the source domain
4344
labels_a : np.ndarray (ns,)
44-
labels of samples in the source domain
45+
labels of samples in the source domain
4546
b : np.ndarray (nt,)
4647
samples in the target domain
4748
M : np.ndarray (ns,nt)
48-
loss matrix
49+
loss matrix
4950
reg: float
5051
Regularization term for entropic regularization >0
5152
eta: float, optional
52-
Regularization term for group lasso regularization >0
53+
Regularization term for group lasso regularization >0
5354
numItermax: int, optional
5455
Max number of iterations
5556
numInnerItermax: int, optional
5657
Max number of iterations (inner sinkhorn solver)
5758
stopInnerThr: float, optional
58-
Stop threshold on error (inner sinkhorn solver) (>0)
59+
Stop threshold on error (inner sinkhorn solver) (>0)
5960
verbose : bool, optional
6061
Print information along iterations
6162
log : bool, optional
62-
record log if True
63-
64-
63+
record log if True
64+
65+
6566
Returns
6667
-------
6768
gamma: (ns x nt) ndarray
6869
Optimal transportation matrix for the given parameters
6970
log: dict
70-
log dictionary return only if log==True in parameters
71-
72-
71+
log dictionary return only if log==True in parameters
72+
73+
7374
References
7475
----------
75-
76+
7677
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
7778
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
78-
79+
7980
See Also
8081
--------
8182
ot.lp.emd : Unregularized OT
8283
ot.bregman.sinkhorn : Entropic regularized OT
8384
ot.optim.cg : General regularized OT
84-
85-
"""
85+
86+
"""
8687
p=0.5
8788
epsilon = 1e-3
8889

8990
# init data
9091
Nini = len(a)
9192
Nfin = len(b)
92-
93+
9394
indices_labels = []
9495
idx_begin = np.min(labels_a)
9596
for c in range(idx_begin,np.max(labels_a)+1):
@@ -117,139 +118,220 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
117118
# do it only for unlabbled data
118119
if idx_begin==-1:
119120
W[indices_labels[0],t]=np.min(all_maj)
120-
121+
121122
return transp
122123

124+
def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
125+
"""Joint Ot and mapping estimation (uniform weights and )
126+
"""
127+
128+
ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]
129+
130+
if bias:
131+
xs1=np.hstack((xs,np.ones((ns,1))))
132+
I=eta*np.eye(d+1)
133+
I[-1]=0
134+
I0=I[:,:-1]
135+
sel=lambda x : x[:-1,:]
136+
else:
137+
xs1=xs
138+
I=eta*np.eye(d)
139+
I0=I
140+
sel=lambda x : x
141+
142+
if log:
143+
log={'err':[]}
144+
145+
a,b=unif(ns),unif(nt)
146+
M=dist(xs,xt)
147+
G=emd(a,b,M)
148+
149+
vloss=[]
150+
151+
def loss(L,G):
152+
return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
153+
154+
def solve_L(G):
155+
""" solve problem with fixed G"""
156+
xst=ns*G.dot(xt)
157+
return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0)
158+
159+
def solve_G(L,G0):
160+
xsi=xs1.dot(L)
161+
def f(G):
162+
return np.sum((xsi-ns*G.dot(xt))**2)
163+
def df(G):
164+
return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T)
165+
G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr)
166+
return G
167+
168+
169+
L=solve_L(G)
170+
171+
vloss.append(loss(L,G))
172+
173+
if verbose:
174+
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
175+
print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0))
176+
177+
178+
# regul matrix
179+
loop=1
180+
it=0
181+
182+
while loop:
183+
184+
it+=1
185+
186+
# update G
187+
G=solve_G(L,G)
188+
189+
#update L
190+
L=solve_L(G)
191+
192+
vloss.append(loss(L,G))
193+
194+
if abs(vloss[-1]-vloss[-2])<stopThr:
195+
loop=0
196+
197+
if verbose:
198+
if it%20==0:
199+
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
200+
print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
201+
202+
return G,L
203+
204+
205+
123206

124207

125208
class OTDA(object):
126209
"""Class for domain adaptation with optimal transport"""
127-
210+
128211
def __init__(self,metric='sqeuclidean'):
129212
""" Class initialization"""
130213
self.xs=0
131214
self.xt=0
132215
self.G=0
133216
self.metric=metric
134217
self.computed=False
135-
136-
218+
219+
137220
def fit(self,xs,xt,ws=None,wt=None):
138-
""" Fit domain adaptation between samples is xs and xt (with optional
221+
""" Fit domain adaptation between samples is xs and xt (with optional
139222
weights)"""
140223
self.xs=xs
141224
self.xt=xt
142-
225+
143226
if wt is None:
144227
wt=unif(xt.shape[0])
145228
if ws is None:
146229
ws=unif(xs.shape[0])
147-
230+
148231
self.ws=ws
149232
self.wt=wt
150-
233+
151234
self.M=dist(xs,xt,metric=self.metric)
152235
self.G=emd(ws,wt,self.M)
153236
self.computed=True
154-
237+
155238
def interp(self,direction=1):
156239
"""Barycentric interpolation for the source (1) or target (-1)
157-
158-
This Barycentric interpolation solves for each source (resp target)
240+
241+
This Barycentric interpolation solves for each source (resp target)
159242
sample xs (resp xt) the following optimization problem:
160-
243+
161244
.. math::
162245
arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
163-
246+
164247
where k is the index of the sample in xs
165-
166-
For the moment only squared euclidean distance is provided but more
167-
metric could be used in the future.
168-
248+
249+
For the moment only squared euclidean distance is provided but more
250+
metric could be used in the future.
251+
169252
"""
170-
if direction>0: # >0 then source to target
253+
if direction>0: # >0 then source to target
171254
G=self.G
172255
w=self.ws.reshape((self.xs.shape[0],1))
173256
x=self.xt
174257
else:
175258
G=self.G.T
176259
w=self.wt.reshape((self.xt.shape[0],1))
177260
x=self.xs
178-
261+
179262
if self.computed:
180263
if self.metric=='sqeuclidean':
181264
return np.dot(G/w,x) # weighted mean
182265
else:
183266
print("Warning, metric not handled yet, using weighted average")
184-
return np.dot(G/w,x) # weighted mean
185-
return None
267+
return np.dot(G/w,x) # weighted mean
268+
return None
186269
else:
187270
print("Warning, model not fitted yet, returning None")
188271
return None
189-
190-
272+
273+
191274
def predict(self,x,direction=1):
192-
""" Out of sample mapping using the formulation from Ferradans
193-
194-
It basically find the source sample the nearset to the nex sample and
275+
""" Out of sample mapping using the formulation from Ferradans
276+
277+
It basically find the source sample the nearset to the nex sample and
195278
apply the difference to the displaced source sample.
196-
279+
197280
"""
198-
if direction>0: # >0 then source to target
281+
if direction>0: # >0 then source to target
199282
xf=self.xt
200283
x0=self.xs
201284
else:
202-
xf=self.xs
285+
xf=self.xs
203286
x0=self.xt
204-
287+
205288
D0=dist(x,x0) # dist netween new samples an source
206289
idx=np.argmin(D0,1) # closest one
207290
xf=self.interp(direction)# interp the source samples
208291
return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
209-
210-
292+
293+
211294

212295
class OTDA_sinkhorn(OTDA):
213296
"""Class for domain adaptation with optimal transport with entropic regularization"""
214297
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
215-
""" Fit domain adaptation between samples is xs and xt (with optional
298+
""" Fit domain adaptation between samples is xs and xt (with optional
216299
weights)"""
217300
self.xs=xs
218301
self.xt=xt
219-
302+
220303
if wt is None:
221304
wt=unif(xt.shape[0])
222305
if ws is None:
223306
ws=unif(xs.shape[0])
224-
307+
225308
self.ws=ws
226309
self.wt=wt
227-
310+
228311
self.M=dist(xs,xt,metric=self.metric)
229312
self.G=sinkhorn(ws,wt,self.M,reg,**kwargs)
230-
self.computed=True
231-
232-
313+
self.computed=True
314+
315+
233316
class OTDA_lpl1(OTDA):
234317
"""Class for domain adaptation with optimal transport with entropic an group regularization"""
235-
236-
318+
319+
237320
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
238-
""" Fit domain adaptation between samples is xs and xt (with optional
321+
""" Fit domain adaptation between samples is xs and xt (with optional
239322
weights)"""
240323
self.xs=xs
241324
self.xt=xt
242-
325+
243326
if wt is None:
244327
wt=unif(xt.shape[0])
245328
if ws is None:
246329
ws=unif(xs.shape[0])
247-
330+
248331
self.ws=ws
249332
self.wt=wt
250-
333+
251334
self.M=dist(xs,xt,metric=self.metric)
252335
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
253-
self.computed=True
254-
255-
336+
self.computed=True
337+

0 commit comments

Comments
 (0)