4
4
"""
5
5
6
6
# Author: Laetitia Chapel <[email protected] >
7
- # License: MIT License
7
+
8
+ # Cédric Vincent-Cuaz <[email protected] >
8
9
9
- import numpy as np
10
- from .lp import emd
11
- from .backend import get_backend
12
10
from .utils import list_to_array
11
+ from .backend import get_backend
12
+ from .lp import emd
13
+ import numpy as np
14
+
15
+ # License: MIT License
13
16
14
17
15
18
def partial_wasserstein_lagrange (a , b , M , reg_m = None , nb_dummies = 1 , log = False ,
@@ -581,7 +584,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
581
584
" equal than min(|a|_1, |b|_1)." )
582
585
583
586
if G0 is None :
584
- G0 = np .outer (p , q )
587
+ G0 = np .outer (p , q ) * m / ( np . sum ( p ) * np . sum ( q )) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q.
585
588
586
589
dim_G_extended = (len (p ) + nb_dummies , len (q ) + nb_dummies )
587
590
q_extended = np .append (q , [(np .sum (p ) - m ) / nb_dummies ] * nb_dummies )
@@ -597,7 +600,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
597
600
598
601
Gprev = np .copy (G0 )
599
602
600
- M = gwgrad_partial (C1 , C2 , G0 )
603
+ M = 0.5 * gwgrad_partial (C1 , C2 , G0 ) # rescaling the gradient with 0.5 for line-search while not changing Gc
601
604
M_emd = np .zeros (dim_G_extended )
602
605
M_emd [:len (p ), :len (q )] = M
603
606
M_emd [- nb_dummies :, - nb_dummies :] = np .max (M ) * 1e2
0 commit comments