@@ -323,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
323
323
if len (b .shape ) < 2 :
324
324
if method .lower () == 'sinkhorn' :
325
325
res = sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
326
- stopThr = stopThr , verbose = verbose , log = log ,
326
+ stopThr = stopThr , verbose = verbose ,
327
+ log = log , warn = warn ,
327
328
** kwargs )
328
329
elif method .lower () == 'sinkhorn_log' :
329
330
res = sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
330
- stopThr = stopThr , verbose = verbose , log = log ,
331
+ stopThr = stopThr , verbose = verbose ,
332
+ log = log , warn = warn ,
331
333
** kwargs )
332
334
elif method .lower () == 'sinkhorn_stabilized' :
333
335
res = sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
334
- stopThr = stopThr , verbose = verbose , log = log ,
336
+ stopThr = stopThr , verbose = verbose ,
337
+ log = log , warn = warn ,
335
338
** kwargs )
336
339
else :
337
340
raise ValueError ("Unknown method '%s'." % method )
@@ -344,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
344
347
345
348
if method .lower () == 'sinkhorn' :
346
349
return sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
347
- stopThr = stopThr , verbose = verbose , log = log ,
350
+ stopThr = stopThr , verbose = verbose ,
351
+ log = log , warn = warn ,
348
352
** kwargs )
349
353
elif method .lower () == 'sinkhorn_log' :
350
354
return sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
351
- stopThr = stopThr , verbose = verbose , log = log ,
355
+ stopThr = stopThr , verbose = verbose ,
356
+ log = log , warn = warn ,
352
357
** kwargs )
353
358
elif method .lower () == 'sinkhorn_stabilized' :
354
359
return sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
355
- stopThr = stopThr , verbose = verbose , log = log ,
360
+ stopThr = stopThr , verbose = verbose ,
361
+ log = log , warn = warn ,
356
362
** kwargs )
357
363
else :
358
364
raise ValueError ("Unknown method '%s'." % method )
0 commit comments