Skip to content

Commit

Permalink
Merge pull request #350 from jcarpent/devel
Browse files Browse the repository at this point in the history
Sync submodule CMake
  • Loading branch information
jcarpent authored Aug 27, 2024
2 parents 6d708e6 + c66e666 commit 52d0095
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
66 changes: 46 additions & 20 deletions bindings/python/proxsuite/torch/qplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
class QPFunctionFn_infeas(Function):
@staticmethod
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
n_in, nz = G_.size() # true double-sided inequality size
n_in, nz = G_.size() # true double-sided inequality size
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)

Q, _ = expandParam(Q_, nBatch, 3)
Expand All @@ -277,7 +277,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):

zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype)
nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
nus_sol = torch.empty((nBatch, n_in), dtype=Q.dtype) # double-sided inequality multiplier
nus_sol = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # double-sided inequality multiplier
lams = (
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
if ctx.neq > 0
Expand All @@ -289,7 +291,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
else torch.empty()
)
slacks = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
s_i = torch.empty((nBatch, n_in), dtype=Q.dtype) # this one is of size the one of the original n_in
s_i = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # this one is of size the one of the original n_in

vector_of_qps = proxsuite.proxqp.dense.BatchQP()

Expand Down Expand Up @@ -342,17 +346,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):

for i in range(nBatch):
zhats[i] = torch.tensor(vector_of_qps.get(i).results.x)
if nineq>0:
if nineq > 0:
# we re-convert the solution to a double sided inequality QP
slack = -h[i] + G[i] @ vector_of_qps.get(i).results.x
nus_sol[i] = torch.Tensor(-vector_of_qps.get(i).results.z[:n_in]+vector_of_qps.get(i).results.z[n_in:]) # de-projecting this one may provoke loss of information when using inexact solution
nus_sol[i] = torch.Tensor(
-vector_of_qps.get(i).results.z[:n_in]
+ vector_of_qps.get(i).results.z[n_in:]
) # de-projecting this one may provoke loss of information when using inexact solution
nus[i] = torch.tensor(vector_of_qps.get(i).results.z)
slacks[i] = slack.clone().detach()
s_i[i] = torch.tensor(-vector_of_qps.get(i).results.si[:n_in]+vector_of_qps.get(i).results.si[n_in:])
s_i[i] = torch.tensor(
-vector_of_qps.get(i).results.si[:n_in]
+ vector_of_qps.get(i).results.si[n_in:]
)
if neq > 0:
lams[i] = torch.tensor(vector_of_qps.get(i).results.y)
s_e[i] = torch.tensor(vector_of_qps.get(i).results.se)

ctx.lams = lams
ctx.nus = nus
ctx.slacks = slacks
Expand All @@ -377,7 +387,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):

neq, nineq = ctx.neq, ctx.nineq
# true size
n_in_sol = int(nineq/2)
n_in_sol = int(nineq / 2)
dx = torch.zeros((nBatch, Q.shape[1]))
dnu = None
b_5 = None
Expand Down Expand Up @@ -464,26 +474,34 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
rhs = np.zeros(kkt.shape[0])
rhs[:dim] = -dl_dzhat[i]
if dl_dlams != None:
if n_eq!= 0:
if n_eq != 0:
rhs[dim : dim + n_eq] = -dl_dlams[i]
active_set = None
if n_in!=0:
active_set = -z_i[:n_in_sol]+z_i[n_in_sol:] >= 0
active_set = None
if n_in != 0:
active_set = -z_i[:n_in_sol] + z_i[n_in_sol:] >= 0
if dl_dnus != None:
if n_in !=0:
if n_in != 0:
# we must convert dl_dnus to a uni sided version
# to do so we reconstitute the active set
rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[i][~active_set]
rhs[dim + n_eq + n_in_sol: dim + n_eq + n_in][active_set] = -dl_dnus[i][active_set]
# to do so we reconstitute the active set
rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[
i
][~active_set]
rhs[dim + n_eq + n_in_sol : dim + n_eq + n_in][active_set] = (
-dl_dnus[i][active_set]
)
if dl_ds_e != None:
if dl_ds_e.shape[0] != 0:
rhs[dim + n_eq + n_in : dim + 2 * n_eq + n_in] = -dl_ds_e[i]
if dl_ds_i != None:
if dl_ds_i.shape[0] != 0:
# we must convert dl_dnus to a uni sided version
# to do so we reconstitute the active set
rhs[dim + 2 * n_eq + n_in : dim + 2 * n_eq + n_in + n_in_sol][~active_set] = dl_ds_i[i][~active_set]
rhs[dim + 2 * n_eq + n_in + n_in_sol:][active_set] = -dl_ds_i[i][active_set]
# to do so we reconstitute the active set
rhs[dim + 2 * n_eq + n_in : dim + 2 * n_eq + n_in + n_in_sol][
~active_set
] = dl_ds_i[i][~active_set]
rhs[dim + 2 * n_eq + n_in + n_in_sol :][active_set] = -dl_ds_i[
i
][active_set]

l = np.zeros(0)
u = np.zeros(0)
Expand Down Expand Up @@ -580,7 +598,15 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
if p_e:
dps = dps.mean(0)

grads = (dQs, dps, dAs, dbs, dGs[n_in_sol:, :], -dhs[:n_in_sol], dhs[n_in_sol:])
grads = (
dQs,
dps,
dAs,
dbs,
dGs[n_in_sol:, :],
-dhs[:n_in_sol],
dhs[n_in_sol:],
)

return grads

Expand Down

0 comments on commit 52d0095

Please sign in to comment.