@@ -1569,8 +1569,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
1569
1569
1570
1570
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
1571
1571
'''
1572
+ if log :
1573
+ sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1574
+
1575
+ sinkhorn_loss_a , log_a = empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1576
+
1577
+ sinkhorn_loss_b , log_b = empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1572
1578
1573
- sinkhorn_div = (2 * empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1574
- empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1575
- empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ))
1576
- return max (0 , sinkhorn_div )
1579
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b )
1580
+
1581
+ log = {}
1582
+ log ['sinkhorn_loss_ab' ] = sinkhorn_loss_ab
1583
+ log ['sinkhorn_loss_a' ] = sinkhorn_loss_a
1584
+ log ['sinkhorn_loss_b' ] = sinkhorn_loss_b
1585
+ log ['log_sinkhorn_ab' ] = log_ab
1586
+ log ['log_sinkhorn_a' ] = log_a
1587
+ log ['log_sinkhorn_b' ] = log_b
1588
+
1589
+ return max (0 , sinkhorn_div ), log
1590
+ else :
1591
+ sinkhorn_div = (empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1592
+ 1 / 2 * empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1593
+ 1 / 2 * empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ))
1594
+ return max (0 , sinkhorn_div )
0 commit comments