@@ -1587,8 +1587,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
1587
1587
log ['log_sinkhorn_b' ] = log_b
1588
1588
1589
1589
return max (0 , sinkhorn_div ), log
1590
+
1590
1591
else :
1591
- sinkhorn_div = (empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1592
- 1 / 2 * empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1593
- 1 / 2 * empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ))
1592
+ sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1593
+
1594
+ sinkhorn_loss_a = empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1595
+
1596
+ sinkhorn_loss_b = empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1597
+
1598
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b )
1594
1599
return max (0 , sinkhorn_div )
0 commit comments