We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ac4cf44 commit 486b0d6Copy full SHA for 486b0d6
RELEASES.md
@@ -5,7 +5,7 @@
5
6
#### New features
7
8
-- remode deprecated `ot.gpu` submodule (PR #361)
+- Remove deprecated `ot.gpu` submodule (PR #361)
9
- Update examples in the gallery (PR #359).
10
- Add stochastic loss and OT plan computation for regularized OT and
11
backend examples(PR #360).
@@ -23,6 +23,8 @@
23
24
#### Closed issues
25
26
+- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
27
+ centered (Issue #364, PR #363)
28
- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
29
PR #338)
30
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
ot/gromov.py
@@ -551,7 +551,8 @@ def df(G):
551
gC1 = nx.from_numpy(gC1, type_as=C10)
552
gC2 = nx.from_numpy(gC2, type_as=C10)
553
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
554
- (log_gw['u'], log_gw['v'], gC1, gC2))
+ (log_gw['u'] - nx.mean(log_gw['u']),
555
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
556
557
if log:
558
return gw, log_gw
@@ -793,7 +794,9 @@ def df(G):
793
794
795
796
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
- (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
797
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
798
+ log_fgw['v'] - nx.mean(log_fgw['v']),
799
+ alpha * gC1, alpha * gC2, (1 - alpha) * T0))
800
801
802
return fgw_dist, log_fgw
ot/lp/__init__.py
@@ -517,7 +517,8 @@ def f(b):
517
log['warning'] = result_code_string
518
log['result_code'] = result_code
519
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
520
- (a0, b0, M0), (log['u'], log['v'], G))
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
521
+ log['v'] - nx.mean(log['v']), G))
522
return [cost, log]
523
else:
524
def f(b):
@@ -540,8 +541,8 @@ def f(b):
540
541
)
542
G = nx.from_numpy(G, type_as=type_as)
543
- (a0, b0, M0), (nx.from_numpy(u, type_as=type_as),
544
- nx.from_numpy(v, type_as=type_as), G))
+ (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
545
+ nx.from_numpy(v - np.mean(v), type_as=type_as), G))
546
547
check_result(result_code)
548
return cost
test/test_ot.py
@@ -147,14 +147,20 @@ def test_emd2_gradients():
147
b1 = torch.tensor(a, requires_grad=True)
148
M1 = torch.tensor(M, requires_grad=True)
149
150
- val = ot.emd2(a1, b1, M1)
+ val, log = ot.emd2(a1, b1, M1, log=True)
151
152
val.backward()
153
154
assert a1.shape == a1.grad.shape
155
assert b1.shape == b1.grad.shape
156
assert M1.shape == M1.grad.shape
157
158
+ assert np.allclose(a1.grad.cpu().detach().numpy(),
159
+ log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean())
160
+
161
+ assert np.allclose(b1.grad.cpu().detach().numpy(),
162
+ log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean())
163
164
# Testing for bug #309, checking for scaling of gradient
165
a2 = torch.tensor(a, requires_grad=True)
166
b2 = torch.tensor(a, requires_grad=True)
0 commit comments