10
10
# Nicolas Courty <[email protected] >
11
11
# Rémi Flamary <[email protected] >
12
12
# Titouan Vayer <[email protected] >
13
+ #
13
14
# License: MIT License
14
15
15
16
import numpy as np
@@ -351,9 +352,9 @@ def df(G):
351
352
return cg (p , q , 0 , 1 , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
352
353
353
354
354
- def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , ** kwargs ):
355
+ def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
355
356
"""
356
- Computes the FGW distance between two graphs see [3 ]
357
+ Computes the FGW transport between two graphs see [24 ]
357
358
.. math::
358
359
\gamma = arg\min_\gamma (1-\a lpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
359
360
s.t. \gamma 1 = p
@@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
377
378
distribution in the source space
378
379
q : ndarray, shape (nt,)
379
380
distribution in the target space
380
- loss_fun : string,optionnal
381
+ loss_fun : string,optional
381
382
loss function used for the solver
382
383
max_iter : int, optional
383
384
Max number of iterations
@@ -416,7 +417,86 @@ def f(G):
416
417
def df (G ):
417
418
return gwggrad (constC , hC1 , hC2 , G )
418
419
419
- return cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
420
+ if log :
421
+ res , log = cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
422
+ log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
423
+ return res , log
424
+ else :
425
+ return cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
426
+
427
+
428
+ def fused_gromov_wasserstein2 (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , armijo = False , log = False , ** kwargs ):
429
+ """
430
+ Computes the FGW distance between two graphs see [24]
431
+ .. math::
432
+ \gamma = arg\min_\gamma (1-\a lpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
433
+ s.t. \gamma 1 = p
434
+ \gamma^T 1= q
435
+ \gamma\geq 0
436
+ where :
437
+ - M is the (ns,nt) metric cost matrix
438
+ - :math:`f` is the regularization term ( and df is its gradient)
439
+ - a and b are source and target weights (sum to 1)
440
+ - L is a loss function to account for the misfit between the similarity matrices
441
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
442
+ Parameters
443
+ ----------
444
+ M : ndarray, shape (ns, nt)
445
+ Metric cost matrix between features across domains
446
+ C1 : ndarray, shape (ns, ns)
447
+ Metric cost matrix respresentative of the structure in the source space
448
+ C2 : ndarray, shape (nt, nt)
449
+ Metric cost matrix espresentative of the structure in the target space
450
+ p : ndarray, shape (ns,)
451
+ distribution in the source space
452
+ q : ndarray, shape (nt,)
453
+ distribution in the target space
454
+ loss_fun : string,optional
455
+ loss function used for the solver
456
+ max_iter : int, optional
457
+ Max number of iterations
458
+ tol : float, optional
459
+ Stop threshold on error (>0)
460
+ verbose : bool, optional
461
+ Print information along iterations
462
+ log : bool, optional
463
+ record log if True
464
+ armijo : bool, optional
465
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
466
+ If there is convergence issues use False.
467
+ **kwargs : dict
468
+ parameters can be directly pased to the ot.optim.cg solver
469
+ Returns
470
+ -------
471
+ gamma : (ns x nt) ndarray
472
+ Optimal transportation matrix for the given parameters
473
+ log : dict
474
+ log dictionary return only if log==True in parameters
475
+ References
476
+ ----------
477
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\' e}mi, Tavenard Romain
478
+ and Courty Nicolas
479
+ "Optimal Transport for structured data with application on graphs"
480
+ International Conference on Machine Learning (ICML). 2019.
481
+ """
482
+
483
+ constC , hC1 , hC2 = init_matrix (C1 , C2 , p , q , loss_fun )
484
+
485
+ G0 = p [:, None ] * q [None , :]
486
+
487
+ def f (G ):
488
+ return gwloss (constC , hC1 , hC2 , G )
489
+
490
+ def df (G ):
491
+ return gwggrad (constC , hC1 , hC2 , G )
492
+
493
+ res , log = cg (p , q , M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , log = True , ** kwargs )
494
+ if log :
495
+ log ['fgw_dist' ] = log ['loss' ][::- 1 ][0 ]
496
+ log ['T' ] = res
497
+ return log ['fgw_dist' ], log
498
+ else :
499
+ return log ['fgw_dist' ]
420
500
421
501
422
502
def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , armijo = False , ** kwargs ):
@@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
889
969
890
970
def fgw_barycenters (N , Ys , Cs , ps , lambdas , alpha , fixed_structure = False , fixed_features = False ,
891
971
p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-9 ,
892
- verbose = False , log = True , init_C = None , init_X = None ):
972
+ verbose = False , log = False , init_C = None , init_X = None ):
893
973
"""
894
974
Compute the fgw barycenter as presented eq (5) in [24].
895
975
----------
@@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
919
999
Barycenters' features
920
1000
C : ndarray, shape (N,N)
921
1001
Barycenters' structure matrix
922
- log_:
1002
+ log_: dictionary
1003
+ Only returned when log=True
923
1004
T : list of (N,ns) transport matrices
924
1005
Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
925
1006
References
@@ -1015,14 +1096,13 @@ class UndefinedParameter(Exception):
1015
1096
T = [fused_gromov_wasserstein ((1 - alpha ) * Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , alpha , numItermax = max_iter , stopThr = 1e-5 , verbose = verbose ) for s in range (S )]
1016
1097
1017
1098
# T is N,ns
1018
-
1019
- log_ ['Ts_iter' ].append (T )
1020
1099
err_feature = np .linalg .norm (X - Xprev .reshape (N , d ))
1021
1100
err_structure = np .linalg .norm (C - Cprev )
1022
1101
1023
1102
if log :
1024
1103
log_ ['err_feature' ].append (err_feature )
1025
1104
log_ ['err_structure' ].append (err_structure )
1105
+ log_ ['Ts_iter' ].append (T )
1026
1106
1027
1107
if verbose :
1028
1108
if cpt % 200 == 0 :
@@ -1032,11 +1112,15 @@ class UndefinedParameter(Exception):
1032
1112
print ('{:5d}|{:8e}|' .format (cpt , err_feature ))
1033
1113
1034
1114
cpt += 1
1035
- log_ ['T' ] = T # from target to Ys
1036
- log_ ['p' ] = p
1037
- log_ ['Ms' ] = Ms # Ms are N,ns
1115
+ if log :
1116
+ log_ ['T' ] = T # from target to Ys
1117
+ log_ ['p' ] = p
1118
+ log_ ['Ms' ] = Ms # Ms are N,ns
1038
1119
1039
- return X , C , log_
1120
+ if log :
1121
+ return X , C , log_
1122
+ else :
1123
+ return X , C
1040
1124
1041
1125
1042
1126
def update_sructure_matrix (p , lambdas , T , Cs ):
0 commit comments