Skip to content

Commit 0138dcf

Browse files
authored
[MRG] Solve example throwing an error when executed on a GPU (#391)
* Solve example throwing an error when executed on a GPU * add PR to releases.md * update pep8 command * pep8
1 parent 818c7ac commit 0138dcf

File tree

6 files changed

+9
-8
lines changed

6 files changed

+9
-8
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ clean : FORCE
4242
$(PYTHON) setup.py clean
4343

4444
pep8 :
45-
flake8 examples/ ot/ test/
45+
flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics --show-source
4646

4747
test : FORCE pep8
4848
$(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
incomplete transport plan above a certain size (slightly above 46k, its square being
2020
roughly 2^31) (PR #381)
2121
- Error raised when mass mismatch in emd2 (PR #386)
22+
- Fixed an issue where a pytorch example would throw an error if executed on a GPU (Issue #389, PR #391)
2223

2324

2425
## 0.8.2

examples/backends/plot_sliced_wass_grad_flow_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
loss_iter = []
7575

7676
# generator for random permutations
77-
gen = torch.Generator()
77+
gen = torch.Generator(device=device)
7878
gen.manual_seed(42)
7979

8080
for i in range(nb_iter_max):
@@ -136,7 +136,7 @@ def _update_plot(i):
136136
loss_iter = []
137137

138138
# generator for random permutations
139-
gen = torch.Generator()
139+
gen = torch.Generator(device=device)
140140
gen.manual_seed(42)
141141

142142
alpha = 0.5

ot/gromov.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,7 +1581,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
15811581

15821582
error = []
15831583

1584-
while(err > tol and cpt < max_iter):
1584+
while (err > tol and cpt < max_iter):
15851585
Cprev = C
15861586

15871587
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
@@ -1725,7 +1725,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
17251725
log_['err_structure'] = []
17261726
log_['Ts_iter'] = []
17271727

1728-
while((err_feature > tol or err_structure > tol) and cpt < max_iter):
1728+
while ((err_feature > tol or err_structure > tol) and cpt < max_iter):
17291729
Cprev = C
17301730
Xprev = X
17311731

ot/lp/cvx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
8080
if weights is None:
8181
weights = np.ones(A.shape[1]) / A.shape[1]
8282
else:
83-
assert(len(weights) == A.shape[1])
83+
assert len(weights) == A.shape[1]
8484

8585
n_distributions = A.shape[1]
8686
n = A.shape[0]

ot/unbalanced.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
734734
if weights is None:
735735
weights = nx.ones(n_hists, type_as=A) / n_hists
736736
else:
737-
assert(len(weights) == A.shape[1])
737+
assert len(weights) == A.shape[1]
738738

739739
if log:
740740
log = {'err': []}
@@ -882,7 +882,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
882882
if weights is None:
883883
weights = nx.ones(n_hists, type_as=A) / n_hists
884884
else:
885-
assert(len(weights) == A.shape[1])
885+
assert len(weights) == A.shape[1]
886886

887887
if log:
888888
log = {'err': []}

0 commit comments

Comments
 (0)