Skip to content

Commit 14c08ba

Browse files
yikun-baiorflamarycedricvincentcuaz
authored
[MRG] Fix Gradient scaling in Partial GW solver (#602)
* new file: ot/partial_gw.py * remove partial_gw.py to update existing file partial.py * fix pep8 --------- Co-authored-by: Rémi Flamary <[email protected]> Co-authored-by: Cédric Vincent-Cuaz <[email protected]>
1 parent 628a089 commit 14c08ba

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- Fix same sign error for sr(F)GW conditional gradient solvers (PR #611)
2020
- Split `test/test_gromov.py` into `test/gromov/` (PR #619)
2121
- Fix (F)GW barycenter functions to support computing barycenter on 1 input + deprecate structures as lists (PR #628)
22+
- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602)
2223

2324
## 0.9.3
2425
*January 2024*

ot/partial.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
"""
55

66
# Author: Laetitia Chapel <[email protected]>
7-
# License: MIT License
7+
# Yikun Bai < [email protected] >
8+
# Cédric Vincent-Cuaz <[email protected]>
89

9-
import numpy as np
10-
from .lp import emd
11-
from .backend import get_backend
1210
from .utils import list_to_array
11+
from .backend import get_backend
12+
from .lp import emd
13+
import numpy as np
14+
15+
# License: MIT License
1316

1417

1518
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
@@ -581,7 +584,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
581584
" equal than min(|a|_1, |b|_1).")
582585

583586
if G0 is None:
584-
G0 = np.outer(p, q)
587+
G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q.
585588

586589
dim_G_extended = (len(p) + nb_dummies, len(q) + nb_dummies)
587590
q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies)
@@ -597,7 +600,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
597600

598601
Gprev = np.copy(G0)
599602

600-
M = gwgrad_partial(C1, C2, G0)
603+
M = 0.5 * gwgrad_partial(C1, C2, G0) # rescaling the gradient with 0.5 for line-search while not changing Gc
601604
M_emd = np.zeros(dim_G_extended)
602605
M_emd[:len(p), :len(q)] = M
603606
M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2

0 commit comments

Comments
 (0)