6
6
import numpy as np
7
7
from .bregman import sinkhorn
8
8
from .lp import emd
9
- from .utils import unif ,dist
9
+ from .utils import unif ,dist , kernel
10
10
from .optim import cg
11
11
12
12
13
13
def indices (a , func ):
14
14
return [i for (i , val ) in enumerate (a ) if func (val )]
15
15
16
+
17
+
16
18
def sinkhorn_lpl1_mm (a ,labels_a , b , M , reg , eta = 0.1 ,numItermax = 10 ,numInnerItermax = 200 ,stopInnerThr = 1e-9 ,verbose = False ,log = False ):
17
19
"""
18
20
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -129,34 +131,38 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
129
131
130
132
if bias :
131
133
xs1 = np .hstack ((xs ,np .ones ((ns ,1 ))))
132
- I = eta * np .eye (d + 1 )
134
+ xstxs = xs1 .T .dot (xs1 )
135
+ I = np .eye (d + 1 )
133
136
I [- 1 ]= 0
134
137
I0 = I [:,:- 1 ]
135
138
sel = lambda x : x [:- 1 ,:]
136
139
else :
137
140
xs1 = xs
138
- I = eta * np .eye (d )
141
+ xstxs = xs1 .T .dot (xs1 )
142
+ I = np .eye (d )
139
143
I0 = I
140
144
sel = lambda x : x
141
145
142
146
if log :
143
147
log = {'err' :[]}
144
148
145
149
a ,b = unif (ns ),unif (nt )
146
- M = dist (xs ,xt )
150
+ M = dist (xs ,xt )* ns
147
151
G = emd (a ,b ,M )
148
152
149
153
vloss = []
150
154
151
155
def loss (L ,G ):
156
+ """Compute full loss"""
152
157
return np .sum ((xs1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum (sel (L - I0 )** 2 )
153
158
154
159
def solve_L (G ):
155
- """ solve problem with fixed G"""
160
+ """ solve L problem with fixed G (least square) """
156
161
xst = ns * G .dot (xt )
157
- return np .linalg .solve (xs1 . T . dot ( xs1 ) + I ,xs1 .T .dot (xst )+ I0 )
162
+ return np .linalg .solve (xstxs + eta * I ,xs1 .T .dot (xst )+ eta * I0 )
158
163
159
164
def solve_G (L ,G0 ):
165
+ """Update G with CG algorithm"""
160
166
xsi = xs1 .dot (L )
161
167
def f (G ):
162
168
return np .sum ((xsi - ns * G .dot (xt ))** 2 )
@@ -175,8 +181,11 @@ def df(G):
175
181
print ('{:5d}|{:8e}|{:8e}' .format (0 ,vloss [- 1 ],0 ))
176
182
177
183
178
- # regul matrix
179
- loop = 1
184
+ # init loop
185
+ if numItermax > 0 :
186
+ loop = 1
187
+ else :
188
+ loop = 0
180
189
it = 0
181
190
182
191
while loop :
@@ -191,18 +200,116 @@ def df(G):
191
200
192
201
vloss .append (loss (L ,G ))
193
202
203
+ if it >= numItermax :
204
+ loop = 0
205
+
194
206
if abs (vloss [- 1 ]- vloss [- 2 ])< stopThr :
195
207
loop = 0
196
208
197
209
if verbose :
198
210
if it % 20 == 0 :
199
211
print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
200
212
print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],abs (vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
213
+ if log :
214
+ log ['loss' ]= vloss
215
+ return G ,L ,log
216
+ else :
217
+ return G ,L
201
218
202
- return G ,L
203
219
220
+ def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kernel = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
221
+ """Joint Ot and mapping estimation (uniform weights and )
222
+ """
204
223
224
+ ns ,nt ,d = xs .shape [0 ],xt .shape [0 ],xt .shape [1 ]
205
225
226
+ if bias :
227
+ K =
228
+ xs1 = np .hstack ((xs ,np .ones ((ns ,1 ))))
229
+ xstxs = xs1 .T .dot (xs1 )
230
+ I = np .eye (d + 1 )
231
+ I [- 1 ]= 0
232
+ I0 = I [:,:- 1 ]
233
+ sel = lambda x : x [:- 1 ,:]
234
+ else :
235
+ xs1 = xs
236
+ xstxs = xs1 .T .dot (xs1 )
237
+ I = np .eye (d )
238
+ I0 = I
239
+ sel = lambda x : x
240
+
241
+ if log :
242
+ log = {'err' :[]}
243
+
244
+ a ,b = unif (ns ),unif (nt )
245
+ M = dist (xs ,xt )* ns
246
+ G = emd (a ,b ,M )
247
+
248
+ vloss = []
249
+
250
+ def loss (L ,G ):
251
+ """Compute full loss"""
252
+ return np .sum ((xs1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum (sel (L - I0 )** 2 )
253
+
254
+ def solve_L (G ):
255
+ """ solve L problem with fixed G (least square)"""
256
+ xst = ns * G .dot (xt )
257
+ return np .linalg .solve (xstxs + eta * I ,xs1 .T .dot (xst )+ eta * I0 )
258
+
259
+ def solve_G (L ,G0 ):
260
+ """Update G with CG algorithm"""
261
+ xsi = xs1 .dot (L )
262
+ def f (G ):
263
+ return np .sum ((xsi - ns * G .dot (xt ))** 2 )
264
+ def df (G ):
265
+ return - 2 * ns * (xsi - ns * G .dot (xt )).dot (xt .T )
266
+ G = cg (a ,b ,M ,1.0 / mu ,f ,df ,G0 = G0 ,numItermax = numInnerItermax ,stopThr = stopInnerThr )
267
+ return G
268
+
269
+
270
+ L = solve_L (G )
271
+
272
+ vloss .append (loss (L ,G ))
273
+
274
+ if verbose :
275
+ print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
276
+ print ('{:5d}|{:8e}|{:8e}' .format (0 ,vloss [- 1 ],0 ))
277
+
278
+
279
+ # init loop
280
+ if numItermax > 0 :
281
+ loop = 1
282
+ else :
283
+ loop = 0
284
+ it = 0
285
+
286
+ while loop :
287
+
288
+ it += 1
289
+
290
+ # update G
291
+ G = solve_G (L ,G )
292
+
293
+ #update L
294
+ L = solve_L (G )
295
+
296
+ vloss .append (loss (L ,G ))
297
+
298
+ if it >= numItermax :
299
+ loop = 0
300
+
301
+ if abs (vloss [- 1 ]- vloss [- 2 ])< stopThr :
302
+ loop = 0
303
+
304
+ if verbose :
305
+ if it % 20 == 0 :
306
+ print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
307
+ print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],abs (vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
308
+ if log :
309
+ log ['loss' ]= vloss
310
+ return G ,L ,log
311
+ else :
312
+ return G ,L
206
313
207
314
208
315
class OTDA (object ):
@@ -294,6 +401,7 @@ def predict(self,x,direction=1):
294
401
295
402
class OTDA_sinkhorn (OTDA ):
296
403
"""Class for domain adaptation with optimal transport with entropic regularization"""
404
+
297
405
def fit (self ,xs ,xt ,reg = 1 ,ws = None ,wt = None ,** kwargs ):
298
406
""" Fit domain adaptation between samples is xs and xt (with optional
299
407
weights)"""
@@ -335,3 +443,51 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
335
443
self .G = sinkhorn_lpl1_mm (ws ,ys ,wt ,self .M ,reg ,eta ,** kwargs )
336
444
self .computed = True
337
445
446
+ class OTDA_mapping (OTDA ):
447
+ """Class for optimal transport with joint linear mapping estimation"""
448
+
449
+
450
+ def __init__ (self ,metric = 'sqeuclidean' ):
451
+ """ Class initialization"""
452
+
453
+
454
+ self .xs = 0
455
+ self .xt = 0
456
+ self .G = 0
457
+ self .L = 0
458
+ self .bias = False
459
+ self .metric = metric
460
+ self .computed = False
461
+
462
+ def fit (self ,xs ,xt ,mu = 1 ,eta = 1 ,bias = False ,** kwargs ):
463
+ """ Fit domain adaptation between samples is xs and xt (with optional
464
+ weights)"""
465
+ self .xs = xs
466
+ self .xt = xt
467
+ self .bias = bias
468
+
469
+ self .ws = unif (xs .shape [0 ])
470
+ self .wt = unif (xt .shape [0 ])
471
+
472
+ self .G ,self .L = joint_OT_mapping_linear (xs ,xt ,mu = mu ,eta = eta ,bias = bias ,** kwargs )
473
+ self .computed = True
474
+
475
+ def mapping (self ):
476
+ return lambda x : self .predict (x )
477
+
478
+
479
+ def predict (self ,x ):
480
+ """ Out of sample mapping using the formulation from Ferradans
481
+
482
+ It basically find the source sample the nearset to the nex sample and
483
+ apply the difference to the displaced source sample.
484
+
485
+ """
486
+ if self .computed :
487
+ if self .bias :
488
+ x = np .hstack ((x ,np .ones ((x .shape [0 ],1 ))))
489
+ return x .dot (self .L ) # aply the delta to the interpolation
490
+ else :
491
+ print ("Warning, model not fitted yet, returning None" )
492
+ return None
493
+
0 commit comments