@@ -217,25 +217,23 @@ def df(G):
217
217
return G ,L
218
218
219
219
220
- def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kernel = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
220
+ def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kerneltype = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
221
221
"""Joint Ot and mapping estimation (uniform weights and )
222
222
"""
223
223
224
224
ns ,nt ,d = xs .shape [0 ],xt .shape [0 ],xt .shape [1 ]
225
225
226
+ K = kernel (xs ,xs ,method = kerneltype ,sigma = sigma )
226
227
if bias :
227
- K =
228
- xs1 = np .hstack ((xs ,np .ones ((ns ,1 ))))
229
- xstxs = xs1 .T .dot (xs1 )
230
- I = np .eye (d + 1 )
228
+ K1 = np .hstack ((K ,np .ones ((ns ,1 ))))
229
+ I = np .eye (ns + 1 )
231
230
I [- 1 ]= 0
232
- I0 = I [:,: - 1 ]
231
+ K0 = K1 . T . dot ( K1 ) + eta * I
233
232
sel = lambda x : x [:- 1 ,:]
234
233
else :
235
- xs1 = xs
236
- xstxs = xs1 .T .dot (xs1 )
237
- I = np .eye (d )
238
- I0 = I
234
+ K1 = K
235
+ I = np .eye (ns )
236
+ K0 = K + eta * I
239
237
sel = lambda x : x
240
238
241
239
if log :
@@ -249,23 +247,32 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=
249
247
250
248
def loss (L ,G ):
251
249
"""Compute full loss"""
252
- return np .sum ((xs1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum (sel (L - I0 )** 2 )
250
+ return np .sum ((K1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum (sel (L )** 2 )
253
251
254
- def solve_L (G ):
252
+ def solve_L_nobias (G ):
255
253
""" solve L problem with fixed G (least square)"""
256
254
xst = ns * G .dot (xt )
257
- return np .linalg .solve (xstxs + eta * I ,xs1 .T .dot (xst )+ eta * I0 )
255
+ return np .linalg .solve (K0 ,xst )
256
+
257
+ def solve_L_bias (G ):
258
+ """ solve L problem with fixed G (least square)"""
259
+ xst = ns * G .dot (xt )
260
+ return np .linalg .solve (K0 ,K1 .T .dot (xst ))
258
261
259
262
def solve_G (L ,G0 ):
260
263
"""Update G with CG algorithm"""
261
- xsi = xs1 .dot (L )
264
+ xsi = K1 .dot (L )
262
265
def f (G ):
263
266
return np .sum ((xsi - ns * G .dot (xt ))** 2 )
264
267
def df (G ):
265
268
return - 2 * ns * (xsi - ns * G .dot (xt )).dot (xt .T )
266
269
G = cg (a ,b ,M ,1.0 / mu ,f ,df ,G0 = G0 ,numItermax = numInnerItermax ,stopThr = stopInnerThr )
267
270
return G
268
271
272
+ if bias :
273
+ solve_L = solve_L_bias
274
+ else :
275
+ solve_L = solve_L_nobias
269
276
270
277
L = solve_L (G )
271
278
0 commit comments