@@ -123,7 +123,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
123
123
124
124
return transp
125
125
126
- def joint_OT_mapping_linear (xs ,xt ,mu = 1 ,eta = 0.001 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
126
+ def joint_OT_mapping_linear (xs ,xt ,mu = 1 ,eta = 0.001 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 10 ,stopInnerThr = 1e-6 ,stopThr = 1e-5 ,log = False ,** kwargs ):
127
127
"""Joint Ot and mapping estimation (uniform weights and )
128
128
"""
129
129
@@ -209,15 +209,15 @@ def df(G):
209
209
if verbose :
210
210
if it % 20 == 0 :
211
211
print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
212
- print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],abs (vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
212
+ print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],(vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
213
213
if log :
214
214
log ['loss' ]= vloss
215
215
return G ,L ,log
216
216
else :
217
217
return G ,L
218
218
219
219
220
- def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kerneltype = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
220
+ def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kerneltype = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 10 ,stopInnerThr = 1e-6 ,stopThr = 1e-5 ,log = False ,** kwargs ):
221
221
"""Joint Ot and mapping estimation (uniform weights and )
222
222
"""
223
223
@@ -228,15 +228,31 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
228
228
K1 = np .hstack ((K ,np .ones ((ns ,1 ))))
229
229
I = np .eye (ns + 1 )
230
230
I [- 1 ]= 0
231
- K0 = K1 .T .dot (K1 )+ eta * I
232
- Kreg = I
233
- sel = lambda x : x [:- 1 ,:]
231
+ Kp = np .eye (ns + 1 )
232
+ Kp [:ns ,:ns ]= K
233
+
234
+ # ls regu
235
+ #K0 = K1.T.dot(K1)+eta*I
236
+ #Kreg=I
237
+
238
+ # RKHS regul
239
+ K0 = K1 .T .dot (K1 )+ eta * Kp
240
+ Kreg = Kp
241
+
234
242
else :
235
243
K1 = K
236
244
I = np .eye (ns )
245
+
246
+ # ls regul
247
+ #K0 = K1.T.dot(K1)+eta*I
248
+ #Kreg=I
249
+
250
+ # proper kernel ridge
237
251
K0 = K + eta * I
238
252
Kreg = K
239
- sel = lambda x : x
253
+
254
+
255
+
240
256
241
257
if log :
242
258
log = {'err' :[]}
@@ -313,7 +329,7 @@ def df(G):
313
329
if verbose :
314
330
if it % 20 == 0 :
315
331
print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
316
- print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],abs (vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
332
+ print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],(vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
317
333
if log :
318
334
log ['loss' ]= vloss
319
335
return G ,L ,log
0 commit comments