@@ -31,6 +31,7 @@ def nearest_brenier_potential_fit(
31
31
its = 100 ,
32
32
log = False ,
33
33
init_method = "barycentric" ,
34
+ solver = None ,
34
35
):
35
36
r"""
36
37
Computes optimal values and gradients at X for a strongly convex potential :math:`\varphi` with Lipschitz gradients
@@ -87,7 +88,10 @@ def nearest_brenier_potential_fit(
87
88
log : bool, optional
88
89
record log if true
89
90
init_method : str, optional
90
- 'target' initialises G=V, 'barycentric' initialises at the image of X by the barycentric projection
91
+ 'target' initialises G=V, 'barycentric' initialises at the image of X by
92
+ the barycentric projection
93
+ solver : str, optional
94
+ The CVXPY solver to use
91
95
92
96
Returns
93
97
-------
@@ -173,7 +177,7 @@ def nearest_brenier_potential_fit(
173
177
- c3 * (G [j ] - G [i ]).T @ (X [j ] - X [i ])
174
178
]
175
179
problem = cvx .Problem (objective , constraints )
176
- problem .solve (solver = cvx . ECOS )
180
+ problem .solve (solver = solver )
177
181
phi_val , G_val = phi .value , G .value
178
182
it_log_dict = {
179
183
"solve_time" : problem .solver_stats .solve_time ,
@@ -231,6 +235,7 @@ def nearest_brenier_potential_predict_bounds(
231
235
strongly_convex_constant = 0.6 ,
232
236
gradient_lipschitz_constant = 1.4 ,
233
237
log = False ,
238
+ solver = None ,
234
239
):
235
240
r"""
236
241
Compute the values of the lower and upper bounding potentials at the input points Y, using the potential optimal
@@ -290,6 +295,8 @@ def nearest_brenier_potential_predict_bounds(
290
295
constant for the Lipschitz property of the input gradient G, defaults to 1.4
291
296
log : bool, optional
292
297
record log if true
298
+ solver : str, optional
299
+ The CVXPY solver to use
293
300
294
301
Returns
295
302
-------
@@ -368,7 +375,7 @@ def nearest_brenier_potential_predict_bounds(
368
375
- c3 * (G [j ] - G_l_y ).T @ (X [j ] - Y [y_idx ])
369
376
]
370
377
problem = cvx .Problem (objective , constraints )
371
- problem .solve (solver = cvx . ECOS )
378
+ problem .solve (solver = solver )
372
379
phi_lu [0 , y_idx ] = phi_l_y .value
373
380
G_lu [0 , y_idx ] = G_l_y .value
374
381
if log :
@@ -395,7 +402,7 @@ def nearest_brenier_potential_predict_bounds(
395
402
- c3 * (G_u_y - G [i ]).T @ (Y [y_idx ] - X [i ])
396
403
]
397
404
problem = cvx .Problem (objective , constraints )
398
- problem .solve (solver = cvx . ECOS )
405
+ problem .solve (solver = solver )
399
406
phi_lu [1 , y_idx ] = phi_u_y .value
400
407
G_lu [1 , y_idx ] = G_u_y .value
401
408
if log :
0 commit comments