Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
49267de
extremely generic API for trajectory optimization
alberthli Dec 2, 2023
1035dd4
fix spacing
alberthli Dec 2, 2023
22f9867
remove unnecessary __init__ method (brainfart)
alberthli Dec 2, 2023
b8d7e8e
some additional massaging for best abstraction + useless scaffolding …
alberthli Dec 2, 2023
f5610fc
add some informative comments to base.py
alberthli Dec 3, 2023
17642f6
[maybe] added cost function to base API
alberthli Dec 3, 2023
4e1ae77
shooting method APIs, no cost function implemented yet
alberthli Dec 3, 2023
8c9cd6b
define generic CostFunction API
alberthli Dec 3, 2023
0d175f0
expose a CostFunction field of shooting methods
alberthli Dec 3, 2023
987e334
add specific implementation of (non-sparse) quadratic CostFunction wi…
alberthli Dec 3, 2023
2b268f8
fixed missing field in docstrings
alberthli Dec 4, 2023
13efcfd
fixed missing field in docstrings
alberthli Dec 4, 2023
b6ea473
pass non-kwarg functions to vmap since vmap breaks in that case
alberthli Dec 4, 2023
6bf291d
added preliminary dead simple predictive sampling example
alberthli Dec 4, 2023
2cfdbfd
minor, get from property
alberthli Dec 4, 2023
dc5d247
fixed some params + simplify sampling of us to one line
alberthli Dec 4, 2023
ff6dbcc
[DIRTY] commit that contains commented out code for pre-allocating da…
alberthli Dec 4, 2023
9f12d74
revert to allocating data at shooting time
alberthli Dec 4, 2023
3fda8ba
[DIRTY] some additions to the example script for profiling help
alberthli Dec 4, 2023
20687cb
added dependencies for profiling
alberthli Dec 5, 2023
441c688
README instructions updated to include profiling guidelines
alberthli Dec 5, 2023
2e6c649
moved timing script to benchmarks
alberthli Dec 5, 2023
288cfb4
remove accidentally committed pngs
alberthli Dec 5, 2023
68fdbbd
delete some comments
alberthli Dec 5, 2023
8f32b49
remove some dependencies to upgrade to cuda 12
alberthli Dec 5, 2023
94409c4
run code checks when ready for review
alberthli Dec 5, 2023
b51d2de
pray that tensorflow 2.13 passes tests
alberthli Dec 5, 2023
6a39ae2
quick gutcheck to see what is causing test failure
alberthli Dec 5, 2023
f164fe5
sanity check 2: whether pinning to an older version causes install er…
alberthli Dec 5, 2023
3d121ca
typo
alberthli Dec 5, 2023
0f507e3
refactor API to take xs instead of qs and vs
alberthli Dec 5, 2023
b2d8296
tests for cost function and its derivatives
alberthli Dec 5, 2023
0e24057
added smoke test for vanilla predictive sampler
alberthli Dec 5, 2023
adf4243
remove the example since it's implemented as a benchmark in an upstre…
alberthli Dec 5, 2023
773be6a
minor docstring edit
alberthli Dec 5, 2023
039a4f2
Merge branch 'trajopt-api' of github.com:Caltech-AMBER/ambersim into …
alberthli Dec 5, 2023
544f964
ensure predictive sampler also accounts for cost of guess vs. samples
alberthli Dec 5, 2023
4e22073
add sanity check for predictive sampling
alberthli Dec 5, 2023
a4da8e8
Merge branch 'trajopt-api' of github.com:Caltech-AMBER/ambersim into …
alberthli Dec 5, 2023
0c49505
added fixture to tests
alberthli Dec 5, 2023
90ea8f4
fix stray parenthetical
alberthli Dec 5, 2023
e2cc740
Merge branch 'trajopt-api' of github.com:Caltech-AMBER/ambersim into …
alberthli Dec 5, 2023
bd3add2
update README + pull upstream changes
alberthli Dec 5, 2023
1236f5f
Merge branch 'main' of github.com:Caltech-AMBER/ambersim into trajopt…
alberthli Dec 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/code_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches: [main]
pull_request:
branches: [main]
types: [ready_for_review]

permissions:
contents: read
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,7 @@ cython_debug/

# mujoco
MUJOCO_LOG.TXT
mujoco
mujoco

# media
*.png
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,43 @@ Major versioning decisions:
* `python=3.11.5`. `torch`, `jax`, and `mujoco` all support it and there are major reported speed improvements over `python` 3.10.
* `cuda==11.8`. Both `torch` and `jax` support `cuda12`; however, they annoyingly support different minor versions which makes them [incompatible in the same environment](https://github.com/google/jax/issues/18032). Once this is resolved, we will upgrade to `cuda-12.2` or later. It seems most likely that `torch` will support `cuda-12.3` once they do upgrade, since that is the most recent release.

### Code Profiling
The majority of the profiling we do will be on JAX code. Here's a generic template for profiling code using `tensorboard` (you need the test dependencies):
```
# create the function you want to profile - we recommend jitting it, since this typically changes the profiling results
def fn_to_profile():
...

jit_fn = jit(fn_to_profile)

# choose a path to store the profiling results
with jax.profiler.trace("/dir/to/profiling/results"):
jit_fn(inputs)
```
To view the profiling results, run
```
tensorboard --logdir=/dir/to/profiling/results --port <port>
```
where `--port` should be some open port like `8008`. In the top right dropdown menu which should say "Inactive," scroll down and select "Profile." Select the run you'd like to analyze and under tools, the most useful tab will usually be "trace_viewer."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might consider adding some screenshots for folks who are unfamiliar with flame charts


Sometimes, we want to expose certain subroutines to the profiler. We can do so with the following:
```
# in one file
def fn():
# stuff that we don't want to profile
fn1()

# stuff we do want to specifically profile
with jax.named_scope("name_of_your_choice"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick q: I know TraceAnnotation didn't work here but it's maybe good to note this here since the jax profiling tool recommends it

fn2()

# in another file containing the jitted function to profile
jit_fn = jit(fn)
with jax.profiler.trace("/dir/to/profiling/results"):
jit_fn()
```
Now, the traced results will specifically show the time spent in `fn2` under the name you chose. Note that you can also use `jax.profiler.TraceAnnotation` or `jax.profiler.annotate_function()` instead, [as recommended](https://jax.readthedocs.io/en/latest/profiling.html#adding-custom-trace-events).

### Tooling
We use various tools to ensure code quality.

Expand Down
99 changes: 99 additions & 0 deletions benchmarks/trajopt/bm_predictive_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import timeit

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import jit
from mujoco import mjx
from mujoco.mjx._src.types import DisableBit

from ambersim.trajopt.cost import CostFunction, StaticGoalQuadraticCost
from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams
from ambersim.utils.io_utils import load_mjx_model_and_data_from_file


def make_ps(model: mjx.Model, cost_function: CostFunction, nsamples: int) -> VanillaPredictiveSampler:
"""Makes a predictive sampler for this quick and dirty timing script."""
stdev = 0.01
ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev)
return ps


if __name__ == "__main__":
# initializing the model
model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False)
model = model.replace(
opt=model.opt.replace(
timestep=0.002, # dt
iterations=1, # number of Newton steps to take during solve
ls_iterations=4, # number of line search iterations along step direction
integrator=0, # Euler semi-implicit integration
solver=2, # Newton solver
disableflags=DisableBit.CONTACT, # [IMPORTANT] disable contact for this example
)
)

# initializing the cost function
cost_function = StaticGoalQuadraticCost(
Q=jnp.eye(model.nq + model.nv),
Qf=10.0 * jnp.eye(model.nq + model.nv),
R=0.01 * jnp.eye(model.nu),
# qg=jnp.zeros(model.nq).at[6].set(1.0), # if force_float=True
qg=jnp.zeros(model.nq),
vg=jnp.zeros(model.nv),
)

# sampler parameters we pass in independent of the number of samples
key = jax.random.PRNGKey(0) # random seed for the predictive sampler
q0 = jnp.zeros(model.nq).at[6].set(1.0)
v0 = jnp.zeros(model.nv)
num_steps = 10
us_guess = jnp.zeros((num_steps, model.nu))
params = VanillaPredictiveSamplerParams(key=key, q0=q0, v0=v0, us_guess=us_guess)

nsamples_list = [1, 10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 4000, 5000, 10000]
runtimes = []
throughputs = []
for nsamples in nsamples_list:
print(f"Running with nsamples={nsamples}...")
ps = make_ps(model, cost_function, nsamples)
optimize_fn = jit(ps.optimize)

# # [DEBUG] profiling with tensorboard
# qs_star, vs_star, us_star = optimize_fn(params) # JIT compiling
# with jax.profiler.trace("/home/albert/tensorboard"):
# qs_star, vs_star, us_star = optimize_fn(params) # after JIT

def _time_fn(fn=optimize_fn) -> None:
"""Function to time runtime."""
qs_star, vs_star, us_star = fn(params)
qs_star.block_until_ready()
vs_star.block_until_ready()
us_star.block_until_ready()

compile_time = timeit.timeit(_time_fn, number=1)
print(f" Compile time: {compile_time}")

num_timing_iters = 100
time = timeit.timeit(_time_fn, number=num_timing_iters)
print(f" Avg. runtime: {time / num_timing_iters}") # returns TOTAL time, so compute the average ourselves

runtimes.append(time / num_timing_iters)
throughputs.append(nsamples / (time / num_timing_iters))

plt.scatter(np.array(nsamples_list), np.array(runtimes))
plt.xlabel("number of samples")
plt.ylabel("runtime (s)")
plt.title("Predictive Sampling: Number of Samples vs. Runtime")
plt.xlim([-100, max(nsamples_list) + 100])
plt.ylim([0, max(runtimes) + 0.01])
plt.show()

plt.scatter(np.array(nsamples_list), np.array(throughputs))
plt.xlabel("number of samples")
plt.ylabel("samples per second (s)")
plt.title("Predictive Sampling: Sampling Throughput vs. Number of Samples")
plt.xlim([-100, max(nsamples_list) + 100])
plt.ylim([0, max(throughputs) + 10000])
plt.show()
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
channels:
- nvidia/label/cuda-11.8.0
- nvidia/label/cuda-12.3.0
- conda-forge
dependencies:
- cuda
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ dependencies = [
"coacd>=1.0.0",
"dm_control>=1.0.0",
"flax>=0.7.5",
"jax[cuda11_local]>=0.4.1",
"jax[cuda12_local]>=0.4.1",
"jaxlib>=0.4.1",
"matplotlib>=3.5.2",
"mujoco>=3.0.0",
"mujoco-mjx>=3.0.0",
"numpy>=1.23.1",
"scipy>=1.10.0",
"torch>=1.13.1",
"tensorboard>=2.15.1",
# "torch>=1.13.1",
"tensorboard>=2.13.0", # [Dec. 4, 2023] https://github.com/tensorflow/tensorflow/issues/62075#issuecomment-1808652131
Copy link
Contributor Author

@alberthli alberthli Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[NOTE] As it turns out, if I pin the version to 2.13.0, the code check workflow fails. However, we know that things are broken with the most recent tensorflow version, so we should probably not merge this yet.

]

[project.optional-dependencies]
Expand All @@ -43,10 +43,12 @@ dev = [

# Test-specific packages for verification
test = [
"cvxpy>=1.4.1",
"drake>=1.21.0",
# "cvxpy>=1.4.1",
# "drake>=1.21.0",
"libigl>=2.4.0",
"pin>=2.6.20",
"tensorflow>=2.13.0", # [Dec. 4, 2023] https://github.com/tensorflow/tensorflow/issues/62075#issuecomment-1808652131
"tensorboard-plugin-profile>=2.13.0",
Comment on lines +50 to +51
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment about pinning versions.

]

# All packages
Expand Down