10
10
11
11
import warnings
12
12
13
- from ..utils import list_to_array
14
13
from ..backend import get_backend
14
+ from ..utils import list_to_array
15
+
16
+ _warning_msg = (
17
+ "Convolutional Sinkhorn did not converge. "
18
+ "Try a larger number of iterations `numItermax` "
19
+ "or a larger entropy `reg`."
20
+ )
21
+
22
+
23
+ def _get_convol_img_fn (nx , width , height , reg , type_as , log_domain = False ):
24
+ """Return the convolution operator for 2D images.
25
+
26
+ The function constructed is equivalent to blurring on horizontal then vertical directions."""
27
+ t1 = nx .linspace (0 , 1 , width , type_as = type_as )
28
+ Y1 , X1 = nx .meshgrid (t1 , t1 )
29
+ M1 = - ((X1 - Y1 ) ** 2 ) / reg
30
+
31
+ t2 = nx .linspace (0 , 1 , height , type_as = type_as )
32
+ Y2 , X2 = nx .meshgrid (t2 , t2 )
33
+ M2 = - ((X2 - Y2 ) ** 2 ) / reg
34
+
35
+ # If normal domain is selected, we can use M1 and M2 to compute the convolution
36
+ if not log_domain :
37
+ K1 , K2 = nx .exp (M1 ), nx .exp (M2 )
38
+
39
+ def convol_imgs (imgs ):
40
+ kx = nx .einsum ("...ij,kjl->kil" , K1 , imgs )
41
+ kxy = nx .einsum ("...ij,klj->kli" , K2 , kx )
42
+ return kxy
43
+
44
+ # Else, we can use M1 and M2 to compute the convolution in log-domain
45
+ else :
46
+
47
+ def convol_imgs (log_imgs ):
48
+ log_imgs = nx .logsumexp (M1 [:, :, None ] + log_imgs [None ], axis = 1 )
49
+ log_imgs = nx .logsumexp (M2 [:, :, None ] + log_imgs .T [None ], axis = 1 ).T
50
+ return log_imgs
51
+
52
+ return convol_imgs
53
+
54
+
55
+ def _print_report (ii , err ):
56
+ """Print the report of the iteration."""
57
+ if ii % 200 == 0 :
58
+ print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
59
+ print ("{:5d}|{:8e}|" .format (ii , err ))
15
60
16
61
17
62
def convolutional_barycenter2d (
@@ -133,37 +178,26 @@ def _convolutional_barycenter2d(
133
178
"""
134
179
135
180
A = list_to_array (A )
181
+ n_hists , width , height = A .shape
136
182
137
183
nx = get_backend (A )
138
184
139
185
if weights is None :
140
- weights = nx .ones ((A . shape [ 0 ] ,), type_as = A ) / A . shape [ 0 ]
186
+ weights = nx .ones ((n_hists ,), type_as = A ) / n_hists
141
187
else :
142
- assert len (weights ) == A . shape [ 0 ]
188
+ assert len (weights ) == n_hists
143
189
144
190
if log :
145
191
log = {"err" : []}
146
192
147
- bar = nx .ones (A . shape [ 1 :] , type_as = A )
193
+ bar = nx .ones (( width , height ) , type_as = A )
148
194
bar /= nx .sum (bar )
149
195
U = nx .ones (A .shape , type_as = A )
150
196
V = nx .ones (A .shape , type_as = A )
151
197
err = 1
152
198
153
199
# build the convolution operator
154
- # this is equivalent to blurring on horizontal then vertical directions
155
- t = nx .linspace (0 , 1 , A .shape [1 ], type_as = A )
156
- [Y , X ] = nx .meshgrid (t , t )
157
- K1 = nx .exp (- ((X - Y ) ** 2 ) / reg )
158
-
159
- t = nx .linspace (0 , 1 , A .shape [2 ], type_as = A )
160
- [Y , X ] = nx .meshgrid (t , t )
161
- K2 = nx .exp (- ((X - Y ) ** 2 ) / reg )
162
-
163
- def convol_imgs (imgs ):
164
- kx = nx .einsum ("...ij,kjl->kil" , K1 , imgs )
165
- kxy = nx .einsum ("...ij,klj->kli" , K2 , kx )
166
- return kxy
200
+ convol_imgs = _get_convol_img_fn (nx , width , height , reg , type_as = A )
167
201
168
202
KU = convol_imgs (U )
169
203
for ii in range (numItermax ):
@@ -177,24 +211,18 @@ def convol_imgs(imgs):
177
211
# log and verbose print
178
212
if log :
179
213
log ["err" ].append (err )
180
-
181
214
if verbose :
182
- if ii % 200 == 0 :
183
- print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
184
- print ("{:5d}|{:8e}|" .format (ii , err ))
215
+ _print_report (ii , err )
185
216
if err < stopThr :
186
217
break
187
218
188
219
else :
189
220
if warn :
190
- warnings .warn (
191
- "Convolutional Sinkhorn did not converge. "
192
- "Try a larger number of iterations `numItermax` "
193
- "or a larger entropy `reg`."
194
- )
221
+ warnings .warn (_warning_msg )
195
222
if log :
196
223
log ["niter" ] = ii
197
224
log ["U" ] = U
225
+ log ["V" ] = V
198
226
return bar , log
199
227
else :
200
228
return bar
@@ -218,6 +246,8 @@ def _convolutional_barycenter2d_log(
218
246
A = list_to_array (A )
219
247
220
248
nx = get_backend (A )
249
+ # This error is raised because we are using mutable assignment in the line
250
+ # `log_KU[k] = ...` which is not allowed in Jax and TF.
221
251
if nx .__name__ in ("jax" , "tf" ):
222
252
raise NotImplementedError (
223
253
"Log-domain functions are not yet implemented"
@@ -236,19 +266,7 @@ def _convolutional_barycenter2d_log(
236
266
237
267
err = 1
238
268
# build the convolution operator
239
- # this is equivalent to blurring on horizontal then vertical directions
240
- t = nx .linspace (0 , 1 , width , type_as = A )
241
- [Y , X ] = nx .meshgrid (t , t )
242
- M1 = - ((X - Y ) ** 2 ) / reg
243
-
244
- t = nx .linspace (0 , 1 , height , type_as = A )
245
- [Y , X ] = nx .meshgrid (t , t )
246
- M2 = - ((X - Y ) ** 2 ) / reg
247
-
248
- def convol_img (log_img ):
249
- log_img = nx .logsumexp (M1 [:, :, None ] + log_img [None ], axis = 1 )
250
- log_img = nx .logsumexp (M2 [:, :, None ] + log_img .T [None ], axis = 1 ).T
251
- return log_img
269
+ convol_img = _get_convol_img_fn (nx , width , height , reg , type_as = A , log_domain = True )
252
270
253
271
logA = nx .log (A + stabThr )
254
272
log_KU , G , F = nx .zeros ((3 , * logA .shape ), type_as = A )
@@ -265,22 +283,15 @@ def convol_img(log_img):
265
283
# log and verbose print
266
284
if log :
267
285
log ["err" ].append (err )
268
-
269
286
if verbose :
270
- if ii % 200 == 0 :
271
- print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
272
- print ("{:5d}|{:8e}|" .format (ii , err ))
287
+ _print_report (ii , err )
273
288
if err < stopThr :
274
289
break
275
290
G = log_bar [None , :, :] - log_KU
276
291
277
292
else :
278
293
if warn :
279
- warnings .warn (
280
- "Convolutional Sinkhorn did not converge. "
281
- "Try a larger number of iterations `numItermax` "
282
- "or a larger entropy `reg`."
283
- )
294
+ warnings .warn (_warning_msg )
284
295
if log :
285
296
log ["niter" ] = ii
286
297
return nx .exp (log_bar ), log
@@ -417,23 +428,11 @@ def _convolutional_barycenter2d_debiased(
417
428
bar /= width * height
418
429
U = nx .ones (A .shape , type_as = A )
419
430
V = nx .ones (A .shape , type_as = A )
420
- c = nx .ones (A . shape [ 1 :] , type_as = A )
431
+ c = nx .ones (( width , height ) , type_as = A )
421
432
err = 1
422
433
423
434
# build the convolution operator
424
- # this is equivalent to blurring on horizontal then vertical directions
425
- t = nx .linspace (0 , 1 , width , type_as = A )
426
- [Y , X ] = nx .meshgrid (t , t )
427
- K1 = nx .exp (- ((X - Y ) ** 2 ) / reg )
428
-
429
- t = nx .linspace (0 , 1 , height , type_as = A )
430
- [Y , X ] = nx .meshgrid (t , t )
431
- K2 = nx .exp (- ((X - Y ) ** 2 ) / reg )
432
-
433
- def convol_imgs (imgs ):
434
- kx = nx .einsum ("...ij,kjl->kil" , K1 , imgs )
435
- kxy = nx .einsum ("...ij,klj->kli" , K2 , kx )
436
- return kxy
435
+ convol_imgs = _get_convol_img_fn (nx , width , height , reg , type_as = A )
437
436
438
437
KU = convol_imgs (U )
439
438
for ii in range (numItermax ):
@@ -451,26 +450,20 @@ def convol_imgs(imgs):
451
450
# log and verbose print
452
451
if log :
453
452
log ["err" ].append (err )
454
-
455
453
if verbose :
456
- if ii % 200 == 0 :
457
- print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
458
- print ("{:5d}|{:8e}|" .format (ii , err ))
454
+ _print_report (ii , err )
459
455
460
456
# debiased Sinkhorn does not converge monotonically
461
457
# guarantee a few iterations are done before stopping
462
458
if err < stopThr and ii > 20 :
463
459
break
464
460
else :
465
461
if warn :
466
- warnings .warn (
467
- "Sinkhorn did not converge. You might want to "
468
- "increase the number of iterations `numItermax` "
469
- "or the regularization parameter `reg`."
470
- )
462
+ warnings .warn (_warning_msg )
471
463
if log :
472
464
log ["niter" ] = ii
473
465
log ["U" ] = U
466
+ log ["V" ] = V
474
467
return bar , log
475
468
else :
476
469
return bar
@@ -492,6 +485,8 @@ def _convolutional_barycenter2d_debiased_log(
492
485
A = list_to_array (A )
493
486
n_hists , width , height = A .shape
494
487
nx = get_backend (A )
488
+ # This error is raised because we are using mutable assignment in the line
489
+ # `log_KU[k] = ...` which is not allowed in Jax and TF.
495
490
if nx .__name__ in ("jax" , "tf" ):
496
491
raise NotImplementedError (
497
492
"Log-domain functions are not yet implemented"
@@ -507,19 +502,7 @@ def _convolutional_barycenter2d_debiased_log(
507
502
508
503
err = 1
509
504
# build the convolution operator
510
- # this is equivalent to blurring on horizontal then vertical directions
511
- t = nx .linspace (0 , 1 , width , type_as = A )
512
- [Y , X ] = nx .meshgrid (t , t )
513
- M1 = - ((X - Y ) ** 2 ) / reg
514
-
515
- t = nx .linspace (0 , 1 , height , type_as = A )
516
- [Y , X ] = nx .meshgrid (t , t )
517
- M2 = - ((X - Y ) ** 2 ) / reg
518
-
519
- def convol_img (log_img ):
520
- log_img = nx .logsumexp (M1 [:, :, None ] + log_img [None ], axis = 1 )
521
- log_img = nx .logsumexp (M2 [:, :, None ] + log_img .T [None ], axis = 1 ).T
522
- return log_img
505
+ convol_img = _get_convol_img_fn (nx , width , height , reg , type_as = A , log_domain = True )
523
506
524
507
logA = nx .log (A + stabThr )
525
508
log_bar , c = nx .zeros ((2 , width , height ), type_as = A )
@@ -540,22 +523,15 @@ def convol_img(log_img):
540
523
# log and verbose print
541
524
if log :
542
525
log ["err" ].append (err )
543
-
544
526
if verbose :
545
- if ii % 200 == 0 :
546
- print ("{:5s}|{:12s}" .format ("It." , "Err" ) + "\n " + "-" * 19 )
547
- print ("{:5d}|{:8e}|" .format (ii , err ))
527
+ _print_report (ii , err )
548
528
if err < stopThr and ii > 20 :
549
529
break
550
530
G = log_bar [None , :, :] - log_KU
551
531
552
532
else :
553
533
if warn :
554
- warnings .warn (
555
- "Convolutional Sinkhorn did not converge. "
556
- "Try a larger number of iterations `numItermax` "
557
- "or a larger entropy `reg`."
558
- )
534
+ warnings .warn (_warning_msg )
559
535
if log :
560
536
log ["niter" ] = ii
561
537
return nx .exp (log_bar ), log
0 commit comments