Skip to content

Commit 200322b

Browse files
framunozFrancisco Muñoz
and
Francisco Muñoz
authored
[MRG] Refactor: improve readibility in the convolutional module (#709)
* feat: add _get_convol_img_fn * refactor: add warning msg * refactor: encapsulate the report printing in a function * docs: add some documentation in the function _get_convol_img_fn * docs: add realise * refactor: change function _get_convol_img_fn for more clarity * refactor: run pre-commit * test: refactor tests to delete the error for unavailable backends * feat: delete not implemented error in convolutional module * revert the last two commits * docs: add comments with the reason of the error --------- Co-authored-by: Francisco Muñoz <[email protected]>
1 parent 39cd6ec commit 200322b

File tree

2 files changed

+70
-93
lines changed

2 files changed

+70
-93
lines changed

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ This release also contains few bug fixes, concerning the support of any metric i
4040
- Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663)
4141
- Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663)
4242
- Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)
43+
- Refactored `ot.bregman._convolutional` to improve readability (PR #709)
4344

4445
#### Closed issues
4546
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)

ot/bregman/_convolutional.py

+69-93
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,53 @@
1010

1111
import warnings
1212

13-
from ..utils import list_to_array
1413
from ..backend import get_backend
14+
from ..utils import list_to_array
15+
16+
_warning_msg = (
17+
"Convolutional Sinkhorn did not converge. "
18+
"Try a larger number of iterations `numItermax` "
19+
"or a larger entropy `reg`."
20+
)
21+
22+
23+
def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
24+
"""Return the convolution operator for 2D images.
25+
26+
The function constructed is equivalent to blurring on horizontal then vertical directions."""
27+
t1 = nx.linspace(0, 1, width, type_as=type_as)
28+
Y1, X1 = nx.meshgrid(t1, t1)
29+
M1 = -((X1 - Y1) ** 2) / reg
30+
31+
t2 = nx.linspace(0, 1, height, type_as=type_as)
32+
Y2, X2 = nx.meshgrid(t2, t2)
33+
M2 = -((X2 - Y2) ** 2) / reg
34+
35+
# If normal domain is selected, we can use M1 and M2 to compute the convolution
36+
if not log_domain:
37+
K1, K2 = nx.exp(M1), nx.exp(M2)
38+
39+
def convol_imgs(imgs):
40+
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
41+
kxy = nx.einsum("...ij,klj->kli", K2, kx)
42+
return kxy
43+
44+
# Else, we can use M1 and M2 to compute the convolution in log-domain
45+
else:
46+
47+
def convol_imgs(log_imgs):
48+
log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
49+
log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
50+
return log_imgs
51+
52+
return convol_imgs
53+
54+
55+
def _print_report(ii, err):
56+
"""Print the report of the iteration."""
57+
if ii % 200 == 0:
58+
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
59+
print("{:5d}|{:8e}|".format(ii, err))
1560

1661

1762
def convolutional_barycenter2d(
@@ -133,37 +178,26 @@ def _convolutional_barycenter2d(
133178
"""
134179

135180
A = list_to_array(A)
181+
n_hists, width, height = A.shape
136182

137183
nx = get_backend(A)
138184

139185
if weights is None:
140-
weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
186+
weights = nx.ones((n_hists,), type_as=A) / n_hists
141187
else:
142-
assert len(weights) == A.shape[0]
188+
assert len(weights) == n_hists
143189

144190
if log:
145191
log = {"err": []}
146192

147-
bar = nx.ones(A.shape[1:], type_as=A)
193+
bar = nx.ones((width, height), type_as=A)
148194
bar /= nx.sum(bar)
149195
U = nx.ones(A.shape, type_as=A)
150196
V = nx.ones(A.shape, type_as=A)
151197
err = 1
152198

153199
# build the convolution operator
154-
# this is equivalent to blurring on horizontal then vertical directions
155-
t = nx.linspace(0, 1, A.shape[1], type_as=A)
156-
[Y, X] = nx.meshgrid(t, t)
157-
K1 = nx.exp(-((X - Y) ** 2) / reg)
158-
159-
t = nx.linspace(0, 1, A.shape[2], type_as=A)
160-
[Y, X] = nx.meshgrid(t, t)
161-
K2 = nx.exp(-((X - Y) ** 2) / reg)
162-
163-
def convol_imgs(imgs):
164-
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
165-
kxy = nx.einsum("...ij,klj->kli", K2, kx)
166-
return kxy
200+
convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A)
167201

168202
KU = convol_imgs(U)
169203
for ii in range(numItermax):
@@ -177,24 +211,18 @@ def convol_imgs(imgs):
177211
# log and verbose print
178212
if log:
179213
log["err"].append(err)
180-
181214
if verbose:
182-
if ii % 200 == 0:
183-
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
184-
print("{:5d}|{:8e}|".format(ii, err))
215+
_print_report(ii, err)
185216
if err < stopThr:
186217
break
187218

188219
else:
189220
if warn:
190-
warnings.warn(
191-
"Convolutional Sinkhorn did not converge. "
192-
"Try a larger number of iterations `numItermax` "
193-
"or a larger entropy `reg`."
194-
)
221+
warnings.warn(_warning_msg)
195222
if log:
196223
log["niter"] = ii
197224
log["U"] = U
225+
log["V"] = V
198226
return bar, log
199227
else:
200228
return bar
@@ -218,6 +246,8 @@ def _convolutional_barycenter2d_log(
218246
A = list_to_array(A)
219247

220248
nx = get_backend(A)
249+
# This error is raised because we are using mutable assignment in the line
250+
# `log_KU[k] = ...` which is not allowed in Jax and TF.
221251
if nx.__name__ in ("jax", "tf"):
222252
raise NotImplementedError(
223253
"Log-domain functions are not yet implemented"
@@ -236,19 +266,7 @@ def _convolutional_barycenter2d_log(
236266

237267
err = 1
238268
# build the convolution operator
239-
# this is equivalent to blurring on horizontal then vertical directions
240-
t = nx.linspace(0, 1, width, type_as=A)
241-
[Y, X] = nx.meshgrid(t, t)
242-
M1 = -((X - Y) ** 2) / reg
243-
244-
t = nx.linspace(0, 1, height, type_as=A)
245-
[Y, X] = nx.meshgrid(t, t)
246-
M2 = -((X - Y) ** 2) / reg
247-
248-
def convol_img(log_img):
249-
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
250-
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
251-
return log_img
269+
convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True)
252270

253271
logA = nx.log(A + stabThr)
254272
log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
@@ -265,22 +283,15 @@ def convol_img(log_img):
265283
# log and verbose print
266284
if log:
267285
log["err"].append(err)
268-
269286
if verbose:
270-
if ii % 200 == 0:
271-
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
272-
print("{:5d}|{:8e}|".format(ii, err))
287+
_print_report(ii, err)
273288
if err < stopThr:
274289
break
275290
G = log_bar[None, :, :] - log_KU
276291

277292
else:
278293
if warn:
279-
warnings.warn(
280-
"Convolutional Sinkhorn did not converge. "
281-
"Try a larger number of iterations `numItermax` "
282-
"or a larger entropy `reg`."
283-
)
294+
warnings.warn(_warning_msg)
284295
if log:
285296
log["niter"] = ii
286297
return nx.exp(log_bar), log
@@ -417,23 +428,11 @@ def _convolutional_barycenter2d_debiased(
417428
bar /= width * height
418429
U = nx.ones(A.shape, type_as=A)
419430
V = nx.ones(A.shape, type_as=A)
420-
c = nx.ones(A.shape[1:], type_as=A)
431+
c = nx.ones((width, height), type_as=A)
421432
err = 1
422433

423434
# build the convolution operator
424-
# this is equivalent to blurring on horizontal then vertical directions
425-
t = nx.linspace(0, 1, width, type_as=A)
426-
[Y, X] = nx.meshgrid(t, t)
427-
K1 = nx.exp(-((X - Y) ** 2) / reg)
428-
429-
t = nx.linspace(0, 1, height, type_as=A)
430-
[Y, X] = nx.meshgrid(t, t)
431-
K2 = nx.exp(-((X - Y) ** 2) / reg)
432-
433-
def convol_imgs(imgs):
434-
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
435-
kxy = nx.einsum("...ij,klj->kli", K2, kx)
436-
return kxy
435+
convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A)
437436

438437
KU = convol_imgs(U)
439438
for ii in range(numItermax):
@@ -451,26 +450,20 @@ def convol_imgs(imgs):
451450
# log and verbose print
452451
if log:
453452
log["err"].append(err)
454-
455453
if verbose:
456-
if ii % 200 == 0:
457-
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
458-
print("{:5d}|{:8e}|".format(ii, err))
454+
_print_report(ii, err)
459455

460456
# debiased Sinkhorn does not converge monotonically
461457
# guarantee a few iterations are done before stopping
462458
if err < stopThr and ii > 20:
463459
break
464460
else:
465461
if warn:
466-
warnings.warn(
467-
"Sinkhorn did not converge. You might want to "
468-
"increase the number of iterations `numItermax` "
469-
"or the regularization parameter `reg`."
470-
)
462+
warnings.warn(_warning_msg)
471463
if log:
472464
log["niter"] = ii
473465
log["U"] = U
466+
log["V"] = V
474467
return bar, log
475468
else:
476469
return bar
@@ -492,6 +485,8 @@ def _convolutional_barycenter2d_debiased_log(
492485
A = list_to_array(A)
493486
n_hists, width, height = A.shape
494487
nx = get_backend(A)
488+
# This error is raised because we are using mutable assignment in the line
489+
# `log_KU[k] = ...` which is not allowed in Jax and TF.
495490
if nx.__name__ in ("jax", "tf"):
496491
raise NotImplementedError(
497492
"Log-domain functions are not yet implemented"
@@ -507,19 +502,7 @@ def _convolutional_barycenter2d_debiased_log(
507502

508503
err = 1
509504
# build the convolution operator
510-
# this is equivalent to blurring on horizontal then vertical directions
511-
t = nx.linspace(0, 1, width, type_as=A)
512-
[Y, X] = nx.meshgrid(t, t)
513-
M1 = -((X - Y) ** 2) / reg
514-
515-
t = nx.linspace(0, 1, height, type_as=A)
516-
[Y, X] = nx.meshgrid(t, t)
517-
M2 = -((X - Y) ** 2) / reg
518-
519-
def convol_img(log_img):
520-
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
521-
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
522-
return log_img
505+
convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True)
523506

524507
logA = nx.log(A + stabThr)
525508
log_bar, c = nx.zeros((2, width, height), type_as=A)
@@ -540,22 +523,15 @@ def convol_img(log_img):
540523
# log and verbose print
541524
if log:
542525
log["err"].append(err)
543-
544526
if verbose:
545-
if ii % 200 == 0:
546-
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
547-
print("{:5d}|{:8e}|".format(ii, err))
527+
_print_report(ii, err)
548528
if err < stopThr and ii > 20:
549529
break
550530
G = log_bar[None, :, :] - log_KU
551531

552532
else:
553533
if warn:
554-
warnings.warn(
555-
"Convolutional Sinkhorn did not converge. "
556-
"Try a larger number of iterations `numItermax` "
557-
"or a larger entropy `reg`."
558-
)
534+
warnings.warn(_warning_msg)
559535
if log:
560536
log["niter"] = ii
561537
return nx.exp(log_bar), log

0 commit comments

Comments
 (0)