Skip to content

Commit

Permalink
qplayer - infeasible case : fix dimensional typo for double sided ine…
Browse files Browse the repository at this point in the history
…qualities
  • Loading branch information
abambade authored and jcarpent committed Aug 7, 2024
1 parent 942ec03 commit 471d8ad
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions bindings/python/proxsuite/torch/qplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
class QPFunctionFn_infeas(Function):
@staticmethod
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
n_in, nz = G_.size() # true double-sided inequality size
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)

Q, _ = expandParam(Q_, nBatch, 3)
Expand All @@ -276,6 +277,7 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):

zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype)
nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
nus_sol = torch.empty((nBatch, n_in), dtype=Q.dtype) # double-sided inequality multiplier
lams = (
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
if ctx.neq > 0
Expand All @@ -287,7 +289,7 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
else torch.empty()
)
slacks = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
s_i = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
s_i = torch.empty((nBatch, n_in), dtype=Q.dtype) # this one is of size the one of the original n_in

vector_of_qps = proxsuite.proxqp.dense.BatchQP()

Expand Down Expand Up @@ -339,20 +341,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
vector_of_qps.get(i).solve()

for i in range(nBatch):
si = -h[i] + G[i] @ vector_of_qps.get(i).results.x
zhats[i] = torch.tensor(vector_of_qps.get(i).results.x)
nus[i] = torch.tensor(vector_of_qps.get(i).results.z)
slacks[i] = si.clone().detach()
if nineq>0:
# we re-convert the solution to a double sided inequality QP
slack = -h[i] + G[i] @ vector_of_qps.get(i).results.x
nus_sol[i] = torch.Tensor(-vector_of_qps.get(i).results.z[:n_in]+vector_of_qps.get(i).results.z[n_in:]) # de-projecting this one may provoke loss of information when using inexact solution
nus[i] = torch.tensor(vector_of_qps.get(i).results.z)
slacks[i] = slack.clone().detach()
s_i[i] = torch.tensor(-vector_of_qps.get(i).results.si[:n_in]+vector_of_qps.get(i).results.si[n_in:])
if neq > 0:
lams[i] = torch.tensor(vector_of_qps.get(i).results.y)
s_e[i] = torch.tensor(vector_of_qps.get(i).results.se)
s_i[i] = torch.tensor(vector_of_qps.get(i).results.si)


ctx.lams = lams
ctx.nus = nus
ctx.slacks = slacks
ctx.save_for_backward(zhats, s_e, Q_, p_, G_, l_, u_, A_, b_)
return zhats, lams, nus, s_e, s_i
return zhats, lams, nus_sol, s_e, s_i

@staticmethod
def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
Expand All @@ -371,6 +376,8 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
G = torch.cat((-G, G), axis=1)

neq, nineq = ctx.neq, ctx.nineq
# true size
n_in_sol = int(nineq/2)
dx = torch.zeros((nBatch, Q.shape[1]))
dnu = None
b_5 = None
Expand Down Expand Up @@ -457,15 +464,26 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
rhs = np.zeros(kkt.shape[0])
rhs[:dim] = -dl_dzhat[i]
if dl_dlams != None:
rhs[dim : dim + n_eq] = -dl_dlams[i]
if n_eq!= 0:
rhs[dim : dim + n_eq] = -dl_dlams[i]
active_set = None
if n_in!=0:
active_set = -z_i[:n_in_sol]+z_i[n_in_sol:] >= 0
if dl_dnus != None:
rhs[dim + n_eq : dim + n_eq + n_in] = -dl_dnus[i]
if n_in !=0:
# we must convert dl_dnus to a uni sided version
# to do so we reconstitute the active set
rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[i][~active_set]
rhs[dim + n_eq + n_in_sol: dim + n_eq + n_in][active_set] = -dl_dnus[i][active_set]
if dl_ds_e != None:
if dl_ds_e.shape[0] != 0:
rhs[dim + n_eq + n_in : dim + 2 * n_eq + n_in] = -dl_ds_e[i]
if dl_ds_i != None:
if dl_ds_i.shape[0] != 0:
rhs[dim + 2 * n_eq + n_in :] = -dl_ds_i[i]
# we must convert dl_dnus to a uni sided version
# to do so we reconstitute the active set
rhs[dim + 2 * n_eq + n_in : dim + 2 * n_eq + n_in + n_in_sol][~active_set] = dl_ds_i[i][~active_set]
rhs[dim + 2 * n_eq + n_in + n_in_sol:][active_set] = -dl_ds_i[i][active_set]

l = np.zeros(0)
u = np.zeros(0)
Expand Down Expand Up @@ -562,7 +580,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
if p_e:
dps = dps.mean(0)

grads = (dQs, dps, dAs, dbs, dGs[nineq:, :], -dhs[:nineq], dhs[nineq:])
grads = (dQs, dps, dAs, dbs, dGs[n_in_sol:, :], -dhs[:n_in_sol], dhs[n_in_sol:])

return grads

Expand Down

0 comments on commit 471d8ad

Please sign in to comment.