@@ -1281,7 +1281,7 @@ def get_reg(n): # exponential decreasing
1281
1281
regi = get_reg (ii )
1282
1282
1283
1283
G , logi = sinkhorn_stabilized (a , b , M , regi ,
1284
- numItermax = numInnerItermax , stopThr = 1e-9 ,
1284
+ numItermax = numInnerItermax , stopThr = stopThr ,
1285
1285
warmstart = (alpha , beta ), verbose = False ,
1286
1286
print_period = 20 , tau = tau , log = True )
1287
1287
@@ -3306,17 +3306,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3306
3306
if log :
3307
3307
sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
3308
3308
numIterMax = numIterMax ,
3309
- stopThr = 1e-9 , verbose = verbose ,
3309
+ stopThr = stopThr , verbose = verbose ,
3310
3310
log = log , warn = warn , ** kwargs )
3311
3311
3312
3312
sinkhorn_loss_a , log_a = empirical_sinkhorn2 (X_s , X_s , reg , a , a , metric = metric ,
3313
3313
numIterMax = numIterMax ,
3314
- stopThr = 1e-9 , verbose = verbose ,
3314
+ stopThr = stopThr , verbose = verbose ,
3315
3315
log = log , warn = warn , ** kwargs )
3316
3316
3317
3317
sinkhorn_loss_b , log_b = empirical_sinkhorn2 (X_t , X_t , reg , b , b , metric = metric ,
3318
3318
numIterMax = numIterMax ,
3319
- stopThr = 1e-9 , verbose = verbose ,
3319
+ stopThr = stopThr , verbose = verbose ,
3320
3320
log = log , warn = warn , ** kwargs )
3321
3321
3322
3322
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b )
@@ -3333,17 +3333,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3333
3333
3334
3334
else :
3335
3335
sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
3336
- numIterMax = numIterMax , stopThr = 1e-9 ,
3336
+ numIterMax = numIterMax , stopThr = stopThr ,
3337
3337
verbose = verbose , log = log ,
3338
3338
warn = warn , ** kwargs )
3339
3339
3340
3340
sinkhorn_loss_a = empirical_sinkhorn2 (X_s , X_s , reg , a , a , metric = metric ,
3341
- numIterMax = numIterMax , stopThr = 1e-9 ,
3341
+ numIterMax = numIterMax , stopThr = stopThr ,
3342
3342
verbose = verbose , log = log ,
3343
3343
warn = warn , ** kwargs )
3344
3344
3345
3345
sinkhorn_loss_b = empirical_sinkhorn2 (X_t , X_t , reg , b , b , metric = metric ,
3346
- numIterMax = numIterMax , stopThr = 1e-9 ,
3346
+ numIterMax = numIterMax , stopThr = stopThr ,
3347
3347
verbose = verbose , log = log ,
3348
3348
warn = warn , ** kwargs )
3349
3349
0 commit comments