@@ -199,14 +199,12 @@ def free_support_barycenter(
199
199
measures_weights : list of N (k_i,) array-like
200
200
Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
201
201
representing the weights of each discrete input measure
202
-
203
202
X_init : (k,d) array-like
204
203
Initialization of the support locations (on `k` atoms) of the barycenter
205
204
b : (k,) array-like
206
205
Initialization of the weights of the barycenter (non-negatives, sum to 1)
207
206
weights : (N,) array-like
208
207
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
209
-
210
208
numItermax : int, optional
211
209
Max number of iterations
212
210
stopThr : float, optional
@@ -219,13 +217,11 @@ def free_support_barycenter(
219
217
If compiled with OpenMP, chooses the number of threads to parallelize.
220
218
"max" selects the highest number possible.
221
219
222
-
223
220
Returns
224
221
-------
225
222
X : (k,d) array-like
226
223
Support locations (on k atoms) of the barycenter
227
224
228
-
229
225
.. _references-free-support-barycenter:
230
226
References
231
227
----------
@@ -428,20 +424,20 @@ def generalized_free_support_barycenter(
428
424
return Y
429
425
430
426
431
- class StoppingCriterionReached (Exception ):
432
- pass
433
-
434
-
435
427
def free_support_barycenter_generic_costs (
436
428
measure_locations ,
437
429
measure_weights ,
438
430
X_init ,
439
431
cost_list ,
440
- B ,
432
+ ground_bary = None ,
441
433
a = None ,
442
434
numItermax = 100 ,
443
435
stopThr = 1e-5 ,
444
436
log = False ,
437
+ ground_bary_lr = 1e-2 ,
438
+ ground_bary_numItermax = 100 ,
439
+ ground_bary_stopThr = 1e-5 ,
440
+ ground_bary_solver = "SGD" ,
445
441
):
446
442
r"""
447
443
Solves the OT barycenter problem for generic costs using the fixed point
@@ -507,14 +503,15 @@ def free_support_barycenter_generic_costs(
507
503
List of K arrays of measure weights, each of shape (m_k).
508
504
X_init : array-like
509
505
Array of shape (n, d) representing initial barycenter points.
510
- cost_list : list of callable
506
+ cost_list : list of callable or callable
511
507
List of K cost functions :math:`c_k: \mathbb{R}^{n\times
512
508
d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times
513
- m_k}`.
514
- B : callable
509
+ m_k}`. If cost_list is a single callable, the same cost is used K times.
510
+ ground_bary : callable or None, optional
515
511
Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays
516
512
of shape (n\times d_K), computing the ground barycenters (broadcasted
517
- over n).
513
+ over n). If not provided, done with Adam on PyTorch (requires PyTorch
514
+ backend)
518
515
a : array-like, optional
519
516
Array of shape (n,) representing weights of the barycenter
520
517
measure.Defaults to uniform.
@@ -524,6 +521,16 @@ def free_support_barycenter_generic_costs(
524
521
If the iterations move less than this, terminate (default is 1e-5).
525
522
log : bool, optional
526
523
Whether to return the log dictionary (default is False).
524
+ ground_bary_lr : float, optional
525
+ Learning rate for the ground barycenter solver (if auto is used).
526
+ ground_bary_numItermax : int, optional
527
+ Maximum number of iterations for the ground barycenter solver (if auto
528
+ is used).
529
+ ground_bary_stopThr : float, optional
530
+ Stop threshold for the ground barycenter solver (if auto is used).
531
+ ground_bary_solver : str, optional
532
+ Solver for auto ground bary solver (torch SGD or Adam). Default is
533
+ "SGD".
527
534
528
535
Returns
529
536
-------
@@ -549,49 +556,85 @@ def free_support_barycenter_generic_costs(
549
556
See Also
550
557
--------
551
558
ot.lp.free_support_barycenter : Free support solver for the case where
552
- :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter :
553
- Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2`
554
- with :math:`P_k` linear.
559
+ :math:`c_k(x,y) = \lambda_k\ |x-y\|_2^2`.
560
+ ot.lp.generalized_free_support_barycenter : Free support solver for the case
561
+ where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
555
562
"""
556
563
nx = get_backend (X_init , measure_locations [0 ])
557
564
K = len (measure_locations )
558
565
n = X_init .shape [0 ]
559
566
if a is None :
560
567
a = nx .ones (n , type_as = X_init ) / n
568
+ if callable (cost_list ): # use the given cost for all K pairs
569
+ cost_list = [cost_list ] * K
570
+ auto_ground_bary = False
571
+
572
+ if ground_bary is None :
573
+ auto_ground_bary = True
574
+ assert str (nx ) == "torch" , (
575
+ f"Backend { str (nx )} is not compatible with ground_bary=None, it"
576
+ "must be provided if not using PyTorch backend"
577
+ )
578
+ try :
579
+ import torch
580
+ from torch .optim import Adam , SGD
581
+
582
+ def ground_bary (y , x_init ):
583
+ x = x_init .clone ().detach ().requires_grad_ (True )
584
+ solver = Adam if ground_bary_solver == "Adam" else SGD
585
+ opt = solver ([x ], lr = ground_bary_lr )
586
+ for _ in range (ground_bary_numItermax ):
587
+ x_prev = x .data .clone ()
588
+ opt .zero_grad ()
589
+ # inefficient cost computation but compatible
590
+ # with the choice of cost_list[k] giving the cost matrix
591
+ loss = torch .sum (
592
+ torch .stack (
593
+ [torch .diag (cost_list [k ](x , y [k ])) for k in range (K )]
594
+ )
595
+ )
596
+ loss .backward ()
597
+ opt .step ()
598
+ diff = torch .sum ((x .data - x_prev ) ** 2 )
599
+ if diff < ground_bary_stopThr :
600
+ break
601
+ return x .detach ()
602
+
603
+ except ImportError :
604
+ raise ImportError ("PyTorch is required to use ground_bary=None" )
605
+
561
606
X_list = [X_init ] if log else [] # store the iterations
562
607
X = X_init
563
608
dX_list = [] # store the displacement squared norms
564
- exit_status = "Unknown"
565
-
566
- try :
567
- for _ in range (numItermax ):
568
- pi_list = [ # compute the pairwise transport plans
569
- emd (a , measure_weights [k ], cost_list [k ](X , measure_locations [k ]))
570
- for k in range (K )
571
- ]
572
- Y_perm = []
573
- for k in range (K ): # compute barycentric projections
574
- Y_perm .append (n * pi_list [k ] @ measure_locations [k ])
575
- X_next = B (Y_perm )
576
-
577
- if log :
578
- X_list .append (X_next )
609
+ exit_status = "Max iterations reached"
610
+
611
+ for _ in range (numItermax ):
612
+ pi_list = [ # compute the pairwise transport plans
613
+ emd (a , measure_weights [k ], cost_list [k ](X , measure_locations [k ]))
614
+ for k in range (K )
615
+ ]
616
+ Y_perm = []
617
+ for k in range (K ): # compute barycentric projections
618
+ Y_perm .append (n * pi_list [k ] @ measure_locations [k ])
619
+ if auto_ground_bary : # use previous position as initialization
620
+ X_next = ground_bary (Y_perm , X )
621
+ else :
622
+ X_next = ground_bary (Y_perm )
579
623
580
- # stationary criterion: move less than the threshold
581
- dX = nx .sum ((X - X_next ) ** 2 )
582
- X = X_next
624
+ if log :
625
+ X_list .append (X_next )
583
626
584
- if log :
585
- dX_list .append (dX )
627
+ # stationary criterion: move less than the threshold
628
+ dX = nx .sum ((X - X_next ) ** 2 )
629
+ X = X_next
586
630
587
- if dX < stopThr :
588
- exit_status = "Stationary Point"
589
- raise StoppingCriterionReached
631
+ if log :
632
+ dX_list .append (dX )
590
633
591
- exit_status = "Max iterations reached"
592
- raise StoppingCriterionReached
634
+ if dX < stopThr :
635
+ exit_status = "Stationary Point"
636
+ break
593
637
594
- except StoppingCriterionReached :
595
- if log :
596
- return X , {"X_list" : X_list , "exit_status" : exit_status , "dX_list" : dX_list }
597
- return X
638
+ if log :
639
+ return X , {"X_list" : X_list , "exit_status" : exit_status , "dX_list" : dX_list }
640
+ return X
0 commit comments