Skip to content

Commit

Permalink
Merge pull request #308 from fabinsch/qplayer-tensor
Browse files Browse the repository at this point in the history
Replace torch.Tensor and fix handling of batched ineq. constraints in QPlayer
  • Loading branch information
jcarpent authored Mar 14, 2024
2 parents 64c17d4 + 41fa052 commit 0398b24
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 71 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]

### What's Changed
* Change from torch.Tensor to torch.empty or torch.tensor and specify type explicitly ([#308](https://github.com/Simple-Robotics/proxsuite/pull/308))
* Fix handling of batch of inequality constraints in `QPFunctionFn_infeas`. The derivations in qplayer was done for single-sided constraints, that's the reason for the concatenation but the expansion of batchsize dimension was not working properly ([#308](https://github.com/Simple-Robotics/proxsuite/pull/308))
* Switch from self-hosted runner for macos-14-ARM to runner from github ([#306](https://github.com/Simple-Robotics/proxsuite/pull/306))

## [0.6.4] - 2024-03-01
Expand Down
62 changes: 34 additions & 28 deletions bindings/python/proxsuite/torch/qplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import proxsuite

from torch.autograd import Function
from .utils import expandParam, extract_nBatch, extract_nBatch_double_sided, bger
from .utils import expandParam, extract_nBatch, bger


def QPFunction(
Expand Down Expand Up @@ -92,7 +92,7 @@ def QPFunction(
class QPFunctionFn(Function):
@staticmethod
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
nBatch = extract_nBatch_double_sided(Q_, p_, A_, b_, G_, l_, u_)
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)
Q, _ = expandParam(Q_, nBatch, 3)
p, _ = expandParam(p_, nBatch, 2)
G, _ = expandParam(G_, nBatch, 3)
Expand All @@ -114,9 +114,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
if ctx.cpu is not None:
ctx.cpu = max(1, int(ctx.cpu / 2))

zhats = torch.empty((nBatch, ctx.nz)).type_as(Q)
lams = torch.empty((nBatch, ctx.neq)).type_as(Q)
nus = torch.empty((nBatch, ctx.nineq)).type_as(Q)
zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype)
lams = torch.empty((nBatch, ctx.neq), dtype=Q.dtype)
nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)

for i in range(nBatch):
qp = ctx.vector_of_qps.init_qp_in_place(ctx.nz, ctx.neq, ctx.nineq)
Expand Down Expand Up @@ -255,37 +255,39 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):

class QPFunctionFn_infeas(Function):
@staticmethod
# def forward(ctx, Q_, p_, G_, h_, A_, b_):
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
h_ = torch.cat((-l_, u_), 0)
G_ = torch.cat((-G_, G_), 0)
nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_)
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)

Q, _ = expandParam(Q_, nBatch, 3)
p, _ = expandParam(p_, nBatch, 2)
G, _ = expandParam(G_, nBatch, 3)
h, _ = expandParam(h_, nBatch, 2)
u, _ = expandParam(u_, nBatch, 2)
l, _ = expandParam(l_, nBatch, 2)
A, _ = expandParam(A_, nBatch, 3)
b, _ = expandParam(b_, nBatch, 2)

h = torch.cat((-l, u), axis=1) # single-sided inequality
G = torch.cat((-G, G), axis=1) # single-sided inequality

_, nineq, nz = G.size()
neq = A.size(1) if A.nelement() > 0 else 0
assert neq > 0 or nineq > 0
ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz

zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q)
nus = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype)
nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
lams = (
torch.Tensor(nBatch, ctx.neq).type_as(Q)
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
if ctx.neq > 0
else torch.Tensor()
else torch.empty()
)
s_e = (
torch.Tensor(nBatch, ctx.neq).type_as(Q)
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
if ctx.neq > 0
else torch.Tensor()
else torch.empty()
)
slacks = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
s_i = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
slacks = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
s_i = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)

vector_of_qps = proxsuite.proxqp.dense.BatchQP()

Expand Down Expand Up @@ -338,32 +340,36 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):

for i in range(nBatch):
si = -h[i] + G[i] @ vector_of_qps.get(i).results.x
zhats[i] = torch.Tensor(vector_of_qps.get(i).results.x)
nus[i] = torch.Tensor(vector_of_qps.get(i).results.z)
slacks[i] = torch.Tensor(si)
zhats[i] = torch.tensor(vector_of_qps.get(i).results.x)
nus[i] = torch.tensor(vector_of_qps.get(i).results.z)
slacks[i] = si.clone().detach()
if neq > 0:
lams[i] = torch.Tensor(vector_of_qps.get(i).results.y)
s_e[i] = torch.Tensor(vector_of_qps.get(i).results.se)
s_i[i] = torch.Tensor(vector_of_qps.get(i).results.si)
lams[i] = torch.tensor(vector_of_qps.get(i).results.y)
s_e[i] = torch.tensor(vector_of_qps.get(i).results.se)
s_i[i] = torch.tensor(vector_of_qps.get(i).results.si)

ctx.lams = lams
ctx.nus = nus
ctx.slacks = slacks
ctx.save_for_backward(zhats, s_e, Q_, p_, G_, h_, A_, b_)
ctx.save_for_backward(zhats, s_e, Q_, p_, G_, l_, u_, A_, b_)
return zhats, lams, nus, s_e, s_i

@staticmethod
def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
zhats, s_e, Q, p, G, h, A, b = ctx.saved_tensors
nBatch = extract_nBatch(Q, p, G, h, A, b)
zhats, s_e, Q, p, G, l, u, A, b = ctx.saved_tensors
nBatch = extract_nBatch(Q, p, A, b, G, l, u)

Q, Q_e = expandParam(Q, nBatch, 3)
p, p_e = expandParam(p, nBatch, 2)
G, G_e = expandParam(G, nBatch, 3)
h, h_e = expandParam(h, nBatch, 2)
_, u_e = expandParam(u, nBatch, 2)
_, l_e = expandParam(l, nBatch, 2)
A, A_e = expandParam(A, nBatch, 3)
b, b_e = expandParam(b, nBatch, 2)

h_e = l_e or u_e
G = torch.cat((-G, G), axis=1)

neq, nineq = ctx.neq, ctx.nineq
dx = torch.zeros((nBatch, Q.shape[1]))
dnu = None
Expand Down
20 changes: 1 addition & 19 deletions bindings/python/proxsuite/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np


def extract_nBatch_double_sided(Q, p, A, b, G, l, u):
def extract_nBatch(Q, p, A, b, G, l, u):
dims = [3, 2, 3, 2, 3, 2, 2]
params = [Q, p, A, b, G, l, u]
for param, dim in zip(params, dims):
Expand All @@ -11,15 +11,6 @@ def extract_nBatch_double_sided(Q, p, A, b, G, l, u):
return 1


def extract_nBatch_(Q, p, G, A, b):
dims = [3, 2, 3, 3, 2]
params = [Q, p, G, A, b]
for param, dim in zip(params, dims):
if param.ndimension() == dim:
return param.size(0)
return 1


# from qpth: https://github.com/locuslab/qpth/blob/master/qpth/util.py
def print_header(msg):
print("===>", msg)
Expand Down Expand Up @@ -67,12 +58,3 @@ def expandParam(X, nBatch, nDim):
return X.unsqueeze(0).expand(*([nBatch] + list(X.size()))), True
else:
raise RuntimeError("Unexpected number of dimensions.")


def extract_nBatch(Q, p, G, h, A, b):
dims = [3, 2, 3, 2, 3, 2]
params = [Q, p, G, h, A, b]
for param, dim in zip(params, dims):
if param.ndimension() == dim:
return param.size(0)
return 1
44 changes: 20 additions & 24 deletions examples/python/qplayer_sudoku.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def __init__(self, n, omp_parallel=False, maxIter=1000):
self.maxIter = maxIter
self.omp_parallel = omp_parallel
nx = (n**2) ** 3
self.Q = Variable(torch.zeros(nx, nx).double())
self.G = Variable(-torch.eye(nx).double())
self.u = Variable(torch.zeros(nx).double())
self.l = Variable(-1.0e20 * torch.ones(nx).double())
self.Q = torch.zeros(nx, nx, dtype=torch.float64)
self.G = -torch.eye(nx, dtype=torch.float64)
self.u = torch.zeros(nx, dtype=torch.float64)
self.l = -1.0e20 * torch.ones(nx, dtype=torch.float64)
t = get_sudoku_matrix(n)
self.A = Parameter(torch.rand(t.shape).double())
self.log_z0 = Parameter(torch.zeros(nx).double())
self.A = Parameter(torch.rand(t.shape, dtype=torch.float64))
self.log_z0 = Parameter(torch.zeros(nx, dtype=torch.float64))

def forward(self, puzzles):
nBatch = puzzles.size(0)
Expand All @@ -78,14 +78,14 @@ def __init__(self, n, omp_parallel=False, maxIter=1000):

nx = (n**2) ** 3
Qpenalty = 0.0
self.Q = Variable(Qpenalty * torch.eye(nx).double())
self.Q = Qpenalty * torch.eye(nx, dtype=torch.float64)

self.G = Variable(-torch.eye(nx).double())
self.h = Variable(torch.zeros(nx).double())
self.l = Variable(-1.0e20 * torch.ones(nx).double())
self.G = -torch.eye(nx, dtype=torch.float64)
self.h = torch.zeros(nx, dtype=torch.float64)
self.l = -1.0e20 * torch.ones(nx, dtype=torch.float64)
t = get_sudoku_matrix(n)
self.A = Parameter(torch.rand(t.shape).double())
self.b = Variable(torch.ones(self.A.size(0)).double())
self.A = Parameter(torch.rand(t.shape, dtype=torch.float64))
self.b = torch.ones(self.A.size(0), dtype=torch.float64)

def forward(self, puzzles):
nBatch = puzzles.size(0)
Expand All @@ -103,15 +103,13 @@ def forward(self, puzzles):
def train(args, epoch, model, trainX, trainY, optimizer):
batchSz = args.batchSz

batch_data_t = torch.FloatTensor(
batchSz, trainX.size(1), trainX.size(2), trainX.size(3)
batch_data = torch.empty(
(batchSz, trainX.size(1), trainX.size(2), trainX.size(3)), dtype=torch.float32
)
batch_targets_t = torch.FloatTensor(
batchSz, trainY.size(1), trainX.size(2), trainX.size(3)
batch_targets = torch.empty(
(batchSz, trainY.size(1), trainX.size(2), trainX.size(3)), dtype=torch.float32
)

batch_data = Variable(batch_data_t, requires_grad=False)
batch_targets = Variable(batch_targets_t, requires_grad=False)
for i in range(0, trainX.size(0), batchSz):
start = time.time()
batch_data.data[:] = trainX[i : i + batchSz]
Expand Down Expand Up @@ -140,14 +138,12 @@ def test(args, epoch, model, testX, testY):
batchSz = args.testBatchSz

test_loss = 0
batch_data_t = torch.FloatTensor(
batchSz, testX.size(1), testX.size(2), testX.size(3)
batch_data = torch.empty(
(batchSz, testX.size(1), testX.size(2), testX.size(3)), dtype=torch.float32
)
batch_targets_t = torch.FloatTensor(
batchSz, testY.size(1), testX.size(2), testX.size(3)
batch_targets = torch.empty(
(batchSz, testY.size(1), testX.size(2), testX.size(3)), dtype=torch.float32
)
batch_data = Variable(batch_data_t)
batch_targets = Variable(batch_targets_t)

nErr = 0
for i in range(0, testX.size(0), batchSz):
Expand Down

0 comments on commit 0398b24

Please sign in to comment.