Skip to content

Commit ea841c7

Browse files
[MRG] Restructure and Augment Partial Gromov-Wasserstein solvers (#663)
* init commit srgw bary * init commit - restructure partial GW * remove overlap with adjacent srGW PR * update tests * update old exemple * up * correct line-search + augment generic cg + complete tests * fix issues from change in gcg * trying to fix bugs with kwargs and args * update optim file * fix pep8 * updates * completing tests partial gw * complete partial tests * fix pep8 * up * fix prints in docs * up * fix precision tests * fixing tests * up * update tests * update tests * up * fasten partial cg * fasten partial cg * put back old partial gw functions * fix partial * improve doc of optim.py * merge with master * Update RELEASES.md * improving doc for optim.py and adding breaking change to release.md * improving doc for optim.py and adding breaking change to release.md * tipos in doc --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent 1a6c790 commit ea841c7

14 files changed

+1704
-288
lines changed

Diff for: RELEASES.md

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## 0.9.5dev
44

5+
#### Breaking change
6+
- Custom functions provided as parameter `line_search` to `ot.optim.generic_conditional_gradient` must now have the signature `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)`, adding as input `df_G` the gradient of the regularizer evaluated at the transport plan `G`. This change aims at improving speed of solvers having quadratic polynomial functions as regularizer such as the Gromov-Wassertein loss (PR #663).
7+
58
#### New features
69
- Added feature `mass=True` for `nx.kl_div` (PR #654)
710
- Implemented Gaussian Mixture Model OT `ot.gmm` (PR #649)
@@ -13,6 +16,8 @@
1316
- Restructured `ot.unbalanced` module (PR #658)
1417
- Added `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
1518
- Implemented Fused unbalanced Gromov-Wasserstein and unbalanced Co-Optimal Transport (PR #677)
19+
- Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663)
20+
- Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663)
1621
- Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)
1722

1823
#### Closed issues

Diff for: examples/unbalanced-partial/plot_partial_wass_and_gromov.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@
125125
# transport 100% of the mass
126126
print('------m = 1')
127127
m = 1
128-
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
129-
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
130-
m=m, log=True,
131-
verbose=True)
128+
res0, log0 = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
129+
res, log = ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
130+
m=m, log=True,
131+
verbose=True)
132132

133133
print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
134134
print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
@@ -146,11 +146,11 @@
146146
# transport 2/3 of the mass
147147
print('------m = 2/3')
148148
m = 2 / 3
149-
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True,
150-
verbose=True)
151-
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
152-
m=m, log=True,
153-
verbose=True)
149+
res0, log0 = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True,
150+
verbose=True)
151+
res, log = ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
152+
m=m, log=True,
153+
verbose=True)
154154

155155
print('Partial Wasserstein distance (m = 2/3): ' +
156156
str(log0['partial_gw_dist']))

Diff for: ot/gromov/__init__.py

+40-18
Original file line numberDiff line numberDiff line change
@@ -72,33 +72,55 @@
7272
unbalanced_co_optimal_transport2,
7373
fused_unbalanced_across_spaces_divergence)
7474

75+
from ._partial import (partial_gromov_wasserstein,
76+
partial_gromov_wasserstein2,
77+
solve_partial_gromov_linesearch,
78+
entropic_partial_gromov_wasserstein,
79+
entropic_partial_gromov_wasserstein2)
80+
81+
7582
__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
7683
'init_matrix_semirelaxed', 'semirelaxed_init_plan',
7784
'update_barycenter_structure', 'update_barycenter_feature',
7885
'div_between_product', 'div_to_product', 'fused_unbalanced_across_spaces_cost',
7986
'uot_cost_matrix', 'uot_parameters_and_measures',
80-
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
81-
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
82-
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
83-
'BAPG_gromov_wasserstein', 'BAPG_gromov_wasserstein2',
84-
'entropic_gromov_barycenters', 'entropic_fused_gromov_wasserstein',
87+
'gromov_wasserstein', 'gromov_wasserstein2',
88+
'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
89+
'solve_gromov_linesearch', 'gromov_barycenters',
90+
'fgw_barycenters', 'entropic_gromov_wasserstein',
91+
'entropic_gromov_wasserstein2', 'BAPG_gromov_wasserstein',
92+
'BAPG_gromov_wasserstein2', 'entropic_gromov_barycenters',
93+
'entropic_fused_gromov_wasserstein',
8594
'entropic_fused_gromov_wasserstein2', 'BAPG_fused_gromov_wasserstein',
8695
'BAPG_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters',
87-
'GW_distance_estimation', 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein',
96+
'GW_distance_estimation', 'pointwise_gromov_wasserstein',
97+
'sampled_gromov_wasserstein',
8898
'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2',
89-
'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2',
90-
'solve_semirelaxed_gromov_linesearch', 'entropic_semirelaxed_gromov_wasserstein',
91-
'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein',
99+
'semirelaxed_fused_gromov_wasserstein',
100+
'semirelaxed_fused_gromov_wasserstein2',
101+
'solve_semirelaxed_gromov_linesearch',
102+
'entropic_semirelaxed_gromov_wasserstein',
103+
'entropic_semirelaxed_gromov_wasserstein2',
104+
'entropic_semirelaxed_fused_gromov_wasserstein',
92105
'entropic_semirelaxed_fused_gromov_wasserstein2',
93106
'semirelaxed_fgw_barycenters', 'semirelaxed_gromov_barycenters',
94107
'gromov_wasserstein_dictionary_learning',
95-
'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
96-
'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples',
97-
'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition',
98-
'get_graph_representants', 'format_partitioned_graph',
99-
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples',
100-
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples',
101-
'fused_unbalanced_gromov_wasserstein', 'fused_unbalanced_gromov_wasserstein2',
102-
'unbalanced_co_optimal_transport', 'unbalanced_co_optimal_transport2',
103-
'fused_unbalanced_across_spaces_divergence'
108+
'gromov_wasserstein_linear_unmixing',
109+
'fused_gromov_wasserstein_dictionary_learning',
110+
'fused_gromov_wasserstein_linear_unmixing',
111+
'lowrank_gromov_wasserstein_samples',
112+
'quantized_fused_gromov_wasserstein_partitioned',
113+
'get_graph_partition', 'get_graph_representants',
114+
'format_partitioned_graph', 'quantized_fused_gromov_wasserstein',
115+
'get_partition_and_representants_samples', 'format_partitioned_samples',
116+
'quantized_fused_gromov_wasserstein_samples',
117+
'fused_unbalanced_gromov_wasserstein',
118+
'fused_unbalanced_gromov_wasserstein2',
119+
'unbalanced_co_optimal_transport',
120+
'unbalanced_co_optimal_transport2',
121+
'fused_unbalanced_across_spaces_divergence',
122+
'partial_gromov_wasserstein', 'partial_gromov_wasserstein2',
123+
'solve_partial_gromov_linesearch',
124+
'entropic_partial_gromov_wasserstein',
125+
'entropic_partial_gromov_wasserstein2'
104126
]

Diff for: ot/gromov/_gw.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,10 @@ def df(G):
167167
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
168168

169169
if armijo:
170-
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
170+
def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):
171171
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
172172
else:
173-
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
173+
def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):
174174
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, symmetric=symmetric, **kwargs)
175175

176176
if not nx.is_floating_point(C10):
@@ -475,11 +475,12 @@ def df(G):
475475
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
476476

477477
if armijo:
478-
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
478+
def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):
479479
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
480480
else:
481-
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
481+
def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):
482482
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, symmetric=symmetric, **kwargs)
483+
483484
if not nx.is_floating_point(M0):
484485
warnings.warn(
485486
"Input feature matrix consists of integer. The transport plan will be "
@@ -625,9 +626,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
625626
if loss_fun == 'square_loss':
626627
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
627628
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
629+
628630
elif loss_fun == 'kl_loss':
629631
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
630632
gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
633+
631634
if isinstance(alpha, int) or isinstance(alpha, float):
632635
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
633636
(log_fgw['u'] - nx.mean(log_fgw['u']),

0 commit comments

Comments
 (0)