@@ -3173,8 +3173,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
3173
3173
return loss
3174
3174
3175
3175
else :
3176
- M = dist (nx .to_numpy (X_s ), nx .to_numpy (X_t ), metric = metric )
3177
- M = nx .from_numpy (M , type_as = a )
3176
+ M = dist (X_s , X_t , metric = metric )
3178
3177
3179
3178
if log :
3180
3179
sinkhorn_loss , log = sinkhorn2 (a , b , M , reg , numItermax = numIterMax ,
@@ -3287,6 +3286,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3287
3286
International Conference on Artficial Intelligence and Statistics,
3288
3287
(AISTATS) 21, 2018
3289
3288
'''
3289
+ X_s , X_t = list_to_array (X_s , X_t )
3290
+
3291
+ nx = get_backend (X_s , X_t )
3292
+
3290
3293
if log :
3291
3294
sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
3292
3295
numIterMax = numIterMax ,
@@ -3313,7 +3316,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3313
3316
log ['log_sinkhorn_a' ] = log_a
3314
3317
log ['log_sinkhorn_b' ] = log_b
3315
3318
3316
- return max (0 , sinkhorn_div ), log
3319
+ return nx . maximum (0 , sinkhorn_div ), log
3317
3320
3318
3321
else :
3319
3322
sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
@@ -3332,7 +3335,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3332
3335
warn = warn , ** kwargs )
3333
3336
3334
3337
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b )
3335
- return max (0 , sinkhorn_div )
3338
+ return nx . maximum (0 , sinkhorn_div )
3336
3339
3337
3340
3338
3341
def screenkhorn (a , b , M , reg , ns_budget = None , nt_budget = None , uniform = False ,
0 commit comments