-
Notifications
You must be signed in to change notification settings - Fork 2
Trajectory Optimization API #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
49267de
extremely generic API for trajectory optimization
alberthli 1035dd4
fix spacing
alberthli 22f9867
remove unnecessary __init__ method (brainfart)
alberthli b8d7e8e
some additional massaging for best abstraction + useless scaffolding …
alberthli f5610fc
add some informative comments to base.py
alberthli 17642f6
[maybe] added cost function to base API
alberthli 4e1ae77
shooting method APIs, no cost function implemented yet
alberthli 8c9cd6b
define generic CostFunction API
alberthli 0d175f0
expose a CostFunction field of shooting methods
alberthli 987e334
add specific implementation of (non-sparse) quadratic CostFunction wi…
alberthli 2b268f8
fixed missing field in docstrings
alberthli 13efcfd
fixed missing field in docstrings
alberthli b6ea473
pass non-kwarg functions to vmap since vmap breaks in that case
alberthli 6bf291d
added preliminary dead simple predictive sampling example
alberthli 2cfdbfd
minor, get from property
alberthli dc5d247
fixed some params + simplify sampling of us to one line
alberthli ff6dbcc
[DIRTY] commit that contains commented out code for pre-allocating da…
alberthli 9f12d74
revert to allocating data at shooting time
alberthli 3fda8ba
[DIRTY] some additions to the example script for profiling help
alberthli 0f507e3
refactor API to take xs instead of qs and vs
alberthli b2d8296
tests for cost function and its derivatives
alberthli 0e24057
added smoke test for vanilla predictive sampler
alberthli adf4243
remove the example since it's implemented as a benchmark in an upstre…
alberthli 773be6a
minor docstring edit
alberthli 544f964
ensure predictive sampler also accounts for cost of guess vs. samples
alberthli 4e22073
add sanity check for predictive sampling
alberthli 0c49505
added fixture to tests
alberthli 90ea8f4
fix stray parenthetical
alberthli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| from typing import Tuple | ||
|
|
||
| import jax | ||
| from flax import struct | ||
| from jax import grad, hessian | ||
|
|
||
| # ####################### # | ||
| # TRAJECTORY OPTIMIZATION # | ||
| # ####################### # | ||
|
|
||
|
|
||
| @struct.dataclass | ||
| class TrajectoryOptimizerParams: | ||
| """The parameters for generic trajectory optimization algorithms. | ||
|
|
||
| Parameters we may want to optimize should be included here. | ||
|
|
||
| This is left completely empty to allow for maximum flexibility in the API. Some examples: | ||
| - A direct collocation method might have parameters for the number of collocation points, the collocation | ||
| scheme, and the number of optimization iterations. | ||
| - A shooting method might have parameters for the number of shooting points, the shooting scheme, and the number | ||
| of optimization iterations. | ||
| The parameters also include initial iterates for each type of algorithm. Some examples: | ||
| - A direct collocation method might have initial iterates for the controls and the state trajectory. | ||
| - A shooting method might have initial iterates for the controls only. | ||
|
|
||
| Parameters which we want to remain untouched by JAX transformations can be marked by pytree_node=False, e.g., | ||
| ``` | ||
| @struct.dataclass | ||
| class ChildParams: | ||
| ... | ||
| # example field | ||
| example: int = struct.field(pytree_node=False) | ||
| ... | ||
| ``` | ||
| """ | ||
|
|
||
|
|
||
| @struct.dataclass | ||
| class TrajectoryOptimizer: | ||
| """The API for generic trajectory optimization algorithms on mechanical systems. | ||
|
|
||
| We choose to implement this as a flax dataclass (as opposed to a regular class whose functions operate on pytree | ||
| nodes) because: | ||
| (1) the OOP formalism allows us to define coherent abstractions through inheritance; | ||
| (2) struct.dataclass registers dataclasses a pytree nodes, so we can deal with awkward issues like the `self` | ||
| variable when using JAX transformations on methods of the dataclass. | ||
|
|
||
| Further, we choose not to specify the mjx.Model as either a field of this dataclass or as a parameter. The reason is | ||
| because we want to allow for maximum flexibility in the API. Two motivating scenarios: | ||
| (1) we want to domain randomize over the model parameters and potentially optimize for them. In this case, it makes | ||
| sense to specify the mjx.Model as a parameter that gets passed as an input into the optimize function. | ||
| (2) we want to fix the model and only randomize/optimize over non-model-parameters. For instance, this is the | ||
| situation in vanilla predictive sampling. If we don't need to pass the model, we instead initialize it as a | ||
| field of this dataclass, which makes the optimize function more performant, since it can just reference the | ||
| fixed model attribute of the optimizer instead of applying JAX transformations to the entire large model pytree. | ||
| Similar logic applies for not specifying the role of the CostFunction - we trust that the user will either use the | ||
| provided API or will ignore it and still end up implementing something custom and reasonable. | ||
|
|
||
| Finally, abstract dataclasses are weird, so we just make all children implement the below functions by instead | ||
| raising a NotImplementedError. | ||
| """ | ||
|
|
||
| def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: | ||
| """Optimizes a trajectory. | ||
|
|
||
| The shapes of the outputs include (?) because we may choose to return non-zero-order-hold parameterizations of | ||
| the optimized trajectories (for example, we could choose to return a cubic spline parameterization of the | ||
| control inputs over the trajectory as is done in the gradient-based methods of MJPC). | ||
|
|
||
| Args: | ||
| params: The parameters of the trajectory optimizer. | ||
|
|
||
| Returns: | ||
| xs_star (shape=(N + 1, nq + nv) or (?)): The optimized trajectory. | ||
| us_star (shape=(N, nu) or (?)): The optimized controls. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| # ############# # | ||
| # COST FUNCTION # | ||
| # ############# # | ||
|
|
||
|
|
||
| @struct.dataclass | ||
| class CostFunctionParams: | ||
| """Generic parameters for cost functions.""" | ||
|
|
||
|
|
||
| @struct.dataclass | ||
| class CostFunction: | ||
| """The API for generic cost functions for trajectory optimization problems for mechanical systems. | ||
|
|
||
| Rationale behind CostFunctionParams in this generic API: | ||
| (1) computation of higher-order derivatives could depend on results or intermediates from lower-order derivatives. | ||
| So, we can flexibly cache the requisite values to avoid repeated computation; | ||
| (2) we may want to randomize or optimize the cost function parameters themselves, so specifying a generic pytree | ||
| as input generically accounts for all possibilities; | ||
| (3) there could simply be parameters that cannot be easily specified in advance that are key for cost evaluation, | ||
| like a time-varying reference trajectory that gets updated in real time. | ||
| (4) histories of higher-order derivatives can be useful for updating their current estimates, e.g., BFGS. | ||
| """ | ||
|
|
||
| def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]: | ||
| """Computes the cost of a trajectory. | ||
|
|
||
| Args: | ||
| xs (shape=(N + 1, nq + nv)): The state trajectory. | ||
| us (shape=(N, nu)): The controls over the trajectory. | ||
| params: The parameters of the cost function. | ||
|
|
||
| Returns: | ||
| val (shape=(,)): The cost of the trajectory. | ||
| new_params: The updated parameters of the cost function. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def grad( | ||
| self, xs: jax.Array, us: jax.Array, params: CostFunctionParams | ||
| ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: | ||
| """Computes the gradient of the cost of a trajectory. | ||
|
|
||
| The default implementation of this function uses JAX's autodiff. Simply override this function if you would like | ||
| to supply an analytical gradient. | ||
|
|
||
| Args: | ||
| xs (shape=(N + 1, nq + nv)): The state trajectory. | ||
| us (shape=(N, nu)): The controls over the trajectory. | ||
| params: The parameters of the cost function. | ||
|
|
||
| Returns: | ||
| gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs. | ||
| gcost_us (shape=(N, nu)): The gradient of the cost wrt us. | ||
| gcost_params: The gradient of the cost wrt params. | ||
| new_params: The updated parameters of the cost function. | ||
| """ | ||
| _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val | ||
| return grad(_fn, argnums=(0, 1, 2))(xs, us, params) + (params,) | ||
|
|
||
| def hess( | ||
| self, xs: jax.Array, us: jax.Array, params: CostFunctionParams | ||
| ) -> Tuple[ | ||
| jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams | ||
| ]: | ||
| """Computes the Hessian of the cost of a trajectory. | ||
|
|
||
| The default implementation of this function uses JAX's autodiff. Simply override this function if you would like | ||
| to supply an analytical Hessian. | ||
|
|
||
| Let t, s be times 0, 1, 2, etc. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. | ||
|
|
||
| Args: | ||
| xs (shape=(N + 1, nq + nv)): The state trajectory. | ||
| us (shape=(N, nu)): The controls over the trajectory. | ||
| params: The parameters of the cost function. | ||
|
|
||
| Returns: | ||
| Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. | ||
| Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us. | ||
| Hcost_xsparams: The Hessian of the cost wrt xs and params. | ||
| Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. | ||
| Hcost_usparams: The Hessian of the cost wrt us and params. | ||
| Hcost_paramsall: The Hessian of the cost wrt params and everything else. | ||
| new_params: The updated parameters of the cost function. | ||
| """ | ||
| _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val | ||
| hessians = hessian(_fn, argnums=(0, 1, 2))(xs, us, params) | ||
| Hcost_xsxs, Hcost_xsus, Hcost_xsparams = hessians[0] | ||
| _, Hcost_usus, Hcost_usparams = hessians[1] | ||
| Hcost_paramsall = hessians[2] | ||
| return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| from typing import Tuple | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from flax import struct | ||
| from jax import lax, vmap | ||
|
|
||
| from ambersim.trajopt.base import CostFunction, CostFunctionParams | ||
|
|
||
| """A collection of common cost functions.""" | ||
|
|
||
|
|
||
| class StaticGoalQuadraticCost(CostFunction): | ||
| """A quadratic cost function that penalizes the distance to a static goal. | ||
|
|
||
| This is the most vanilla possible quadratic cost. The cost matrices are static (defined at init time) and so is the | ||
| single, fixed goal. The gradient is as compressed as it can be in general (one matrix multiplication), but the | ||
| Hessian can be far more compressed by simplying referencing Q, Qf, and R - this implementation is inefficient and | ||
| dense. | ||
| """ | ||
|
|
||
| def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, xg: jax.Array) -> None: | ||
| """Initializes a quadratic cost function. | ||
|
|
||
| Args: | ||
| Q (shape=(nx, nx)): The state cost matrix. | ||
| Qf (shape=(nx, nx)): The final state cost matrix. | ||
| R (shape=(nu, nu)): The control cost matrix. | ||
| xg (shape=(nq,)): The goal state. | ||
| """ | ||
| self.Q = Q | ||
| self.Qf = Qf | ||
| self.R = R | ||
| self.xg = xg | ||
|
|
||
| @staticmethod | ||
| def batch_quadform(bs: jax.Array, A: jax.Array) -> jax.Array: | ||
| """Computes a batched quadratic form for a single instance of A. | ||
|
|
||
| Args: | ||
| bs (shape=(..., n)): The batch of vectors. | ||
| A (shape=(n, n)): The matrix. | ||
|
|
||
| Returns: | ||
| val (shape=(...,)): The batch of quadratic forms. | ||
| """ | ||
| return jnp.einsum("...i,ij,...j->...", bs, A, bs) | ||
|
|
||
| @staticmethod | ||
| def batch_matmul(bs: jax.Array, A: jax.Array) -> jax.Array: | ||
| """Computes a batched matrix multiplication for a single instance of A. | ||
|
|
||
| Args: | ||
| bs (shape=(..., n)): The batch of vectors. | ||
| A (shape=(n, n)): The matrix. | ||
|
|
||
| Returns: | ||
| val (shape=(..., n)): The batch of matrix multiplications. | ||
| """ | ||
| return jnp.einsum("...i,ij->...j", bs, A) | ||
|
|
||
| def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]: | ||
| """Computes the cost of a trajectory. | ||
|
|
||
| cost = 0.5 * (xs - xg)' @ Q @ (xs - xg) + 0.5 * us' @ R @ us | ||
|
|
||
| Args: | ||
| xs (shape=(N + 1, nq + nv)): The state trajectory. | ||
| us (shape=(N, nu)): The controls over the trajectory. | ||
| params: Unused. Included for API compliance. | ||
|
|
||
| Returns: | ||
| cost_val: The cost of the trajectory. | ||
| new_params: Unused. Included for API compliance. | ||
| """ | ||
| xs_err = xs[:-1, :] - self.xg # errors before the terminal state | ||
| xf_err = xs[-1, :] - self.xg | ||
| val = 0.5 * jnp.squeeze( | ||
| ( | ||
| jnp.sum(self.batch_quadform(xs_err, self.Q)) | ||
| + self.batch_quadform(xf_err, self.Qf) | ||
| + jnp.sum(self.batch_quadform(us, self.R)) | ||
| ) | ||
| ) | ||
| return val, params | ||
|
|
||
| def grad( | ||
| self, xs: jax.Array, us: jax.Array, params: CostFunctionParams | ||
| ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: | ||
| """Computes the gradient of the cost of a trajectory. | ||
|
|
||
| Args: | ||
| xs (shape=(N + 1, nq + nv)): The state trajectory. | ||
| us (shape=(N, nu)): The controls over the trajectory. | ||
| params: Unused. Included for API compliance. | ||
|
|
||
| Returns: | ||
| gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs. | ||
| gcost_us (shape=(N, nu)): The gradient of the cost wrt us. | ||
| gcost_params: Unused. Included for API compliance. | ||
| new_params: Unused. Included for API compliance. | ||
| """ | ||
| xs_err = xs[:-1, :] - self.xg # errors before the terminal state | ||
| xf_err = xs[-1, :] - self.xg | ||
| gcost_xs = jnp.concatenate( | ||
| ( | ||
| self.batch_matmul(xs_err, self.Q), | ||
| (self.Qf @ xf_err)[None, :], | ||
| ), | ||
| axis=-2, | ||
| ) | ||
| gcost_us = self.batch_matmul(us, self.R) | ||
| return gcost_xs, gcost_us, params, params | ||
|
|
||
| def hess( | ||
| self, xs: jax.Array, us: jax.Array, params: CostFunctionParams | ||
| ) -> Tuple[ | ||
| jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams | ||
| ]: | ||
| """Computes the gradient of the cost of a trajectory. | ||
|
|
||
| Let t, s be times 0, 1, 2, etc. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. | ||
|
|
||
| Args: | ||
| xs (shape=(N + 1, nq + nv)): The state trajectory. | ||
| us (shape=(N, nu)): The controls over the trajectory. | ||
| params: Unused. Included for API compliance. | ||
|
|
||
| Returns: | ||
| Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. | ||
| Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us. | ||
| Hcost_xsparams: The Hessian of the cost wrt xs and params. | ||
| Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. | ||
| Hcost_usparams: The Hessian of the cost wrt us and params. | ||
| Hcost_paramsall: The Hessian of the cost wrt params and everything else. | ||
| new_params: The updated parameters of the cost function. | ||
| """ | ||
| # setting up | ||
| nx = self.Q.shape[0] | ||
| N, nu = us.shape | ||
| Q = self.Q | ||
| Qf = self.Qf | ||
| R = self.R | ||
| dummy_params = CostFunctionParams() | ||
|
|
||
| # Hessian for state | ||
| Hcost_xsxs = jnp.zeros((N + 1, nx, N + 1, nx)) | ||
| Hcost_xsxs = vmap( | ||
| lambda i: lax.dynamic_update_slice( | ||
| jnp.zeros((nx, N + 1, nx)), | ||
| Q[:, None, :], | ||
| (0, i, 0), | ||
| ) | ||
| )( | ||
| jnp.arange(N + 1) | ||
| ) # only the terms [i, :, i, :] are nonzero | ||
| Hcost_xsxs = Hcost_xsxs.at[-1, :, -1, :].set(Qf) # last one is different | ||
|
|
||
| # trivial cross-terms of Hessian | ||
| Hcost_xsus = jnp.zeros((N + 1, nx, N, nu)) | ||
| Hcost_xsparams = dummy_params | ||
|
|
||
| # Hessian for control inputs | ||
| Hcost_usus = jnp.zeros((N, nu, N, nu)) | ||
| Hcost_usus = vmap( | ||
| lambda i: lax.dynamic_update_slice( | ||
| jnp.zeros((nu, N, nu)), | ||
| R[:, None, :], | ||
| (0, i, 0), | ||
| ) | ||
| )( | ||
| jnp.arange(N) | ||
| ) # only the terms [i, :, i, :] are nonzero | ||
|
|
||
| # trivial cross-terms and Hessian for params | ||
| Hcost_usparams = dummy_params | ||
| Hcost_paramsall = dummy_params | ||
| return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.