Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 107 additions & 12 deletions learning/train_jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,16 @@
_LOAD_CHECKPOINT_PATH = flags.DEFINE_string(
"load_checkpoint_path", None, "Path to load checkpoint from"
)
_SAVE_PARAMS_PATH = flags.DEFINE_string(
"save_params_path", None, "Path to save parameters to"
)
_SUFFIX = flags.DEFINE_string("suffix", None, "Suffix for the experiment name")
_PLAY_ONLY = flags.DEFINE_boolean(
"play_only", False, "If true, only play with the model and do not train"
)
_RENDER_FINAL_POLICY = flags.DEFINE_boolean(
"render_final_policy", True, "If true, render the final policy"
)
_USE_WANDB = flags.DEFINE_boolean(
"use_wandb",
False,
Expand Down Expand Up @@ -132,6 +138,33 @@
"policy_obs_key", "state", "Policy obs key"
)
_VALUE_OBS_KEY = flags.DEFINE_string("value_obs_key", "state", "Value obs key")
_RSCOPE_ENVS = flags.DEFINE_integer(
"rscope_envs",
None,
"Number of parallel environment rollouts to save for the rscope viewer",
)
_DETERMINISTIC_RSCOPE = flags.DEFINE_boolean(
"deterministic_rscope",
True,
"Run deterministic rollouts for the rscope viewer",
)
_RUN_EVALS = flags.DEFINE_boolean(
"run_evals",
True,
"Run evaluation rollouts between policy updates.",
)
_LOG_TRAINING_METRICS = flags.DEFINE_boolean(
"log_training_metrics",
False,
"Whether to log training metrics and callback to progress_fn. Significantly"
" slows down training if too frequent.",
)
_TRAINING_METRICS_STEPS = flags.DEFINE_integer(
"training_metrics_steps",
1_000_000,
"Number of steps between logging training metrics. Increase if training"
" experiences slowdown.",
)


def get_rl_config(env_name: str) -> config_dict.ConfigDict:
Expand All @@ -151,6 +184,24 @@ def get_rl_config(env_name: str) -> config_dict.ConfigDict:
raise ValueError(f"Env {env_name} not found in {registry.ALL_ENVS}.")


def rscope_fn(full_states, obs, rew, done):
"""
All arrays are of shape (unroll_length, rscope_envs, ...)
full_states: dict with keys 'qpos', 'qvel', 'time', 'metrics'
obs: nd.array or dict obs based on env configuration
rew: nd.array rewards
done: nd.array done flags
"""
# Calculate cumulative rewards per episode, stopping at first done flag
done_mask = jp.cumsum(done, axis=0)
valid_rewards = rew * (done_mask == 0)
episode_rewards = jp.sum(valid_rewards, axis=0)
print(
"Collected rscope rollouts with reward"
f" {episode_rewards.mean():.3f} +- {episode_rewards.std():.3f}"
)


def main(argv):
"""Run training and evaluation for the specified environment."""

Expand Down Expand Up @@ -209,11 +260,16 @@ def main(argv):
ppo_params.network_factory.policy_obs_key = _POLICY_OBS_KEY.value
if _VALUE_OBS_KEY.present:
ppo_params.network_factory.value_obs_key = _VALUE_OBS_KEY.value

if _VISION.value:
env_cfg.vision = True
env_cfg.vision_config.render_batch_size = ppo_params.num_envs
env = registry.load(_ENV_NAME.value, config=env_cfg)
if _RUN_EVALS.present:
ppo_params.run_evals = _RUN_EVALS.value
if _LOG_TRAINING_METRICS.present:
ppo_params.log_training_metrics = _LOG_TRAINING_METRICS.value
if _TRAINING_METRICS_STEPS.present:
ppo_params.training_metrics_steps = _TRAINING_METRICS_STEPS.value

print(f"Environment Config:\n{env_cfg}")
print(f"PPO Training Parameters:\n{ppo_params}")
Expand Down Expand Up @@ -260,21 +316,17 @@ def main(argv):
restore_checkpoint_path = None

# Set up checkpoint directory
ckpt_path = logdir / "checkpoints"
if _SAVE_PARAMS_PATH.value is not None:
ckpt_path = epath.Path(_SAVE_PARAMS_PATH.value).resolve() / "checkpoints"
else:
ckpt_path = logdir / "checkpoints"
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"Checkpoint path: {ckpt_path}")

# Save environment configuration
with open(ckpt_path / "config.json", "w", encoding="utf-8") as fp:
json.dump(env_cfg.to_dict(), fp, indent=4)

# Define policy parameters function for saving checkpoints
def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
path = ckpt_path / f"{current_step}"
orbax_checkpointer.save(path, params, force=True, save_args=save_args)

training_params = dict(ppo_params)
if "network_factory" in training_params:
del training_params["network_factory"]
Expand Down Expand Up @@ -319,9 +371,9 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
ppo.train,
**training_params,
network_factory=network_factory,
policy_params_fn=policy_params_fn,
seed=_SEED.value,
restore_checkpoint_path=restore_checkpoint_path,
save_checkpoint_path=ckpt_path,
wrap_env_fn=None if _VISION.value else wrapper.wrap_for_brax_training,
num_eval_envs=num_eval_envs,
)
Expand All @@ -341,18 +393,55 @@ def progress(num_steps, metrics):
for key, value in metrics.items():
writer.add_scalar(key, value, num_steps)
writer.flush()

print(f"{num_steps}: reward={metrics['eval/episode_reward']:.3f}")
if _RUN_EVALS.value:
print(f"{num_steps}: reward={metrics['eval/episode_reward']:.3f}")
if _LOG_TRAINING_METRICS.value:
if "episode/sum_reward" in metrics:
print(
f"{num_steps}: mean episode"
f" reward={metrics['episode/sum_reward']:.3f}"
)

# Load evaluation environment
eval_env = (
None if _VISION.value else registry.load(_ENV_NAME.value, config=env_cfg)
)

policy_params_fn = lambda *args: None
if _RSCOPE_ENVS.value:
# Interactive visualisation of policy checkpoints
from rscope import brax as rscope_utils

if not _VISION.value:
rscope_env = registry.load(_ENV_NAME.value, config=env_cfg)
rscope_env = wrapper.wrap_for_brax_training(
rscope_env,
episode_length=ppo_params.episode_length,
action_repeat=ppo_params.action_repeat,
randomization_fn=training_params.get("randomization_fn"),
)
else:
rscope_env = env

rscope_handle = rscope_utils.BraxRolloutSaver(
rscope_env,
ppo_params,
_VISION.value,
_RSCOPE_ENVS.value,
_DETERMINISTIC_RSCOPE.value,
jax.random.PRNGKey(_SEED.value),
rscope_fn,
)

def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
rscope_handle.set_make_policy(make_policy)
rscope_handle.dump_rollout(params)

# Train or load the model
make_inference_fn, params, _ = train_fn( # pylint: disable=no-value-for-parameter
environment=env,
progress_fn=progress,
policy_params_fn=policy_params_fn,
eval_env=None if _VISION.value else eval_env,
)

Expand All @@ -361,13 +450,19 @@ def progress(num_steps, metrics):
print(f"Time to JIT compile: {times[1] - times[0]}")
print(f"Time to train: {times[-1] - times[1]}")

if not _RENDER_FINAL_POLICY.value:
return

print("Starting inference...")

# Create inference function
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)

# Prepare for evaluation
eval_env = (
None if _VISION.value else registry.load(_ENV_NAME.value, config=env_cfg)
)
num_envs = 1
if _VISION.value:
eval_env = env
Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mujoco_playground._src.mjx_env import render_array
from mujoco_playground._src.mjx_env import State
from mujoco_playground._src.mjx_env import step

# pylint: enable=g-importing-member

__all__ = [
Expand Down
4 changes: 3 additions & 1 deletion mujoco_playground/_src/dm_control_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)
4 changes: 3 additions & 1 deletion mujoco_playground/_src/locomotion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
3 changes: 1 addition & 2 deletions mujoco_playground/_src/locomotion/t1/randomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from mujoco import mjx
import numpy as np


FLOOR_GEOM_ID = 0
TORSO_BODY_ID = 1
ANKLE_JOINT_IDS = np.array([[21, 22, 27, 28]])
Expand All @@ -30,7 +29,7 @@ def rand_dynamics(rng):
# Floor friction: =U(0.4, 1.0).
rng, key = jax.random.split(rng)
geom_friction = model.geom_friction.at[FLOOR_GEOM_ID, 0].set(
jax.random.uniform(key, minval=0.2, maxval=.6)
jax.random.uniform(key, minval=0.2, maxval=0.6)
)

rng, key = jax.random.split(rng)
Expand Down
18 changes: 14 additions & 4 deletions mujoco_playground/_src/manipulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from mujoco import mjx

from mujoco_playground._src import mjx_env
from mujoco_playground._src.manipulation.aloha import distillation as aloha_distillation
from mujoco_playground._src.manipulation.aloha import handover as aloha_handover
from mujoco_playground._src.manipulation.aloha import single_peg_insertion as aloha_peg
from mujoco_playground._src.manipulation.aloha import peg_insertion as aloha_peg_insertion
from mujoco_playground._src.manipulation.aloha import pick as aloha_pick
from mujoco_playground._src.manipulation.franka_emika_panda import open_cabinet as panda_open_cabinet
from mujoco_playground._src.manipulation.franka_emika_panda import pick as panda_pick
from mujoco_playground._src.manipulation.franka_emika_panda import pick_cartesian as panda_pick_cartesian
Expand All @@ -31,7 +33,9 @@

_envs = {
"AlohaHandOver": aloha_handover.HandOver,
"AlohaSinglePegInsertion": aloha_peg.SinglePegInsertion,
"AlohaPick": aloha_pick.Pick,
"AlohaPegInsertion": aloha_peg_insertion.SinglePegInsertion,
"AlohaPegInsertionDistill": aloha_distillation.DistillPegInsertion,
"PandaPickCube": panda_pick.PandaPickCube,
"PandaPickCubeOrientation": panda_pick.PandaPickCubeOrientation,
"PandaPickCubeCartesian": panda_pick_cartesian.PandaPickCubeCartesian,
Expand All @@ -43,7 +47,9 @@

_cfgs = {
"AlohaHandOver": aloha_handover.default_config,
"AlohaSinglePegInsertion": aloha_peg.default_config,
"AlohaPick": aloha_pick.default_config,
"AlohaPegInsertion": aloha_peg_insertion.default_config,
"AlohaPegInsertionDistill": aloha_distillation.default_config,
"PandaPickCube": panda_pick.default_config,
"PandaPickCubeOrientation": panda_pick.default_config,
"PandaPickCubeCartesian": panda_pick_cartesian.default_config,
Expand All @@ -56,6 +62,8 @@
_randomizer = {
"LeapCubeRotateZAxis": leap_rotate_z.domain_randomize,
"LeapCubeReorient": leap_cube_reorient.domain_randomize,
"AlohaPick": aloha_pick.domain_randomize,
"AlohaPegInsertionDistill": aloha_distillation.domain_randomize,
}


Expand Down Expand Up @@ -108,7 +116,9 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
84 changes: 84 additions & 0 deletions mujoco_playground/_src/manipulation/aloha/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
### Quickstart


**Pre-requisites**

- *Handover, Pick, Peg Insertion:* The standard Playground setup
- *Behaviour Cloning for Peg Insertion:* Madrona MJX
- *Jax-to-ONNX Conversion:* Onnx, Tensorflow, tf2onnx

```bash
# Train Aloha Handover. Documentation at https://github.com/google-deepmind/mujoco_playground/pull/29
python learning/train_jax_ppo.py --env_name AlohaHandOver
```

```bash
# Plots for pick and peg-insertion at https://github.com/google-deepmind/mujoco_playground/pull/76
cd <PATH_TO_YOUR_CLONE>
export PARAMS_PATH=mujoco_playground/_src/manipulation/aloha/params

# Train a single arm to pick up a cube.
python learning/train_jax_ppo.py --env_name AlohaPick --domain_randomization --norender_final_policy --save_params_path $PARAMS_PATH/AlohaPick
sleep 0.5

# Train a biarm to insert a peg into a socket. Requires above policy.
python learning/train_jax_ppo.py --env_name AlohaPegInsertion --save_params_path $PARAMS_PATH/AlohaPegInsertion
sleep 0.5

# Train a student policy to insert a peg into a socket using *pixel inputs*. Requires above policy.
python mujoco_playground/experimental/bc_peg_insertion.py --domain-randomization --num-evals 0 --print-loss

# Convert checkpoints from the above run to ONNX for easy robot deployment.
# ONNX policies are written to `experimental/jax2onnx/onnx_policies`.
python mujoco_playground/experimental/jax2onnx/aloha_nets_to_onnx.py --checkpoint_path <YOUR_DISTILL_CHECKPOINT_DIR>
```

### Sim-to-Real Transfer of a Bi-Arm RL Policy via Pixel-Based Behaviour Cloning

https://github.com/user-attachments/assets/205fe8b9-1773-4715-8025-5de13490d0da

---

**Distillation**

In this module, we demonstrate policy distillation: a straightforward method for deploying a simulation-trained reinforcement learning policy that initially uses privileged state observations (such as object positions). The process involves two steps:

1. **Teacher Policy Training:** A state-based teacher policy is trained using RL.
2. **Student Policy Distillation:** The teacher is then distilled into a student policy via behaviour cloning (BC), where the student learns to map its observations $o_s(x)$ (e.g., exteroceptive RGBD images) to the teacher’s deterministic actions $\pi_t(o_t(x))$. For example, while both policies observe joint angles, the student uses RGBD images, whereas the teacher directly accesses (noisy) object positions.

The distillation process—where the student uses left and right wrist-mounted RGBD cameras for exteroception—takes about **3 minutes** on an RTX4090. This rapid turnaround is due to three factors:

1. [Very fast rendering](https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/experimental/madrona_benchmarking/figures/cartpole_benchmark_full.png) provided by Madrona MJX.
2. The sample efficiency of behaviour cloning.
3. The use of low-resolution (32×32) rendering, which is sufficient for precise alignment given the wrist camera placement.

For further details on the teacher policy and RGBD sim-to-real techniques, please refer to the [technical report](https://docs.google.com/presentation/d/1v50Vg-SJdy5HV5JmPHALSwph9mcVI2RSPRdrxYR3Bkg/edit?usp=sharing).

---

**A Note on Sample Efficiency**

Behaviour cloning (BC) can be orders of magnitude more sample-efficient than reinforcement learning. In our approach, we use an L2 loss defined as:

$|| \pi_s(o_s(x)) - \pi_t(o_t(x)) ||$

In contrast, the policy gradient in RL generally takes the form:

![Equation](https://latex.codecogs.com/svg.latex?\nabla_\theta%20J(\theta)%20=%20\mathbb{E}_{\tau%20\sim%20\theta}%20\left[\sum_t%20\nabla_\theta%20\log%20\pi_\theta(a_t%20|%20s_t)%20R(\tau)\right])

Two key observations highlight why BC’s direct supervision is more efficient:

- **Explicit Loss Signal:** The BC loss compares against the teacher action, giving explicit feedback on how the action should be adjusted. In contrast, the policy gradient only provides directional guidance, instructing the optimizer to increase or decrease an action’s likelihood based solely on its downstream rewards.
- **Per-Dimension Supervision:** While the policy gradient applies a uniform weighting across all action dimensions, BC supplies per-dimension information, making it easier to scale to high-dimensional action spaces.

---

**Frozen Encoders**

*VisionMLP2ChanCIFAR10_OCP* is an Orbax checkpoint of [NatureCNN](https://github.com/google/brax/blob/241f9bc5bbd003f9cfc9ded7613388e2fe125af6/brax/training/networks.py#L153) (AtariCNN) pre-trained on CIFAR10 to achieve over 70% classification accuracy. We omit the supervised training code, see [this tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html) for reference.

---

**Aloha Deployment Setup**

For deployment, the ONNX policy is executed on the Aloha robot using a custom fork of [OpenPI](https://github.com/Physical-Intelligence/openpi) along with the Interbotix Aloha ROS packages. Acknowledgements to Kevin Zakka, Laura Smith and the Levine Lab for robot deployment setup!
Loading