Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
49267de
extremely generic API for trajectory optimization
alberthli Dec 2, 2023
1035dd4
fix spacing
alberthli Dec 2, 2023
22f9867
remove unnecessary __init__ method (brainfart)
alberthli Dec 2, 2023
b8d7e8e
some additional massaging for best abstraction + useless scaffolding …
alberthli Dec 2, 2023
f5610fc
add some informative comments to base.py
alberthli Dec 3, 2023
17642f6
[maybe] added cost function to base API
alberthli Dec 3, 2023
4e1ae77
shooting method APIs, no cost function implemented yet
alberthli Dec 3, 2023
8c9cd6b
define generic CostFunction API
alberthli Dec 3, 2023
0d175f0
expose a CostFunction field of shooting methods
alberthli Dec 3, 2023
987e334
add specific implementation of (non-sparse) quadratic CostFunction wi…
alberthli Dec 3, 2023
2b268f8
fixed missing field in docstrings
alberthli Dec 4, 2023
13efcfd
fixed missing field in docstrings
alberthli Dec 4, 2023
b6ea473
pass non-kwarg functions to vmap since vmap breaks in that case
alberthli Dec 4, 2023
6bf291d
added preliminary dead simple predictive sampling example
alberthli Dec 4, 2023
2cfdbfd
minor, get from property
alberthli Dec 4, 2023
dc5d247
fixed some params + simplify sampling of us to one line
alberthli Dec 4, 2023
ff6dbcc
[DIRTY] commit that contains commented out code for pre-allocating da…
alberthli Dec 4, 2023
9f12d74
revert to allocating data at shooting time
alberthli Dec 4, 2023
3fda8ba
[DIRTY] some additions to the example script for profiling help
alberthli Dec 4, 2023
0f507e3
refactor API to take xs instead of qs and vs
alberthli Dec 5, 2023
b2d8296
tests for cost function and its derivatives
alberthli Dec 5, 2023
0e24057
added smoke test for vanilla predictive sampler
alberthli Dec 5, 2023
adf4243
remove the example since it's implemented as a benchmark in an upstre…
alberthli Dec 5, 2023
773be6a
minor docstring edit
alberthli Dec 5, 2023
544f964
ensure predictive sampler also accounts for cost of guess vs. samples
alberthli Dec 5, 2023
4e22073
add sanity check for predictive sampling
alberthli Dec 5, 2023
0c49505
added fixture to tests
alberthli Dec 5, 2023
90ea8f4
fix stray parenthetical
alberthli Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions ambersim/trajopt/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import Tuple

import jax
from flax import struct
from jax import grad, hessian

# ####################### #
# TRAJECTORY OPTIMIZATION #
# ####################### #


@struct.dataclass
class TrajectoryOptimizerParams:
"""The parameters for generic trajectory optimization algorithms.

Parameters we may want to optimize should be included here.

This is left completely empty to allow for maximum flexibility in the API. Some examples:
- A direct collocation method might have parameters for the number of collocation points, the collocation
scheme, and the number of optimization iterations.
- A shooting method might have parameters for the number of shooting points, the shooting scheme, and the number
of optimization iterations.
The parameters also include initial iterates for each type of algorithm. Some examples:
- A direct collocation method might have initial iterates for the controls and the state trajectory.
- A shooting method might have initial iterates for the controls only.

Parameters which we want to remain untouched by JAX transformations can be marked by pytree_node=False, e.g.,
```
@struct.dataclass
class ChildParams:
...
# example field
example: int = struct.field(pytree_node=False)
...
```
"""


@struct.dataclass
class TrajectoryOptimizer:
"""The API for generic trajectory optimization algorithms on mechanical systems.

We choose to implement this as a flax dataclass (as opposed to a regular class whose functions operate on pytree
nodes) because:
(1) the OOP formalism allows us to define coherent abstractions through inheritance;
(2) struct.dataclass registers dataclasses a pytree nodes, so we can deal with awkward issues like the `self`
variable when using JAX transformations on methods of the dataclass.

Further, we choose not to specify the mjx.Model as either a field of this dataclass or as a parameter. The reason is
because we want to allow for maximum flexibility in the API. Two motivating scenarios:
(1) we want to domain randomize over the model parameters and potentially optimize for them. In this case, it makes
sense to specify the mjx.Model as a parameter that gets passed as an input into the optimize function.
(2) we want to fix the model and only randomize/optimize over non-model-parameters. For instance, this is the
situation in vanilla predictive sampling. If we don't need to pass the model, we instead initialize it as a
field of this dataclass, which makes the optimize function more performant, since it can just reference the
fixed model attribute of the optimizer instead of applying JAX transformations to the entire large model pytree.
Similar logic applies for not specifying the role of the CostFunction - we trust that the user will either use the
provided API or will ignore it and still end up implementing something custom and reasonable.

Finally, abstract dataclasses are weird, so we just make all children implement the below functions by instead
raising a NotImplementedError.
"""

def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]:
"""Optimizes a trajectory.

The shapes of the outputs include (?) because we may choose to return non-zero-order-hold parameterizations of
the optimized trajectories (for example, we could choose to return a cubic spline parameterization of the
control inputs over the trajectory as is done in the gradient-based methods of MJPC).

Args:
params: The parameters of the trajectory optimizer.

Returns:
xs_star (shape=(N + 1, nq + nv) or (?)): The optimized trajectory.
us_star (shape=(N, nu) or (?)): The optimized controls.
"""
raise NotImplementedError


# ############# #
# COST FUNCTION #
# ############# #


@struct.dataclass
class CostFunctionParams:
"""Generic parameters for cost functions."""


@struct.dataclass
class CostFunction:
"""The API for generic cost functions for trajectory optimization problems for mechanical systems.

Rationale behind CostFunctionParams in this generic API:
(1) computation of higher-order derivatives could depend on results or intermediates from lower-order derivatives.
So, we can flexibly cache the requisite values to avoid repeated computation;
(2) we may want to randomize or optimize the cost function parameters themselves, so specifying a generic pytree
as input generically accounts for all possibilities;
(3) there could simply be parameters that cannot be easily specified in advance that are key for cost evaluation,
like a time-varying reference trajectory that gets updated in real time.
(4) histories of higher-order derivatives can be useful for updating their current estimates, e.g., BFGS.
"""

def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]:
"""Computes the cost of a trajectory.

Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
params: The parameters of the cost function.

Returns:
val (shape=(,)): The cost of the trajectory.
new_params: The updated parameters of the cost function.
"""
raise NotImplementedError

def grad(
self, xs: jax.Array, us: jax.Array, params: CostFunctionParams
) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]:
"""Computes the gradient of the cost of a trajectory.

The default implementation of this function uses JAX's autodiff. Simply override this function if you would like
to supply an analytical gradient.

Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
params: The parameters of the cost function.

Returns:
gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs.
gcost_us (shape=(N, nu)): The gradient of the cost wrt us.
gcost_params: The gradient of the cost wrt params.
new_params: The updated parameters of the cost function.
"""
_fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val
return grad(_fn, argnums=(0, 1, 2))(xs, us, params) + (params,)

def hess(
self, xs: jax.Array, us: jax.Array, params: CostFunctionParams
) -> Tuple[
jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams
]:
"""Computes the Hessian of the cost of a trajectory.

The default implementation of this function uses JAX's autodiff. Simply override this function if you would like
to supply an analytical Hessian.

Let t, s be times 0, 1, 2, etc. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j].

Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
params: The parameters of the cost function.

Returns:
Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs.
Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us.
Hcost_xsparams: The Hessian of the cost wrt xs and params.
Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us.
Hcost_usparams: The Hessian of the cost wrt us and params.
Hcost_paramsall: The Hessian of the cost wrt params and everything else.
new_params: The updated parameters of the cost function.
"""
_fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val
hessians = hessian(_fn, argnums=(0, 1, 2))(xs, us, params)
Hcost_xsxs, Hcost_xsus, Hcost_xsparams = hessians[0]
_, Hcost_usus, Hcost_usparams = hessians[1]
Hcost_paramsall = hessians[2]
return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params
178 changes: 178 additions & 0 deletions ambersim/trajopt/cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import Tuple

import jax
import jax.numpy as jnp
from flax import struct
from jax import lax, vmap

from ambersim.trajopt.base import CostFunction, CostFunctionParams

"""A collection of common cost functions."""


class StaticGoalQuadraticCost(CostFunction):
"""A quadratic cost function that penalizes the distance to a static goal.

This is the most vanilla possible quadratic cost. The cost matrices are static (defined at init time) and so is the
single, fixed goal. The gradient is as compressed as it can be in general (one matrix multiplication), but the
Hessian can be far more compressed by simplying referencing Q, Qf, and R - this implementation is inefficient and
dense.
"""

def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, xg: jax.Array) -> None:
"""Initializes a quadratic cost function.

Args:
Q (shape=(nx, nx)): The state cost matrix.
Qf (shape=(nx, nx)): The final state cost matrix.
R (shape=(nu, nu)): The control cost matrix.
xg (shape=(nq,)): The goal state.
"""
self.Q = Q
self.Qf = Qf
self.R = R
self.xg = xg

@staticmethod
def batch_quadform(bs: jax.Array, A: jax.Array) -> jax.Array:
"""Computes a batched quadratic form for a single instance of A.

Args:
bs (shape=(..., n)): The batch of vectors.
A (shape=(n, n)): The matrix.

Returns:
val (shape=(...,)): The batch of quadratic forms.
"""
return jnp.einsum("...i,ij,...j->...", bs, A, bs)

@staticmethod
def batch_matmul(bs: jax.Array, A: jax.Array) -> jax.Array:
"""Computes a batched matrix multiplication for a single instance of A.

Args:
bs (shape=(..., n)): The batch of vectors.
A (shape=(n, n)): The matrix.

Returns:
val (shape=(..., n)): The batch of matrix multiplications.
"""
return jnp.einsum("...i,ij->...j", bs, A)

def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]:
"""Computes the cost of a trajectory.

cost = 0.5 * (xs - xg)' @ Q @ (xs - xg) + 0.5 * us' @ R @ us

Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
params: Unused. Included for API compliance.

Returns:
cost_val: The cost of the trajectory.
new_params: Unused. Included for API compliance.
"""
xs_err = xs[:-1, :] - self.xg # errors before the terminal state
xf_err = xs[-1, :] - self.xg
val = 0.5 * jnp.squeeze(
(
jnp.sum(self.batch_quadform(xs_err, self.Q))
+ self.batch_quadform(xf_err, self.Qf)
+ jnp.sum(self.batch_quadform(us, self.R))
)
)
return val, params

def grad(
self, xs: jax.Array, us: jax.Array, params: CostFunctionParams
) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]:
"""Computes the gradient of the cost of a trajectory.

Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
params: Unused. Included for API compliance.

Returns:
gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs.
gcost_us (shape=(N, nu)): The gradient of the cost wrt us.
gcost_params: Unused. Included for API compliance.
new_params: Unused. Included for API compliance.
"""
xs_err = xs[:-1, :] - self.xg # errors before the terminal state
xf_err = xs[-1, :] - self.xg
gcost_xs = jnp.concatenate(
(
self.batch_matmul(xs_err, self.Q),
(self.Qf @ xf_err)[None, :],
),
axis=-2,
)
gcost_us = self.batch_matmul(us, self.R)
return gcost_xs, gcost_us, params, params

def hess(
self, xs: jax.Array, us: jax.Array, params: CostFunctionParams
) -> Tuple[
jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams
]:
"""Computes the gradient of the cost of a trajectory.

Let t, s be times 0, 1, 2, etc. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j].

Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
params: Unused. Included for API compliance.

Returns:
Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs.
Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us.
Hcost_xsparams: The Hessian of the cost wrt xs and params.
Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us.
Hcost_usparams: The Hessian of the cost wrt us and params.
Hcost_paramsall: The Hessian of the cost wrt params and everything else.
new_params: The updated parameters of the cost function.
"""
# setting up
nx = self.Q.shape[0]
N, nu = us.shape
Q = self.Q
Qf = self.Qf
R = self.R
dummy_params = CostFunctionParams()

# Hessian for state
Hcost_xsxs = jnp.zeros((N + 1, nx, N + 1, nx))
Hcost_xsxs = vmap(
lambda i: lax.dynamic_update_slice(
jnp.zeros((nx, N + 1, nx)),
Q[:, None, :],
(0, i, 0),
)
)(
jnp.arange(N + 1)
) # only the terms [i, :, i, :] are nonzero
Hcost_xsxs = Hcost_xsxs.at[-1, :, -1, :].set(Qf) # last one is different

# trivial cross-terms of Hessian
Hcost_xsus = jnp.zeros((N + 1, nx, N, nu))
Hcost_xsparams = dummy_params

# Hessian for control inputs
Hcost_usus = jnp.zeros((N, nu, N, nu))
Hcost_usus = vmap(
lambda i: lax.dynamic_update_slice(
jnp.zeros((nu, N, nu)),
R[:, None, :],
(0, i, 0),
)
)(
jnp.arange(N)
) # only the terms [i, :, i, :] are nonzero

# trivial cross-terms and Hessian for params
Hcost_usparams = dummy_params
Hcost_paramsall = dummy_params
return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params
Loading