From 70ebaa6a27a1e95a9317fd71846808a6a65bdd75 Mon Sep 17 00:00:00 2001 From: abambade Date: Mon, 29 Jul 2024 16:26:55 +0200 Subject: [PATCH] qplayer - infeasible case : fix dimensional typo for double sided inequalities --- bindings/python/proxsuite/torch/qplayer.py | 40 ++++++++++++++++------ 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/bindings/python/proxsuite/torch/qplayer.py b/bindings/python/proxsuite/torch/qplayer.py index 1c9955946..c7076eaa5 100644 --- a/bindings/python/proxsuite/torch/qplayer.py +++ b/bindings/python/proxsuite/torch/qplayer.py @@ -256,6 +256,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus): class QPFunctionFn_infeas(Function): @staticmethod def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): + n_in, nz = G_.size() # true double-sided inequality size nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_) Q, _ = expandParam(Q_, nBatch, 3) @@ -276,6 +277,7 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype) nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype) + nus_sol = torch.empty((nBatch, n_in), dtype=Q.dtype) # double-sided inequality multiplier lams = ( torch.empty(nBatch, ctx.neq, dtype=Q.dtype) if ctx.neq > 0 @@ -287,7 +289,7 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): else torch.empty() ) slacks = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype) - s_i = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype) + s_i = torch.empty((nBatch, n_in), dtype=Q.dtype) # this one is of size the one of the original n_in vector_of_qps = proxsuite.proxqp.dense.BatchQP() @@ -339,20 +341,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): vector_of_qps.get(i).solve() for i in range(nBatch): - si = -h[i] + G[i] @ vector_of_qps.get(i).results.x zhats[i] = torch.tensor(vector_of_qps.get(i).results.x) - nus[i] = torch.tensor(vector_of_qps.get(i).results.z) - slacks[i] = si.clone().detach() + if nineq>0: + # we re-convert the solution to a double sided inequality QP + slack = -h[i] + G[i] @ vector_of_qps.get(i).results.x + nus_sol[i] = torch.Tensor(-vector_of_qps.get(i).results.z[:n_in]+vector_of_qps.get(i).results.z[n_in:]) # de-projecting this one may provoke loss of information when using inexact solution + nus[i] = torch.tensor(vector_of_qps.get(i).results.z) + slacks[i] = slack.clone().detach() + s_i[i] = torch.tensor(-vector_of_qps.get(i).results.si[:n_in]+vector_of_qps.get(i).results.si[n_in:]) if neq > 0: lams[i] = torch.tensor(vector_of_qps.get(i).results.y) s_e[i] = torch.tensor(vector_of_qps.get(i).results.se) - s_i[i] = torch.tensor(vector_of_qps.get(i).results.si) - + ctx.lams = lams ctx.nus = nus ctx.slacks = slacks ctx.save_for_backward(zhats, s_e, Q_, p_, G_, l_, u_, A_, b_) - return zhats, lams, nus, s_e, s_i + return zhats, lams, nus_sol, s_e, s_i @staticmethod def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): @@ -371,6 +376,8 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): G = torch.cat((-G, G), axis=1) neq, nineq = ctx.neq, ctx.nineq + # true size + n_in_sol = int(nineq/2) dx = torch.zeros((nBatch, Q.shape[1])) dnu = None b_5 = None @@ -457,15 +464,26 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): rhs = np.zeros(kkt.shape[0]) rhs[:dim] = -dl_dzhat[i] if dl_dlams != None: - rhs[dim : dim + n_eq] = -dl_dlams[i] + if n_eq!= 0: + rhs[dim : dim + n_eq] = -dl_dlams[i] + active_set = None + if n_in!=0: + active_set = -z_i[:n_in_sol]+z_i[n_in_sol:] >= 0 if dl_dnus != None: - rhs[dim + n_eq : dim + n_eq + n_in] = -dl_dnus[i] + if n_in !=0: + # we must convert dl_dnus to a uni sided version + # to do so we reconstitute the active set + rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[i][~active_set] + rhs[dim + n_eq + n_in_sol: dim + n_eq + n_in][active_set] = -dl_dnus[i][active_set] if dl_ds_e != None: if dl_ds_e.shape[0] != 0: rhs[dim + n_eq + n_in : dim + 2 * n_eq + n_in] = -dl_ds_e[i] if dl_ds_i != None: if dl_ds_i.shape[0] != 0: - rhs[dim + 2 * n_eq + n_in :] = -dl_ds_i[i] + # we must convert dl_dnus to a uni sided version + # to do so we reconstitute the active set + rhs[dim + 2 * n_eq + n_in : dim + 2 * n_eq + n_in + n_in_sol][~active_set] = dl_ds_i[i][~active_set] + rhs[dim + 2 * n_eq + n_in + n_in_sol:][active_set] = -dl_ds_i[i][active_set] l = np.zeros(0) u = np.zeros(0) @@ -562,7 +580,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i): if p_e: dps = dps.mean(0) - grads = (dQs, dps, dAs, dbs, dGs[nineq:, :], -dhs[:nineq], dhs[nineq:]) + grads = (dQs, dps, dAs, dbs, dGs[n_in_sol:, :], -dhs[:n_in_sol], dhs[n_in_sol:]) return grads