Skip to content

Commit 486b0d6

Browse files
authored
[MRG] Center gradients for mass of emd2 and gw2 (#363)
* center gradients for mass of emd2 and gw2 * debug fgw gradient * debug fgw
1 parent ac4cf44 commit 486b0d6

File tree

4 files changed

+19
-7
lines changed

4 files changed

+19
-7
lines changed

RELEASES.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
#### New features
77

8-
- remode deprecated `ot.gpu` submodule (PR #361)
8+
- Remove deprecated `ot.gpu` submodule (PR #361)
99
- Update examples in the gallery (PR #359).
1010
- Add stochastic loss and OT plan computation for regularized OT and
1111
backend examples(PR #360).
@@ -23,6 +23,8 @@
2323

2424
#### Closed issues
2525

26+
- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
27+
centered (Issue #364, PR #363)
2628
- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
2729
PR #338)
2830
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)

ot/gromov.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,8 @@ def df(G):
551551
gC1 = nx.from_numpy(gC1, type_as=C10)
552552
gC2 = nx.from_numpy(gC2, type_as=C10)
553553
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
554-
(log_gw['u'], log_gw['v'], gC1, gC2))
554+
(log_gw['u'] - nx.mean(log_gw['u']),
555+
log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
555556

556557
if log:
557558
return gw, log_gw
@@ -793,7 +794,9 @@ def df(G):
793794
gC1 = nx.from_numpy(gC1, type_as=C10)
794795
gC2 = nx.from_numpy(gC2, type_as=C10)
795796
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
796-
(log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
797+
(log_fgw['u'] - nx.mean(log_fgw['u']),
798+
log_fgw['v'] - nx.mean(log_fgw['v']),
799+
alpha * gC1, alpha * gC2, (1 - alpha) * T0))
797800

798801
if log:
799802
return fgw_dist, log_fgw

ot/lp/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,8 @@ def f(b):
517517
log['warning'] = result_code_string
518518
log['result_code'] = result_code
519519
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
520-
(a0, b0, M0), (log['u'], log['v'], G))
520+
(a0, b0, M0), (log['u'] - nx.mean(log['u']),
521+
log['v'] - nx.mean(log['v']), G))
521522
return [cost, log]
522523
else:
523524
def f(b):
@@ -540,8 +541,8 @@ def f(b):
540541
)
541542
G = nx.from_numpy(G, type_as=type_as)
542543
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
543-
(a0, b0, M0), (nx.from_numpy(u, type_as=type_as),
544-
nx.from_numpy(v, type_as=type_as), G))
544+
(a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
545+
nx.from_numpy(v - np.mean(v), type_as=type_as), G))
545546

546547
check_result(result_code)
547548
return cost

test/test_ot.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,20 @@ def test_emd2_gradients():
147147
b1 = torch.tensor(a, requires_grad=True)
148148
M1 = torch.tensor(M, requires_grad=True)
149149

150-
val = ot.emd2(a1, b1, M1)
150+
val, log = ot.emd2(a1, b1, M1, log=True)
151151

152152
val.backward()
153153

154154
assert a1.shape == a1.grad.shape
155155
assert b1.shape == b1.grad.shape
156156
assert M1.shape == M1.grad.shape
157157

158+
assert np.allclose(a1.grad.cpu().detach().numpy(),
159+
log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean())
160+
161+
assert np.allclose(b1.grad.cpu().detach().numpy(),
162+
log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean())
163+
158164
# Testing for bug #309, checking for scaling of gradient
159165
a2 = torch.tensor(a, requires_grad=True)
160166
b2 = torch.tensor(a, requires_grad=True)

0 commit comments

Comments
 (0)