-
Notifications
You must be signed in to change notification settings - Fork 2
Trajopt Profiling #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
49267de
1035dd4
22f9867
b8d7e8e
f5610fc
17642f6
4e1ae77
8c9cd6b
0d175f0
987e334
2b268f8
13efcfd
b6ea473
6bf291d
2cfdbfd
dc5d247
ff6dbcc
9f12d74
3fda8ba
20687cb
441c688
2e6c649
288cfb4
68fdbbd
8f32b49
94409c4
b51d2de
6a39ae2
f164fe5
3d121ca
0f507e3
b2d8296
0e24057
adf4243
773be6a
039a4f2
544f964
4e22073
a4da8e8
0c49505
90ea8f4
e2cc740
bd3add2
1236f5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,4 +164,7 @@ cython_debug/ | |
|
||
# mujoco | ||
MUJOCO_LOG.TXT | ||
mujoco | ||
mujoco | ||
|
||
# media | ||
*.png |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,43 @@ Major versioning decisions: | |
* `python=3.11.5`. `torch`, `jax`, and `mujoco` all support it and there are major reported speed improvements over `python` 3.10. | ||
* `cuda==11.8`. Both `torch` and `jax` support `cuda12`; however, they annoyingly support different minor versions which makes them [incompatible in the same environment](https://github.com/google/jax/issues/18032). Once this is resolved, we will upgrade to `cuda-12.2` or later. It seems most likely that `torch` will support `cuda-12.3` once they do upgrade, since that is the most recent release. | ||
|
||
### Code Profiling | ||
The majority of the profiling we do will be on JAX code. Here's a generic template for profiling code using `tensorboard` (you need the test dependencies): | ||
``` | ||
# create the function you want to profile - we recommend jitting it, since this typically changes the profiling results | ||
def fn_to_profile(): | ||
... | ||
|
||
jit_fn = jit(fn_to_profile) | ||
|
||
# choose a path to store the profiling results | ||
with jax.profiler.trace("/dir/to/profiling/results"): | ||
jit_fn(inputs) | ||
``` | ||
To view the profiling results, run | ||
``` | ||
tensorboard --logdir=/dir/to/profiling/results --port <port> | ||
``` | ||
where `--port` should be some open port like `8008`. In the top right dropdown menu which should say "Inactive," scroll down and select "Profile." Select the run you'd like to analyze and under tools, the most useful tab will usually be "trace_viewer." | ||
|
||
Sometimes, we want to expose certain subroutines to the profiler. We can do so with the following: | ||
``` | ||
# in one file | ||
def fn(): | ||
# stuff that we don't want to profile | ||
fn1() | ||
|
||
# stuff we do want to specifically profile | ||
with jax.named_scope("name_of_your_choice"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick q: I know |
||
fn2() | ||
|
||
# in another file containing the jitted function to profile | ||
jit_fn = jit(fn) | ||
with jax.profiler.trace("/dir/to/profiling/results"): | ||
jit_fn() | ||
``` | ||
Now, the traced results will specifically show the time spent in `fn2` under the name you chose. Note that you can also use `jax.profiler.TraceAnnotation` or `jax.profiler.annotate_function()` instead, [as recommended](https://jax.readthedocs.io/en/latest/profiling.html#adding-custom-trace-events). | ||
|
||
### Tooling | ||
We use various tools to ensure code quality. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import timeit | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from jax import jit | ||
from mujoco import mjx | ||
from mujoco.mjx._src.types import DisableBit | ||
|
||
from ambersim.trajopt.cost import CostFunction, StaticGoalQuadraticCost | ||
from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams | ||
from ambersim.utils.io_utils import load_mjx_model_and_data_from_file | ||
|
||
|
||
def make_ps(model: mjx.Model, cost_function: CostFunction, nsamples: int) -> VanillaPredictiveSampler: | ||
"""Makes a predictive sampler for this quick and dirty timing script.""" | ||
stdev = 0.01 | ||
ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) | ||
return ps | ||
|
||
|
||
if __name__ == "__main__": | ||
# initializing the model | ||
model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) | ||
model = model.replace( | ||
opt=model.opt.replace( | ||
timestep=0.002, # dt | ||
iterations=1, # number of Newton steps to take during solve | ||
ls_iterations=4, # number of line search iterations along step direction | ||
integrator=0, # Euler semi-implicit integration | ||
solver=2, # Newton solver | ||
disableflags=DisableBit.CONTACT, # [IMPORTANT] disable contact for this example | ||
) | ||
) | ||
|
||
# initializing the cost function | ||
cost_function = StaticGoalQuadraticCost( | ||
Q=jnp.eye(model.nq + model.nv), | ||
Qf=10.0 * jnp.eye(model.nq + model.nv), | ||
R=0.01 * jnp.eye(model.nu), | ||
# qg=jnp.zeros(model.nq).at[6].set(1.0), # if force_float=True | ||
qg=jnp.zeros(model.nq), | ||
vg=jnp.zeros(model.nv), | ||
) | ||
|
||
# sampler parameters we pass in independent of the number of samples | ||
key = jax.random.PRNGKey(0) # random seed for the predictive sampler | ||
q0 = jnp.zeros(model.nq).at[6].set(1.0) | ||
v0 = jnp.zeros(model.nv) | ||
num_steps = 10 | ||
us_guess = jnp.zeros((num_steps, model.nu)) | ||
params = VanillaPredictiveSamplerParams(key=key, q0=q0, v0=v0, us_guess=us_guess) | ||
|
||
nsamples_list = [1, 10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 4000, 5000, 10000] | ||
runtimes = [] | ||
throughputs = [] | ||
for nsamples in nsamples_list: | ||
print(f"Running with nsamples={nsamples}...") | ||
ps = make_ps(model, cost_function, nsamples) | ||
optimize_fn = jit(ps.optimize) | ||
|
||
# # [DEBUG] profiling with tensorboard | ||
# qs_star, vs_star, us_star = optimize_fn(params) # JIT compiling | ||
# with jax.profiler.trace("/home/albert/tensorboard"): | ||
# qs_star, vs_star, us_star = optimize_fn(params) # after JIT | ||
|
||
def _time_fn(fn=optimize_fn) -> None: | ||
"""Function to time runtime.""" | ||
qs_star, vs_star, us_star = fn(params) | ||
qs_star.block_until_ready() | ||
vs_star.block_until_ready() | ||
us_star.block_until_ready() | ||
|
||
compile_time = timeit.timeit(_time_fn, number=1) | ||
print(f" Compile time: {compile_time}") | ||
|
||
num_timing_iters = 100 | ||
time = timeit.timeit(_time_fn, number=num_timing_iters) | ||
print(f" Avg. runtime: {time / num_timing_iters}") # returns TOTAL time, so compute the average ourselves | ||
|
||
runtimes.append(time / num_timing_iters) | ||
throughputs.append(nsamples / (time / num_timing_iters)) | ||
|
||
plt.scatter(np.array(nsamples_list), np.array(runtimes)) | ||
plt.xlabel("number of samples") | ||
plt.ylabel("runtime (s)") | ||
plt.title("Predictive Sampling: Number of Samples vs. Runtime") | ||
plt.xlim([-100, max(nsamples_list) + 100]) | ||
plt.ylim([0, max(runtimes) + 0.01]) | ||
plt.show() | ||
|
||
plt.scatter(np.array(nsamples_list), np.array(throughputs)) | ||
plt.xlabel("number of samples") | ||
plt.ylabel("samples per second (s)") | ||
plt.title("Predictive Sampling: Sampling Throughput vs. Number of Samples") | ||
plt.xlim([-100, max(nsamples_list) + 100]) | ||
plt.ylim([0, max(throughputs) + 10000]) | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
channels: | ||
- nvidia/label/cuda-11.8.0 | ||
- nvidia/label/cuda-12.3.0 | ||
- conda-forge | ||
dependencies: | ||
- cuda | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,15 +18,15 @@ dependencies = [ | |
"coacd>=1.0.0", | ||
"dm_control>=1.0.0", | ||
"flax>=0.7.5", | ||
"jax[cuda11_local]>=0.4.1", | ||
"jax[cuda12_local]>=0.4.1", | ||
"jaxlib>=0.4.1", | ||
"matplotlib>=3.5.2", | ||
"mujoco>=3.0.0", | ||
"mujoco-mjx>=3.0.0", | ||
"numpy>=1.23.1", | ||
"scipy>=1.10.0", | ||
"torch>=1.13.1", | ||
"tensorboard>=2.15.1", | ||
# "torch>=1.13.1", | ||
"tensorboard>=2.13.0", # [Dec. 4, 2023] https://github.com/tensorflow/tensorflow/issues/62075#issuecomment-1808652131 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [NOTE] As it turns out, if I pin the version to |
||
] | ||
|
||
[project.optional-dependencies] | ||
|
@@ -43,10 +43,12 @@ dev = [ | |
|
||
# Test-specific packages for verification | ||
test = [ | ||
"cvxpy>=1.4.1", | ||
"drake>=1.21.0", | ||
# "cvxpy>=1.4.1", | ||
# "drake>=1.21.0", | ||
"libigl>=2.4.0", | ||
"pin>=2.6.20", | ||
"tensorflow>=2.13.0", # [Dec. 4, 2023] https://github.com/tensorflow/tensorflow/issues/62075#issuecomment-1808652131 | ||
"tensorboard-plugin-profile>=2.13.0", | ||
Comment on lines
+50
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above comment about pinning versions. |
||
] | ||
|
||
# All packages | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might consider adding some screenshots for folks who are unfamiliar with flame charts