Skip to content

Commit 0d23718

Browse files
author
Hicham Janati
committed
remove square in convergence check
1 parent 952503e commit 0d23718

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ot/unbalanced.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
371371
if cpt % 10 == 0:
372372
# we can speed up the process by checking for the error only all
373373
# the 10th iterations
374-
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
375-
np.sum((v - vprev)**2) / np.sum((v)**2)
374+
err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.)
375+
err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.)
376+
err = 0.5 * (err_u + err_v)
376377
if log:
377378
log['err'].append(err)
378379
if verbose:
@@ -498,8 +499,9 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
498499
if cpt % 10 == 0:
499500
# we can speed up the process by checking for the error only all
500501
# the 10th iterations
501-
err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
502-
np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
502+
err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.)
503+
err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.)
504+
err = 0.5 * (err_u + err_v)
503505
if log:
504506
log['err'].append(err)
505507
if verbose:

0 commit comments

Comments
 (0)