Skip to content

[NOMERG] PPO-Myo #1514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/distributed/collectors/single_machine/generic.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.robohive import RoboHiveEnv

parser = ArgumentParser()
parser.add_argument(
@@ -80,6 +81,16 @@
default="ALE/Pong-v5",
help="Gym environment to be run.",
)
LIBS = {
"gym": GymEnv,
"robohive": RoboHiveEnv,
}
parser.add_argument(
"--lib",
default="gym",
help="Lib backend",
choices=list(LIBS.keys()),
)
if __name__ == "__main__":
args = parser.parse_args()
num_workers = args.num_workers
@@ -89,7 +100,8 @@

device_count = torch.cuda.device_count()

make_env = EnvCreator(lambda: GymEnv(args.env))
lib = LIBS[args.lib]
make_env = EnvCreator(lambda: lib(args.env))
if args.worker_parallelism == "collector" or num_workers == 1:
action_spec = make_env().action_spec
else:
11 changes: 11 additions & 0 deletions examples/distributed/collectors/single_machine/rpc.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.robohive import RoboHiveEnv

parser = ArgumentParser()
parser.add_argument(
@@ -63,6 +64,16 @@
default="ALE/Pong-v5",
help="Gym environment to be run.",
)
LIBS = {
"gym": GymEnv,
"robohive": RoboHiveEnv,
}
parser.add_argument(
"--lib",
default="gym",
help="Lib backend",
choices=list(LIBS.keys()),
)
if __name__ == "__main__":
args = parser.parse_args()
num_workers = args.num_workers
11 changes: 11 additions & 0 deletions examples/distributed/collectors/single_machine/sync.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.libs.robohive import RoboHiveEnv

parser = ArgumentParser()
parser.add_argument(
@@ -75,6 +76,16 @@
default="ALE/Pong-v5",
help="Gym environment to be run.",
)
LIBS = {
"gym": GymEnv,
"robohive": RoboHiveEnv,
}
parser.add_argument(
"--lib",
default="gym",
help="Lib backend",
choices=list(LIBS.keys()),
)
if __name__ == "__main__":
args = parser.parse_args()
num_workers = args.num_workers
46 changes: 0 additions & 46 deletions examples/ppo/config.yaml

This file was deleted.

35 changes: 35 additions & 0 deletions examples/ppo/config_atari.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Environment
env:
env_name: PongNoFrameskip-v4

# collector
collector:
frames_per_batch: 4096
total_frames: 40_000_000

# logger
logger:
backend: wandb
exp_name: Atari_Schulman17
test_interval: 40_000_000
num_test_episodes: 3

# Optim
optim:
lr: 2.5e-4
eps: 1.0e-6
weight_decay: 0.0
max_grad_norm: 0.5
anneal_lr: True

# loss
loss:
gamma: 0.99
mini_batch_size: 1024
ppo_epochs: 3
gae_lambda: 0.95
clip_epsilon: 0.1
anneal_clip_epsilon: True
critic_coef: 1.0
entropy_coef: 0.01
loss_critic_type: l2
43 changes: 0 additions & 43 deletions examples/ppo/config_example2.yaml

This file was deleted.

32 changes: 32 additions & 0 deletions examples/ppo/config_mujoco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# task and env
env:
env_name: HalfCheetah-v3

# collector
collector:
frames_per_batch: 2048
total_frames: 1_000_000

# logger
logger:
backend: wandb
exp_name: Mujoco_Schulman17
test_interval: 1_000_000
num_test_episodes: 5

# Optim
optim:
lr: 3e-4
weight_decay: 0.0
anneal_lr: False

# loss
loss:
gamma: 0.99
mini_batch_size: 64
ppo_epochs: 10
gae_lambda: 0.95
clip_epsilon: 0.2
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2
33 changes: 33 additions & 0 deletions examples/ppo/config_myo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# task and env
env:
env_name: myoHandReachRandom-v0

# collector
collector:
frames_per_batch: 2048
total_frames: 1_000_000
num_envs: 1

# logger
logger:
backend: wandb
exp_name: myo_hand_reach
test_interval: 1_000_000
num_test_episodes: 5

# Optim
optim:
lr: 3e-4
weight_decay: 0.0
anneal_lr: False

# loss
loss:
gamma: 0.99
mini_batch_size: 64
ppo_epochs: 10
gae_lambda: 0.95
clip_epsilon: 0.2
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2
182 changes: 0 additions & 182 deletions examples/ppo/ppo.py

This file was deleted.

204 changes: 204 additions & 0 deletions examples/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
This script reproduces the Proximal Policy Optimization (PPO) Algorithm
results from Schulman et al. 2017 for the on Atari Environments.
"""
import hydra


@hydra.main(config_path=".", config_name="config_atari", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821

import time

import numpy as np
import torch.optim
import tqdm

from tensordict import TensorDict
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_atari import make_parallel_env, make_ppo_models

device = "cpu" if not torch.cuda.is_available() else "cuda"

# Correct for frame_skip
frame_skip = 4
total_frames = cfg.collector.total_frames // frame_skip
frames_per_batch = cfg.collector.frames_per_batch // frame_skip
mini_batch_size = cfg.loss.mini_batch_size // frame_skip
test_interval = cfg.logger.test_interval // frame_skip

# Create models (check utils_atari.py)
actor, critic, critic_head = make_ppo_models(cfg.env.env_name)
actor, critic, critic_head = (
actor.to(device),
critic.to(device),
critic_head.to(device),
)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_parallel_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)

# Create data buffer
sampler = SamplerWithoutReplacement()
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(frames_per_batch),
sampler=sampler,
batch_size=mini_batch_size,
)

# Create loss and adv modules
adv_module = GAE(
gamma=cfg.loss.gamma,
lmbda=cfg.loss.gae_lambda,
value_network=critic,
average_gae=False,
)
loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
clip_epsilon=cfg.loss.clip_epsilon,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
normalize_advantage=True,
)

# Create optimizer
optim = torch.optim.Adam(
loss_module.parameters(),
lr=cfg.optim.lr,
weight_decay=cfg.optim.weight_decay,
eps=cfg.optim.eps,
)

# Create logger
exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
logger = get_logger(cfg.logger.backend, logger_name="ppo", experiment_name=exp_name)

# Create test environment
test_env = make_parallel_env(cfg.env.env_name, device, is_test=True)
test_env.eval()

# Main loop
collected_frames = 0
num_network_updates = 0
start_time = time.time()
pbar = tqdm.tqdm(total=total_frames)
num_mini_batches = frames_per_batch // mini_batch_size
total_network_updates = (
(total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches
)

for data in collector:

frames_in_batch = data.numel()
collected_frames += frames_in_batch * frame_skip
pbar.update(data.numel())

# Train loging
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
logger.log_scalar(
"reward_train", episode_rewards.mean().item(), collected_frames
)

# Apply episodic end of life
data["done"].copy_(data["end_of_life"])
data["next", "done"].copy_(data["next", "end_of_life"])

losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
for j in range(cfg.loss.ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.extend(data_reshape)

for i, batch in enumerate(data_buffer):

# Linearly decrease the learning rate and clip epsilon
alpha = 1 - (num_network_updates / total_network_updates)
if cfg.optim.anneal_lr:
for g in optim.param_groups:
g["lr"] = cfg.optim.lr * alpha
if cfg.loss.anneal_clip_epsilon:
loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha)
num_network_updates += 1

# Get a data batch
batch = batch.to(device)

# Forward pass PPO loss
loss = loss_module(batch)
losses[j, i] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
).detach()
loss_sum = (
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
)

# Backward pass
loss_sum.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm
)

# Update the networks
optim.step()
optim.zero_grad()

losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses.items():
logger.log_scalar(key, value.item(), collected_frames)
logger.log_scalar("lr", alpha * cfg.optim.lr, collected_frames)
logger.log_scalar(
"clip_epsilon", alpha * cfg.loss.clip_epsilon, collected_frames
)

# Test logging
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if (collected_frames - frames_in_batch) // test_interval < (
collected_frames // test_interval
):
actor.eval()
test_rewards = []
for _ in range(cfg.logger.num_test_episodes):
td_test = test_env.rollout(
policy=actor,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,
max_steps=10_000_000,
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
del td_test
logger.log_scalar("reward_test", test_rewards.mean(), collected_frames)
actor.train()

collector.update_policy_weights_()

end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Binary file removed examples/ppo/ppo_atari_pong.png
Binary file not shown.
179 changes: 179 additions & 0 deletions examples/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
This script reproduces the Proximal Policy Optimization (PPO) Algorithm
results from Schulman et al. 2017 for the on MuJoCo Environments.
"""
import hydra


@hydra.main(config_path=".", config_name="config_mujoco", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821

import time

import numpy as np
import torch.optim
import tqdm
from tensordict import TensorDict
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_mujoco import make_env, make_ppo_models

# Define paper hyperparameters
device = "cpu" if not torch.cuda.is_available() else "cuda"
num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
total_network_updates = (
(cfg.collector.total_frames // cfg.collector.frames_per_batch)
* cfg.loss.ppo_epochs
* num_mini_batches
)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)

# Create data buffer
sampler = SamplerWithoutReplacement()
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.collector.frames_per_batch, device=device),
sampler=sampler,
batch_size=cfg.loss.mini_batch_size,
)

# Create loss and adv modules
adv_module = GAE(
gamma=cfg.loss.gamma,
lmbda=cfg.loss.gae_lambda,
value_network=critic,
average_gae=False,
)
loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
clip_epsilon=cfg.loss.clip_epsilon,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
normalize_advantage=True,
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr)

# Create logger
exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
logger = get_logger(cfg.logger.backend, logger_name="ppo", experiment_name=exp_name)

# Create test environment
test_env = make_env(cfg.env.env_name, device)
test_env.eval()

# Main loop
collected_frames = 0
num_network_updates = 0
start_time = time.time()
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

for data in collector:

frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Train loging
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
logger.log_scalar(
"reward_train", episode_rewards.mean().item(), collected_frames
)

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.extend(data_reshape)

losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
for j in range(cfg.loss.ppo_epochs):

for i, batch in enumerate(data_buffer):

# Linearly decrease the learning rate and clip epsilon
if cfg.optim.anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for g in actor_optim.param_groups:
g["lr"] = cfg.optim.lr * alpha
for g in critic_optim.param_groups:
g["lr"] = cfg.optim.lr * alpha
num_network_updates += 1

# Forward pass PPO loss
loss = loss_module(batch)
losses[j, i] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
).detach()
critic_loss = loss["loss_critic"]
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()

losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses.items():
logger.log_scalar(key, value.item(), collected_frames)

# Test logging
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if (collected_frames - frames_in_batch) // cfg.logger.test_interval < (
collected_frames // cfg.logger.test_interval
):
actor.eval()
test_rewards = []
for _ in range(cfg.logger.num_test_episodes):
td_test = test_env.rollout(
policy=actor,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,
max_steps=10_000_000,
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
del td_test
logger.log_scalar("reward_test", test_rewards.mean(), collected_frames)
actor.train()

collector.update_policy_weights_()

end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Binary file removed examples/ppo/ppo_mujoco_halfcheetah.png
Binary file not shown.
191 changes: 191 additions & 0 deletions examples/ppo/ppo_myo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
This script reproduces the Proximal Policy Optimization (PPO) Algorithm
results from Schulman et al. 2017 for the on MuJoCo Environments.
"""
import hydra

from torchrl.collectors import MultiSyncDataCollector


@hydra.main(config_path=".", config_name="config_myo", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821

import time

import numpy as np
import torch.optim
import tqdm
from tensordict import TensorDict
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_myo import make_env, make_ppo_models

# Define paper hyperparameters
device = "cpu" if not torch.cuda.is_available() else "cuda"
num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
total_network_updates = (
(cfg.collector.total_frames // cfg.collector.frames_per_batch)
* cfg.loss.ppo_epochs
* num_mini_batches
)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
if cfg.collector.num_envs == 1:
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)
else:
collector = MultiSyncDataCollector(
create_env_fn=cfg.collector.num_envs * [make_env(cfg.env.env_name, device)],
policy=actor,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)

# Create data buffer
sampler = SamplerWithoutReplacement()
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.collector.frames_per_batch, device=device),
sampler=sampler,
batch_size=cfg.loss.mini_batch_size,
)

# Create loss and adv modules
adv_module = GAE(
gamma=cfg.loss.gamma,
lmbda=cfg.loss.gae_lambda,
value_network=critic,
average_gae=False,
)
loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
clip_epsilon=cfg.loss.clip_epsilon,
loss_critic_type=cfg.loss.loss_critic_type,
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
normalize_advantage=True,
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr)

# Create logger
exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
logger = get_logger(cfg.logger.backend, logger_name="ppo", experiment_name=exp_name)

# Create test environment
test_env = make_env(cfg.env.env_name, device)
test_env.eval()

# Main loop
collected_frames = 0
num_network_updates = 0
start_time = time.time()
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches])
for data in collector:

frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Train loging
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
logger.log_scalar(
"reward_train", episode_rewards.mean().item(), collected_frames
)

for j in range(cfg.loss.ppo_epochs):
# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)
# Update the data buffer
data_buffer.empty()
data_buffer.extend(data_reshape)

for i, batch in enumerate(data_buffer):

# Linearly decrease the learning rate and clip epsilon
if cfg.optim.anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for g in actor_optim.param_groups:
g["lr"] = cfg.optim.lr * alpha
for g in critic_optim.param_groups:
g["lr"] = cfg.optim.lr * alpha
num_network_updates += 1

# Forward pass PPO loss
loss = loss_module(batch)
losses[j, i] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
).detach()
critic_loss = loss["loss_critic"]
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()

losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
for key, value in losses.items():
logger.log_scalar(key, value.item(), collected_frames)

# Test logging
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if (collected_frames - frames_in_batch) // cfg.logger.test_interval < (
collected_frames // cfg.logger.test_interval
):
actor.eval()
test_rewards = []
for _ in range(cfg.logger.num_test_episodes):
td_test = test_env.rollout(
policy=actor,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,
max_steps=10_000_000,
)
reward = td_test["next", "episode_reward"][td_test["next", "done"]]
test_rewards = np.append(test_rewards, reward.cpu().numpy())
del td_test
logger.log_scalar("reward_test", test_rewards.mean(), collected_frames)
actor.train()

collector.update_policy_weights_()

end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
13 changes: 0 additions & 13 deletions examples/ppo/training_curves.md

This file was deleted.

473 changes: 0 additions & 473 deletions examples/ppo/utils.py

This file was deleted.

212 changes: 212 additions & 0 deletions examples/ppo/utils_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import gym
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
from torchrl.data import CompositeSpec
from torchrl.data.tensor_specs import DiscreteBox
from torchrl.envs import (
CatFrames,
default_info_dict_reader,
DoubleToFloat,
EnvCreator,
ExplorationType,
GrayScale,
NoopResetEnv,
ParallelEnv,
Resize,
RewardClipping,
RewardSum,
StepCounter,
ToTensorImage,
TransformedEnv,
VecNorm,
)
from torchrl.envs.libs.gym import GymWrapper
from torchrl.modules import (
ActorValueOperator,
ConvNet,
MLP,
OneHotCategorical,
ProbabilisticActor,
TanhNormal,
ValueOperator,
)

# ====================================================================
# Environment utils
# --------------------------------------------------------------------


class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. It helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0

def step(self, action):
obs, rew, done, info = self.env.step(action)
lives = self.env.unwrapped.ale.lives()
info["end_of_life"] = False
if (lives < self.lives) or done:
info["end_of_life"] = True
self.lives = lives
return obs, rew, done, info

def reset(self, **kwargs):
reset_data = self.env.reset(**kwargs)
self.lives = self.env.unwrapped.ale.lives()
return reset_data


def make_base_env(
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
):
env = gym.make(env_name)
if not is_test:
env = EpisodicLifeEnv(env)
env = GymWrapper(
env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device
)
env = TransformedEnv(env)
env.append_transform(NoopResetEnv(noops=30, random=True))
reader = default_info_dict_reader(["end_of_life"])
env.set_info_dict_reader(reader)
return env


def make_parallel_env(env_name, device, is_test=False):
num_envs = 8
env = ParallelEnv(
num_envs, EnvCreator(lambda: make_base_env(env_name, device=device))
)
env = TransformedEnv(env)
env.append_transform(ToTensorImage())
env.append_transform(GrayScale())
env.append_transform(Resize(84, 84))
env.append_transform(CatFrames(N=4, dim=-3))
env.append_transform(RewardSum())
env.append_transform(StepCounter(max_steps=4500))
if not is_test:
env.append_transform(RewardClipping(-1, 1))
env.append_transform(DoubleToFloat())
env.append_transform(VecNorm(in_keys=["pixels"]))
return env


# ====================================================================
# Model utils
# --------------------------------------------------------------------


def make_ppo_modules_pixels(proof_environment):

# Define input shape
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(proof_environment.action_spec.space, DiscreteBox):
num_outputs = proof_environment.action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
}

# Define input keys
in_keys = ["pixels"]

# Define a shared Module and TensorDictModule (CNN + MLP)
common_cnn = ConvNet(
activation_class=torch.nn.ReLU,
num_cells=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
)
common_cnn_output = common_cnn(torch.ones(input_shape))
common_mlp = MLP(
in_features=common_cnn_output.shape[-1],
activation_class=torch.nn.ReLU,
activate_last_layer=True,
out_features=512,
num_cells=[],
)
common_mlp_output = common_mlp(common_cnn_output)

# Define shared net as TensorDictModule
common_module = TensorDictModule(
module=torch.nn.Sequential(common_cnn, common_mlp),
in_keys=in_keys,
out_keys=["common_features"],
)

# Define on head for the policy
policy_net = MLP(
in_features=common_mlp_output.shape[-1],
out_features=num_outputs,
activation_class=torch.nn.ReLU,
num_cells=[],
)
policy_module = TensorDictModule(
module=policy_net,
in_keys=["common_features"],
out_keys=["logits"],
)

# Add probabilistic sampling of the actions
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=CompositeSpec(action=proof_environment.action_spec),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
default_interaction_type=ExplorationType.RANDOM,
)

# Define another head for the value
value_net = MLP(
activation_class=torch.nn.ReLU,
in_features=common_mlp_output.shape[-1],
out_features=1,
num_cells=[],
)
value_module = ValueOperator(
value_net,
in_keys=["common_features"],
)

return common_module, policy_module, value_module


def make_ppo_models(env_name):

proof_environment = make_parallel_env(env_name, device="cpu")
common_module, policy_module, value_module = make_ppo_modules_pixels(
proof_environment
)

# Wrap modules in a single ActorCritic operator
actor_critic = ActorValueOperator(
common_operator=common_module,
policy_operator=policy_module,
value_operator=value_module,
)

with torch.no_grad():
td = proof_environment.rollout(max_steps=100, break_when_any_done=False)
td = actor_critic(td)
del td

actor = actor_critic.get_policy_operator()
critic = actor_critic.get_value_operator()
critic_head = actor_critic.get_value_head()

del proof_environment

return actor, critic, critic_head
114 changes: 114 additions & 0 deletions examples/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import gym
import torch.nn
import torch.optim

from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.data import CompositeSpec
from torchrl.envs import (
ClipTransform,
DoubleToFloat,
ExplorationType,
RewardSum,
TransformedEnv,
VecNorm,
)
from torchrl.envs.libs.gym import GymWrapper
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator

# ====================================================================
# Environment utils
# --------------------------------------------------------------------


def make_env(env_name="HalfCheetah-v4", device="cpu"):
env = gym.make(env_name)
env = GymWrapper(env, device=device)
env = TransformedEnv(env)
env.append_transform(RewardSum())
env.append_transform(VecNorm(in_keys=["observation"]))
env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
env.append_transform(DoubleToFloat(in_keys=["observation"]))
return env


# ====================================================================
# Model utils
# --------------------------------------------------------------------


def make_ppo_models_state(proof_environment):

# Define input shape
input_shape = proof_environment.observation_spec["observation"].shape

# Define policy output distribution class
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
"tanh_loc": False,
}

# Define policy architecture
policy_mlp = MLP(
in_features=input_shape[-1],
activation_class=torch.nn.Tanh,
out_features=num_outputs, # predict only loc
num_cells=[64, 64],
)

# Initialize policy weights
for layer in policy_mlp.modules():
if isinstance(layer, torch.nn.Linear):
torch.nn.init.orthogonal_(layer.weight, 1.0)
layer.bias.data.zero_()

# Add state-independent normal scale
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]),
)

# Add probabilistic sampling of the actions
policy_module = ProbabilisticActor(
TensorDictModule(
module=policy_mlp,
in_keys=["observation"],
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=CompositeSpec(action=proof_environment.action_spec),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
default_interaction_type=ExplorationType.RANDOM,
)

# Define value architecture
value_mlp = MLP(
in_features=input_shape[-1],
activation_class=torch.nn.Tanh,
out_features=1,
num_cells=[64, 64],
)

# Initialize value weights
for layer in value_mlp.modules():
if isinstance(layer, torch.nn.Linear):
torch.nn.init.orthogonal_(layer.weight, 0.01)
layer.bias.data.zero_()

# Define value module
value_module = ValueOperator(
value_mlp,
in_keys=["observation"],
)

return policy_module, value_module


def make_ppo_models(env_name):
proof_environment = make_env(env_name, device="cpu")
actor, critic = make_ppo_models_state(proof_environment)
return actor, critic
116 changes: 116 additions & 0 deletions examples/ppo/utils_myo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import gym
import torch.nn
import torch.optim

from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
from torchrl.data import CompositeSpec
from torchrl.envs import (
ClipTransform,
DoubleToFloat,
ExplorationType,
RewardSum,
TransformedEnv,
VecNorm, CatTensors, ExcludeTransform,
)
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs.libs.robohive import RoboHiveEnv
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator

# ====================================================================
# Environment utils
# --------------------------------------------------------------------


def make_env(env_name="HalfCheetah-v4", device="cpu"):
env = RoboHiveEnv(env_name, include_info=False, device=device)
env = TransformedEnv(env)
env.append_transform(RewardSum())
env.append_transform(CatTensors(["qpos", "qvel", "tip_pos", "reach_err"], out_key="observation"))
env.append_transform(ExcludeTransform("time", "state", "rwd_dense", "rwd_dict", "visual_dict"))
env.append_transform(VecNorm(in_keys=["observation"]))
env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
env.append_transform(DoubleToFloat(in_keys=["observation"]))
return env


# ====================================================================
# Model utils
# --------------------------------------------------------------------


def make_ppo_models_state(proof_environment):

# Define input shape
input_shape = proof_environment.observation_spec["observation"].shape

# Define policy output distribution class
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.minimum,
"max": proof_environment.action_spec.space.maximum,
"tanh_loc": False,
}

# Define policy architecture
policy_mlp = MLP(
in_features=input_shape[-1],
activation_class=torch.nn.Tanh,
out_features=num_outputs, # predict only loc
num_cells=[64, 64],
)

# Initialize policy weights
for layer in policy_mlp.modules():
if isinstance(layer, torch.nn.Linear):
torch.nn.init.orthogonal_(layer.weight, 1.0)
layer.bias.data.zero_()

# Add state-independent normal scale
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]),
)

# Add probabilistic sampling of the actions
policy_module = ProbabilisticActor(
TensorDictModule(
module=policy_mlp,
in_keys=["observation"],
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=CompositeSpec(action=proof_environment.action_spec),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
default_interaction_type=ExplorationType.RANDOM,
)

# Define value architecture
value_mlp = MLP(
in_features=input_shape[-1],
activation_class=torch.nn.Tanh,
out_features=1,
num_cells=[64, 64],
)

# Initialize value weights
for layer in value_mlp.modules():
if isinstance(layer, torch.nn.Linear):
torch.nn.init.orthogonal_(layer.weight, 0.01)
layer.bias.data.zero_()

# Define value module
value_module = ValueOperator(
value_mlp,
in_keys=["observation"],
)

return policy_module, value_module


def make_ppo_models(env_name):
proof_environment = make_env(env_name, device="cpu")
actor, critic = make_ppo_models_state(proof_environment)
return actor, critic
33 changes: 20 additions & 13 deletions torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
@@ -82,6 +82,12 @@ class RoboHiveEnv(GymEnv):
else:
CURR_DIR = None

def __init__(self, env_name, include_info: bool=True, **kwargs):
self.include_info = include_info
kwargs["env_name"] = env_name
self._set_gym_args(kwargs)
super().__init__(**kwargs)

@classmethod
def register_envs(cls):

@@ -304,19 +310,20 @@ def read_obs(self, observation):
return super().read_obs(out)

def read_info(self, info, tensordict_out):
out = {}
for key, value in info.items():
if key in ("obs_dict", "done", "reward", *self._env.obs_keys, "act"):
continue
if isinstance(value, dict):
value = {key: _val for key, _val in value.items() if _val is not None}
value = make_tensordict(value, batch_size=[])
if value is not None:
out[key] = value
tensordict_out.update(out)
tensordict_out.update(
tensordict_out.apply(lambda x: x.reshape((1,)) if not x.shape else x)
)
if self.include_info:
out = {}
for key, value in info.items():
if key in ("obs_dict", "done", "reward", *self._env.obs_keys, "act"):
continue
if isinstance(value, dict):
value = {key: _val for key, _val in value.items() if _val is not None}
value = make_tensordict(value, batch_size=[])
if value is not None:
out[key] = value
tensordict_out.update(out)
tensordict_out.update(
tensordict_out.apply(lambda x: x.reshape((1,)) if not x.shape else x)
)
return tensordict_out

def to(self, *args, **kwargs):