Skip to content

Commit f4b363d

Browse files
authored
[WIP] Fix gradient scaling bug in emd (#310)
* orrect gradient bug in emd2 * small comment in test * deploy properly on tag release * subplot fail
1 parent 0c58991 commit f4b363d

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

.circleci/config.yml

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,24 +134,21 @@ jobs:
134134
name: Deploy docs
135135
command: |
136136
set -e;
137-
if [ "${CIRCLE_BRANCH}" == "master" ]; then
138-
git config --global user.email "[email protected]";
139-
git config --global user.name "Circle CI";
140-
cd ~/PythonOT.github.io;
141-
git checkout master
142-
git remote -v
143-
git fetch origin
144-
git reset --hard origin/master
145-
git clean -xdf
146-
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
147-
cp -a /tmp/build/html/* .;
148-
touch .nojekyll;
149-
git add -A;
150-
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
151-
git push origin master;
152-
else
153-
echo "No deployment (build: ${CIRCLE_BRANCH}).";
154-
fi
137+
git config --global user.email "[email protected]";
138+
git config --global user.name "Circle CI";
139+
cd ~/PythonOT.github.io;
140+
git checkout master
141+
git remote -v
142+
git fetch origin
143+
git reset --hard origin/master
144+
git clean -xdf
145+
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
146+
cp -a /tmp/build/html/* .;
147+
touch .nojekyll;
148+
git add -A;
149+
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
150+
git push origin master;
151+
155152
156153
157154
workflows:

examples/plot_Intro_OT.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@
327327
time_sinkhorn_reg[k] = time.time() - start
328328

329329
if k % 4 == 0 and k > 0: # we only plot a few
330-
ax = pl.subplot(1, 5, k / 4)
330+
ax = pl.subplot(1, 5, k // 4)
331331
im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot)
332332
pl.title('reg={0:.2g}'.format(reg_parameter[k]))
333333
pl.xlabel('Cafés')

ot/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1203,7 +1203,7 @@ def forward(ctx, val, grads, *inputs):
12031203
@staticmethod
12041204
def backward(ctx, grad_output):
12051205
# the gradients are grad
1206-
return (None, None) + ctx.grads
1206+
return (None, None) + tuple(g * grad_output for g in ctx.grads)
12071207

12081208
self.ValFunction = ValFunction
12091209

test/test_ot.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,22 @@ def test_emd2_gradients():
126126
assert b1.shape == b1.grad.shape
127127
assert M1.shape == M1.grad.shape
128128

129+
# Testing for bug #309, checking for scaling of gradient
130+
a2 = torch.tensor(a, requires_grad=True)
131+
b2 = torch.tensor(a, requires_grad=True)
132+
M2 = torch.tensor(M, requires_grad=True)
133+
134+
val = 10.0 * ot.emd2(a2, b2, M2)
135+
136+
val.backward()
137+
138+
assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(),
139+
a2.grad.cpu().detach().numpy())
140+
assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(),
141+
b2.grad.cpu().detach().numpy())
142+
assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(),
143+
M2.grad.cpu().detach().numpy())
144+
129145

130146
def test_emd_emd2():
131147
# test emd and emd2 for simple identity

0 commit comments

Comments
 (0)