1010
1111import  warnings 
1212
13- from  ..utils  import  list_to_array 
1413from  ..backend  import  get_backend 
14+ from  ..utils  import  list_to_array 
15+ 
16+ _warning_msg  =  (
17+     "Convolutional Sinkhorn did not converge. " 
18+     "Try a larger number of iterations `numItermax` " 
19+     "or a larger entropy `reg`." 
20+ )
21+ 
22+ 
23+ def  _get_convol_img_fn (nx , width , height , reg , type_as , log_domain = False ):
24+     """Return the convolution operator for 2D images. 
25+ 
26+     The function constructed is equivalent to blurring on horizontal then vertical directions.""" 
27+     t1  =  nx .linspace (0 , 1 , width , type_as = type_as )
28+     Y1 , X1  =  nx .meshgrid (t1 , t1 )
29+     M1  =  - ((X1  -  Y1 ) **  2 ) /  reg 
30+ 
31+     t2  =  nx .linspace (0 , 1 , height , type_as = type_as )
32+     Y2 , X2  =  nx .meshgrid (t2 , t2 )
33+     M2  =  - ((X2  -  Y2 ) **  2 ) /  reg 
34+ 
35+     # If normal domain is selected, we can use M1 and M2 to compute the convolution 
36+     if  not  log_domain :
37+         K1 , K2  =  nx .exp (M1 ), nx .exp (M2 )
38+ 
39+         def  convol_imgs (imgs ):
40+             kx  =  nx .einsum ("...ij,kjl->kil" , K1 , imgs )
41+             kxy  =  nx .einsum ("...ij,klj->kli" , K2 , kx )
42+             return  kxy 
43+ 
44+     # Else, we can use M1 and M2 to compute the convolution in log-domain 
45+     else :
46+ 
47+         def  convol_imgs (log_imgs ):
48+             log_imgs  =  nx .logsumexp (M1 [:, :, None ] +  log_imgs [None ], axis = 1 )
49+             log_imgs  =  nx .logsumexp (M2 [:, :, None ] +  log_imgs .T [None ], axis = 1 ).T 
50+             return  log_imgs 
51+ 
52+     return  convol_imgs 
53+ 
54+ 
55+ def  _print_report (ii , err ):
56+     """Print the report of the iteration.""" 
57+     if  ii  %  200  ==  0 :
58+         print ("{:5s}|{:12s}" .format ("It." , "Err" ) +  "\n "  +  "-"  *  19 )
59+     print ("{:5d}|{:8e}|" .format (ii , err ))
1560
1661
1762def  convolutional_barycenter2d (
@@ -133,37 +178,26 @@ def _convolutional_barycenter2d(
133178    """ 
134179
135180    A  =  list_to_array (A )
181+     n_hists , width , height  =  A .shape 
136182
137183    nx  =  get_backend (A )
138184
139185    if  weights  is  None :
140-         weights  =  nx .ones ((A . shape [ 0 ] ,), type_as = A ) /  A . shape [ 0 ] 
186+         weights  =  nx .ones ((n_hists ,), type_as = A ) /  n_hists 
141187    else :
142-         assert  len (weights ) ==  A . shape [ 0 ] 
188+         assert  len (weights ) ==  n_hists 
143189
144190    if  log :
145191        log  =  {"err" : []}
146192
147-     bar  =  nx .ones (A . shape [ 1 :] , type_as = A )
193+     bar  =  nx .ones (( width ,  height ) , type_as = A )
148194    bar  /=  nx .sum (bar )
149195    U  =  nx .ones (A .shape , type_as = A )
150196    V  =  nx .ones (A .shape , type_as = A )
151197    err  =  1 
152198
153199    # build the convolution operator 
154-     # this is equivalent to blurring on horizontal then vertical directions 
155-     t  =  nx .linspace (0 , 1 , A .shape [1 ], type_as = A )
156-     [Y , X ] =  nx .meshgrid (t , t )
157-     K1  =  nx .exp (- ((X  -  Y ) **  2 ) /  reg )
158- 
159-     t  =  nx .linspace (0 , 1 , A .shape [2 ], type_as = A )
160-     [Y , X ] =  nx .meshgrid (t , t )
161-     K2  =  nx .exp (- ((X  -  Y ) **  2 ) /  reg )
162- 
163-     def  convol_imgs (imgs ):
164-         kx  =  nx .einsum ("...ij,kjl->kil" , K1 , imgs )
165-         kxy  =  nx .einsum ("...ij,klj->kli" , K2 , kx )
166-         return  kxy 
200+     convol_imgs  =  _get_convol_img_fn (nx , width , height , reg , type_as = A )
167201
168202    KU  =  convol_imgs (U )
169203    for  ii  in  range (numItermax ):
@@ -177,24 +211,18 @@ def convol_imgs(imgs):
177211            # log and verbose print 
178212            if  log :
179213                log ["err" ].append (err )
180- 
181214            if  verbose :
182-                 if  ii  %  200  ==  0 :
183-                     print ("{:5s}|{:12s}" .format ("It." , "Err" ) +  "\n "  +  "-"  *  19 )
184-                 print ("{:5d}|{:8e}|" .format (ii , err ))
215+                 _print_report (ii , err )
185216            if  err  <  stopThr :
186217                break 
187218
188219    else :
189220        if  warn :
190-             warnings .warn (
191-                 "Convolutional Sinkhorn did not converge. " 
192-                 "Try a larger number of iterations `numItermax` " 
193-                 "or a larger entropy `reg`." 
194-             )
221+             warnings .warn (_warning_msg )
195222    if  log :
196223        log ["niter" ] =  ii 
197224        log ["U" ] =  U 
225+         log ["V" ] =  V 
198226        return  bar , log 
199227    else :
200228        return  bar 
@@ -218,6 +246,8 @@ def _convolutional_barycenter2d_log(
218246    A  =  list_to_array (A )
219247
220248    nx  =  get_backend (A )
249+     # This error is raised because we are using mutable assignment in the line 
250+     #  `log_KU[k] = ...` which is not allowed in Jax and TF. 
221251    if  nx .__name__  in  ("jax" , "tf" ):
222252        raise  NotImplementedError (
223253            "Log-domain functions are not yet implemented" 
@@ -236,19 +266,7 @@ def _convolutional_barycenter2d_log(
236266
237267    err  =  1 
238268    # build the convolution operator 
239-     # this is equivalent to blurring on horizontal then vertical directions 
240-     t  =  nx .linspace (0 , 1 , width , type_as = A )
241-     [Y , X ] =  nx .meshgrid (t , t )
242-     M1  =  - ((X  -  Y ) **  2 ) /  reg 
243- 
244-     t  =  nx .linspace (0 , 1 , height , type_as = A )
245-     [Y , X ] =  nx .meshgrid (t , t )
246-     M2  =  - ((X  -  Y ) **  2 ) /  reg 
247- 
248-     def  convol_img (log_img ):
249-         log_img  =  nx .logsumexp (M1 [:, :, None ] +  log_img [None ], axis = 1 )
250-         log_img  =  nx .logsumexp (M2 [:, :, None ] +  log_img .T [None ], axis = 1 ).T 
251-         return  log_img 
269+     convol_img  =  _get_convol_img_fn (nx , width , height , reg , type_as = A , log_domain = True )
252270
253271    logA  =  nx .log (A  +  stabThr )
254272    log_KU , G , F  =  nx .zeros ((3 , * logA .shape ), type_as = A )
@@ -265,22 +283,15 @@ def convol_img(log_img):
265283            # log and verbose print 
266284            if  log :
267285                log ["err" ].append (err )
268- 
269286            if  verbose :
270-                 if  ii  %  200  ==  0 :
271-                     print ("{:5s}|{:12s}" .format ("It." , "Err" ) +  "\n "  +  "-"  *  19 )
272-                 print ("{:5d}|{:8e}|" .format (ii , err ))
287+                 _print_report (ii , err )
273288            if  err  <  stopThr :
274289                break 
275290        G  =  log_bar [None , :, :] -  log_KU 
276291
277292    else :
278293        if  warn :
279-             warnings .warn (
280-                 "Convolutional Sinkhorn did not converge. " 
281-                 "Try a larger number of iterations `numItermax` " 
282-                 "or a larger entropy `reg`." 
283-             )
294+             warnings .warn (_warning_msg )
284295    if  log :
285296        log ["niter" ] =  ii 
286297        return  nx .exp (log_bar ), log 
@@ -417,23 +428,11 @@ def _convolutional_barycenter2d_debiased(
417428    bar  /=  width  *  height 
418429    U  =  nx .ones (A .shape , type_as = A )
419430    V  =  nx .ones (A .shape , type_as = A )
420-     c  =  nx .ones (A . shape [ 1 :] , type_as = A )
431+     c  =  nx .ones (( width ,  height ) , type_as = A )
421432    err  =  1 
422433
423434    # build the convolution operator 
424-     # this is equivalent to blurring on horizontal then vertical directions 
425-     t  =  nx .linspace (0 , 1 , width , type_as = A )
426-     [Y , X ] =  nx .meshgrid (t , t )
427-     K1  =  nx .exp (- ((X  -  Y ) **  2 ) /  reg )
428- 
429-     t  =  nx .linspace (0 , 1 , height , type_as = A )
430-     [Y , X ] =  nx .meshgrid (t , t )
431-     K2  =  nx .exp (- ((X  -  Y ) **  2 ) /  reg )
432- 
433-     def  convol_imgs (imgs ):
434-         kx  =  nx .einsum ("...ij,kjl->kil" , K1 , imgs )
435-         kxy  =  nx .einsum ("...ij,klj->kli" , K2 , kx )
436-         return  kxy 
435+     convol_imgs  =  _get_convol_img_fn (nx , width , height , reg , type_as = A )
437436
438437    KU  =  convol_imgs (U )
439438    for  ii  in  range (numItermax ):
@@ -451,26 +450,20 @@ def convol_imgs(imgs):
451450            # log and verbose print 
452451            if  log :
453452                log ["err" ].append (err )
454- 
455453            if  verbose :
456-                 if  ii  %  200  ==  0 :
457-                     print ("{:5s}|{:12s}" .format ("It." , "Err" ) +  "\n "  +  "-"  *  19 )
458-                 print ("{:5d}|{:8e}|" .format (ii , err ))
454+                 _print_report (ii , err )
459455
460456            # debiased Sinkhorn does not converge monotonically 
461457            # guarantee a few iterations are done before stopping 
462458            if  err  <  stopThr  and  ii  >  20 :
463459                break 
464460    else :
465461        if  warn :
466-             warnings .warn (
467-                 "Sinkhorn did not converge. You might want to " 
468-                 "increase the number of iterations `numItermax` " 
469-                 "or the regularization parameter `reg`." 
470-             )
462+             warnings .warn (_warning_msg )
471463    if  log :
472464        log ["niter" ] =  ii 
473465        log ["U" ] =  U 
466+         log ["V" ] =  V 
474467        return  bar , log 
475468    else :
476469        return  bar 
@@ -492,6 +485,8 @@ def _convolutional_barycenter2d_debiased_log(
492485    A  =  list_to_array (A )
493486    n_hists , width , height  =  A .shape 
494487    nx  =  get_backend (A )
488+     # This error is raised because we are using mutable assignment in the line 
489+     #  `log_KU[k] = ...` which is not allowed in Jax and TF. 
495490    if  nx .__name__  in  ("jax" , "tf" ):
496491        raise  NotImplementedError (
497492            "Log-domain functions are not yet implemented" 
@@ -507,19 +502,7 @@ def _convolutional_barycenter2d_debiased_log(
507502
508503    err  =  1 
509504    # build the convolution operator 
510-     # this is equivalent to blurring on horizontal then vertical directions 
511-     t  =  nx .linspace (0 , 1 , width , type_as = A )
512-     [Y , X ] =  nx .meshgrid (t , t )
513-     M1  =  - ((X  -  Y ) **  2 ) /  reg 
514- 
515-     t  =  nx .linspace (0 , 1 , height , type_as = A )
516-     [Y , X ] =  nx .meshgrid (t , t )
517-     M2  =  - ((X  -  Y ) **  2 ) /  reg 
518- 
519-     def  convol_img (log_img ):
520-         log_img  =  nx .logsumexp (M1 [:, :, None ] +  log_img [None ], axis = 1 )
521-         log_img  =  nx .logsumexp (M2 [:, :, None ] +  log_img .T [None ], axis = 1 ).T 
522-         return  log_img 
505+     convol_img  =  _get_convol_img_fn (nx , width , height , reg , type_as = A , log_domain = True )
523506
524507    logA  =  nx .log (A  +  stabThr )
525508    log_bar , c  =  nx .zeros ((2 , width , height ), type_as = A )
@@ -540,22 +523,15 @@ def convol_img(log_img):
540523            # log and verbose print 
541524            if  log :
542525                log ["err" ].append (err )
543- 
544526            if  verbose :
545-                 if  ii  %  200  ==  0 :
546-                     print ("{:5s}|{:12s}" .format ("It." , "Err" ) +  "\n "  +  "-"  *  19 )
547-                 print ("{:5d}|{:8e}|" .format (ii , err ))
527+                 _print_report (ii , err )
548528            if  err  <  stopThr  and  ii  >  20 :
549529                break 
550530        G  =  log_bar [None , :, :] -  log_KU 
551531
552532    else :
553533        if  warn :
554-             warnings .warn (
555-                 "Convolutional Sinkhorn did not converge. " 
556-                 "Try a larger number of iterations `numItermax` " 
557-                 "or a larger entropy `reg`." 
558-             )
534+             warnings .warn (_warning_msg )
559535    if  log :
560536        log ["niter" ] =  ii 
561537        return  nx .exp (log_bar ), log 
0 commit comments