@@ -243,11 +243,11 @@ def test_empirical_sinkhorn_divergence():
243
243
emp_sinkhorn_div = ot .bregman .empirical_sinkhorn_divergence (X_s , X_t , 1 )
244
244
sinkhorn_div = (ot .sinkhorn2 (a , b , M , 1 ) - 1 / 2 * ot .sinkhorn2 (a , a , M_s , 1 ) - 1 / 2 * ot .sinkhorn2 (b , b , M_t , 1 ))
245
245
246
- emp_sinkhorn_div_log , log_es = ot .bregman .empirical_sinkhorn_divergence (X_s , X_t , 0. 1 , log = True )
247
- sink_div_log , log_s = ot .sinkhorn2 (a , b , M , 1 )
248
- sink_div_log_a , log_s_a = ot .sinkhorn2 (a , a , M_s , 1 )
249
- sink_div_log_b , log_s_b = ot .sinkhorn2 (b , b , M_t , 1 )
250
- sink_div_log = sink_div_log - 1 / 2 * (sink_div_log_a + sink_div_log_b )
246
+ emp_sinkhorn_div_log , log_es = ot .bregman .empirical_sinkhorn_divergence (X_s , X_t , 1 , log = True )
247
+ sink_div_log_ab , log_s_ab = ot .sinkhorn2 (a , b , M , 1 , log = True )
248
+ sink_div_log_a , log_s_a = ot .sinkhorn2 (a , a , M_s , 1 , log = True )
249
+ sink_div_log_b , log_s_b = ot .sinkhorn2 (b , b , M_t , 1 , log = True )
250
+ sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b )
251
251
252
252
# check constratints
253
253
np .testing .assert_allclose (
0 commit comments