Skip to content

Commit 263a36f

Browse files
authored
[MRG] Update pymanopt requirement and API for ot.dr (#443)
* updayte pymanopt API step 1 * add realease information * update requireents for tests on windows
1 parent a6d5d75 commit 263a36f

File tree

5 files changed

+21
-20
lines changed

5 files changed

+21
-20
lines changed

.github/requirements_test_windows.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ scipy>=1.3
33
cython
44
matplotlib
55
autograd
6-
pymanopt==0.2.4; python_version <'3'
7-
pymanopt==0.2.6rc1; python_version >= '3'
6+
pymanopt
87
cvxopt
98
scikit-learn
109
pytest

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
1616
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
1717
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
18+
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
19+
Pymanopt (PR #443)
1820

1921
#### Closed issues
2022

docs/requirements_rtd.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ scipy>=1.0
99
cython
1010
matplotlib
1111
autograd
12-
pymanopt==0.2.4; python_version <'3'
13-
pymanopt; python_version >= '3'
12+
pymanopt
1413
cvxopt
1514
scikit-learn

ot/dr.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from scipy import linalg
1919
import autograd.numpy as np
20-
from pymanopt.function import Autograd
21-
from pymanopt.manifolds import Stiefel
22-
from pymanopt import Problem
23-
from pymanopt.solvers import SteepestDescent, TrustRegions
20+
21+
import pymanopt
22+
import pymanopt.manifolds
23+
import pymanopt.optimizers
2424

2525

2626
def dist(x1, x2):
@@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k):
3838
ui = np.ones((M.shape[0],))
3939
vi = np.ones((M.shape[1],))
4040
for i in range(k):
41-
vi = w2 / (np.dot(K.T, ui))
42-
ui = w1 / (np.dot(K, vi))
41+
vi = w2 / (np.dot(K.T, ui) + 1e-50)
42+
ui = w1 / (np.dot(K, vi) + 1e-50)
4343
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
4444
return G
4545

@@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
222222
else:
223223
regmean = np.ones((len(xc), len(xc)))
224224

225-
@Autograd
225+
manifold = pymanopt.manifolds.Stiefel(d, p)
226+
227+
@pymanopt.function.autograd(manifold)
226228
def cost(P):
227229
# wda loss
228230
loss_b = 0
@@ -243,21 +245,21 @@ def cost(P):
243245
return loss_w / loss_b
244246

245247
# declare manifold and problem
246-
manifold = Stiefel(d, p)
247-
problem = Problem(manifold=manifold, cost=cost)
248+
249+
problem = pymanopt.Problem(manifold=manifold, cost=cost)
248250

249251
# declare solver and solve
250252
if solver is None:
251-
solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
253+
solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose)
252254
elif solver in ['tr', 'TrustRegions']:
253-
solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
255+
solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose)
254256

255-
Popt = solver.solve(problem, x=P0)
257+
Popt = solver.run(problem, initial_point=P0)
256258

257259
def proj(X):
258-
return (X - mx.reshape((1, -1))).dot(Popt)
260+
return (X - mx.reshape((1, -1))).dot(Popt.point)
259261

260-
return Popt, proj
262+
return Popt.point, proj
261263

262264

263265
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ numpy>=1.20
22
scipy>=1.3
33
matplotlib
44
autograd
5-
pymanopt==0.2.4; python_version <'3'
6-
pymanopt==0.2.6rc1; python_version >= '3'
5+
pymanopt
76
cvxopt
87
scikit-learn
98
torch

0 commit comments

Comments
 (0)