Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -375,7 +375,7 @@ jobs:
image: ubuntu-2004-cuda-11.4:202110-01
resource_class: gpu.nvidia.medium
environment:
image_name: "pytorch/manylinux-cuda117"
image_name: "nvidia/cudagl:11.4.0-base"
TAR_OPTIONS: --no-same-owner
PYTHON_VERSION: << parameters.python_version >>
CU_VERSION: << parameters.cu_version >>
2 changes: 2 additions & 0 deletions .circleci/unittest/linux_examples/scripts/install.sh
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@ unset PYTORCH_VERSION
# For unittest, nightly PyTorch is used as the following section,
# so no need to set PYTORCH_VERSION.
# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.
apt-get update && apt-get install -y git wget gcc g++
#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev

set -e

93 changes: 69 additions & 24 deletions .circleci/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@

set -e

apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 wget freeglut3 freeglut3-dev

eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env

@@ -27,19 +29,75 @@ export MKL_THREADING_LAYER=GNU
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 20
python .circleci/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 20

# ========================================================================================
# DDPG
# ----
#
# Modalities:
# ^^^^^^^^^^^
#
# pixels on/off
# Batched on/off
#
# With batched environments
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=4 \
env_per_collector=2 \
collector_devices=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_collectors=4 \
collector.collector_devices=cuda:0 \
env.num_envs=2 \
optim.batch_size=10 \
optim.optim_steps_per_batch=1 \
recorder.video=True \
recorder.frames=4 \
replay_buffer.capacity=120 \
env.from_pixels=False \
logger.backend=csv
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_collectors=4 \
collector.collector_devices=cuda:0 \
env.num_envs=2 \
optim.batch_size=10 \
optim.optim_steps_per_batch=1 \
recorder.video=True \
recorder.frames=4 \
replay_buffer.capacity=120 \
env.from_pixels=True \
logger.backend=csv
# With single envs
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_collectors=4 \
collector.collector_devices=cuda:0 \
env.num_envs=1 \
optim.batch_size=10 \
optim.optim_steps_per_batch=1 \
recorder.video=True \
recorder.frames=4 \
replay_buffer.capacity=120 \
env.from_pixels=False \
logger.backend=csv
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_collectors=4 \
collector.collector_devices=cuda:0 \
env.num_envs=1 \
optim.batch_size=10 \
optim.optim_steps_per_batch=1 \
recorder.video=True \
recorder.frames=4 \
replay_buffer.capacity=120 \
env.from_pixels=True \
logger.backend=csv

python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \
total_frames=48 \
batch_size=10 \
@@ -112,19 +170,6 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/drea
buffer_size=120 \
rssm_hidden_dim=17

# With single envs
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
total_frames=48 \
init_random_frames=10 \
batch_size=10 \
frames_per_batch=16 \
num_workers=2 \
env_per_collector=1 \
collector_devices=cuda:0 \
optim_steps_per_batch=1 \
record_video=True \
record_frames=4 \
buffer_size=120
python .circleci/unittest/helpers/coverage_run_parallel.py examples/a2c/a2c.py \
total_frames=48 \
batch_size=10 \
42 changes: 4 additions & 38 deletions .circleci/unittest/linux_examples/scripts/setup_env.sh
Original file line number Diff line number Diff line change
@@ -9,6 +9,8 @@ set -e

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
# Avoid error: "fatal: unsafe repository"
apt-get update && apt-get install -y git wget gcc g++

git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
@@ -71,48 +73,12 @@ conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \
MUJOCO_GL=$PRIVATE_MUJOCO_GL \
PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL

# Software rendering requires GLX and OSMesa.
if [ $PRIVATE_MUJOCO_GL == 'egl' ] || [ $PRIVATE_MUJOCO_GL == 'osmesa' ] ; then
yum makecache
yum install -y glfw
yum install -y glew
yum install -y mesa-libGL
yum install -y mesa-libGL-devel
yum install -y mesa-libOSMesa-devel
yum -y install egl-utils
yum -y install freeglut
fi

pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda deactivate
conda activate "${env_dir}"

if [[ $OSTYPE != 'darwin'* ]]; then
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
PY_VERSION=$(python --version)
if [[ $PY_VERSION == *"3.7"* ]]; then
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.8"* ]]; then
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.9"* ]]; then
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
elif [[ $PY_VERSION == *"3.10"* ]]; then
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
pip install "gymnasium[atari,accept-rom-license]"
else
pip install "gymnasium[atari,accept-rom-license]"
fi
pip install ale-py
pip install "gymnasium[atari,accept-rom-license]"
101 changes: 65 additions & 36 deletions examples/ddpg/config.yaml
Original file line number Diff line number Diff line change
@@ -1,36 +1,65 @@
env_name: HalfCheetah-v4
env_task: ""
env_library: gym
async_collection: 1
record_video: 0
normalize_rewards_online: 1
normalize_rewards_online_scale: 5
frame_skip: 1
frames_per_batch: 1024
optim_steps_per_batch: 128
batch_size: 256
total_frames: 1000000
prb: 1
lr: 3e-4
ou_exploration: 1
multi_step: 1
init_random_frames: 25000
activation: elu
gSDE: 0
from_pixels: 0
#collector_devices: [cuda:1,cuda:1,cuda:1,cuda:1]
collector_devices: [cpu,cpu,cpu,cpu]
env_per_collector: 8
num_workers: 32
lr_scheduler: ""
value_network_update_interval: 200
record_interval: 10
max_frames_per_traj: -1
weight_decay: 0.0
annealing_frames: 1000000
init_env_steps: 10000
record_frames: 10000
loss_function: smooth_l1
batch_transform: 1
buffer_prefetch: 64
norm_stats: 1
# task and env
env:
env_name: HalfCheetah-v4
env_task: ""
env_library: gym
normalize_rewards_online: 1
normalize_rewards_online_scale: 5
frame_skip: 1
norm_stats: 1
num_envs: 4
n_samples_stats: 1000
noop: 1
reward_scaling:
from_pixels: False

# collector
collector:
async_collection: 1
frames_per_batch: 1024
total_frames: 1000000
multi_step: 3 # 0 to disable
init_random_frames: 25000
collector_devices: cpu # [cpu,cpu,cpu,cpu]
num_collectors: 4
max_frames_per_traj: -1

# eval
recorder:
video: True
interval: 10000 # record interval in frames
frames: 10000

# logger
logger:
backend: wandb
exp_name: ddpg_cheetah_gym

# Buffer
replay_buffer:
prb: 1
buffer_prefetch: 64
capacity: 1_000_000

# Optim
optim:
device: cpu
lr: 3e-4
weight_decay: 0.0
batch_size: 256
lr_scheduler: ""
value_network_update_interval: 200
optim_steps_per_batch: 8

# Policy and model
model:
ou_exploration: 1
annealing_frames: 1000000
noisy: False
activation: elu

# loss
loss:
loss_function: smooth_l1
gamma: 0.99
tau: 0.05
274 changes: 91 additions & 183 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
@@ -2,196 +2,104 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""DDPG Example.
import dataclasses
This is a self-contained example of a DDPG training script.
It works across Gym and DM-control over a variety of tasks.
Both state and pixel-based environments are supported.
The helper functions are coded in the utils.py associated with this script.
"""

import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
from torchrl.trainers.helpers.collectors import (
make_collector_offpolicy,
OffPolicyCollectorConfig,
import tqdm
from torchrl.trainers.helpers.envs import correct_for_frame_skip

from utils import (
get_stats,
make_collector,
make_ddpg_model,
make_logger,
make_loss,
make_optim,
make_policy,
make_recorder,
make_replay_buffer,
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
initialize_observation_norm_transforms,
parallel_env_constructor,
retrieve_observation_norms_state_dict,
transformed_env_constructor,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import LossConfig, make_ddpg_loss
from torchrl.trainers.helpers.models import DDPGModelConfig, make_ddpg_actor
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
(config_field.name, config_field.type, config_field)
for config_cls in (
TrainerConfig,
OffPolicyCollectorConfig,
EnvConfig,
LossConfig,
DDPGModelConfig,
LoggerConfig,
ReplayArgsConfig,
)
for config_field in dataclasses.fields(config_cls)
]
Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)

DEFAULT_REWARD_SCALING = {
"Hopper-v1": 5,
"Walker2d-v1": 5,
"HalfCheetah-v1": 5,
"cheetah": 5,
"Ant-v2": 5,
"Humanoid-v2": 20,
"humanoid": 100,
}


@hydra.main(version_base=None, config_path=".", config_name="config")


@hydra.main(config_path=".", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821

cfg = correct_for_frame_skip(cfg)

if not isinstance(cfg.reward_scaling, float):
cfg.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env_name, 5.0)

device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda:0")
)

exp_name = generate_exp_name("DDPG", cfg.exp_name)
logger = get_logger(
logger_type=cfg.logger, logger_name="ddpg_logging", experiment_name=exp_name
)
video_tag = exp_name if cfg.record_video else ""

key, init_env_steps, stats = None, None, None
if not cfg.vecnorm and cfg.norm_stats:
if not hasattr(cfg, "init_env_steps"):
raise AttributeError("init_env_steps missing from arguments.")
key = ("next", "pixels") if cfg.from_pixels else ("next", "observation_vector")
init_env_steps = cfg.init_env_steps
stats = {"loc": None, "scale": None}
elif cfg.from_pixels:
stats = {"loc": 0.5, "scale": 0.5}

proof_env = transformed_env_constructor(
cfg=cfg,
stats=stats,
use_env_creator=False,
)()
initialize_observation_norm_transforms(
proof_environment=proof_env, num_iter=init_env_steps, key=key
)
_, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0]

model = make_ddpg_actor(
proof_env,
cfg=cfg,
device=device,
)
loss_module, target_net_updater = make_ddpg_loss(model, cfg)

actor_model_explore = model[0]
if cfg.ou_exploration:
if cfg.gSDE:
raise RuntimeError("gSDE and ou_exploration are incompatible")
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore,
annealing_num_steps=cfg.annealing_frames,
sigma=cfg.ou_sigma,
theta=cfg.ou_theta,
).to(device)
if device == torch.device("cpu"):
# mostly for debugging
actor_model_explore.share_memory()

if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
del proof_td
else:
action_dim_gsde, state_dim_gsde = None, None

proof_env.close()

create_env_fn = parallel_env_constructor(
cfg=cfg,
obs_norm_state_dict=obs_norm_state_dict,
action_dim_gsde=action_dim_gsde,
state_dim_gsde=state_dim_gsde,
)

collector = make_collector_offpolicy(
make_env=create_env_fn,
actor_model_explore=actor_model_explore,
cfg=cfg,
# make_env_kwargs=[
# {"device": device} if device >= 0 else {}
# for device in args.env_rendering_devices
# ],
)

replay_buffer = make_replay_buffer(device, cfg)

recorder = transformed_env_constructor(
cfg,
video_tag=video_tag,
norm_obs_only=True,
obs_norm_state_dict=obs_norm_state_dict,
logger=logger,
use_env_creator=False,
)()
if isinstance(create_env_fn, ParallelEnv):
raise NotImplementedError("This behaviour is deprecated")
elif isinstance(create_env_fn, EnvCreator):
recorder.transform[1:].load_state_dict(create_env_fn().transform.state_dict())
elif isinstance(create_env_fn, TransformedEnv):
recorder.transform = create_env_fn.transform.clone()
else:
raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}")
if logger is not None and video_tag:
recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag))

# reset reward scaling
for t in recorder.transform:
if isinstance(t, RewardScaling):
t.scale.fill_(1.0)
t.loc.fill_(0.0)

trainer = make_trainer(
collector,
loss_module,
recorder,
target_net_updater,
actor_model_explore,
replay_buffer,
logger,
cfg,
)

final_seed = collector.set_seed(cfg.seed)
print(f"init seed: {cfg.seed}, final seed: {final_seed}")

trainer.train()
return (logger.log_dir, trainer._log_dict)
model_device = cfg.optim.device

state_dict = get_stats(cfg.env)
logger = make_logger(cfg.logger)
replay_buffer = make_replay_buffer(cfg.replay_buffer)

actor_network, value_network = make_ddpg_model(cfg)
actor_network = actor_network.to(model_device)
value_network = value_network.to(model_device)

policy = make_policy(cfg.model, actor_network)
collector = make_collector(cfg, state_dict=state_dict, policy=policy)
loss, target_net_updater = make_loss(cfg.loss, actor_network, value_network)
optim = make_optim(cfg.optim, actor_network, value_network)
recorder = make_recorder(cfg, logger, policy)

optim_steps_per_batch = cfg.optim.optim_steps_per_batch
batch_size = cfg.optim.batch_size
init_random_frames = cfg.collector.init_random_frames
record_interval = cfg.recorder.interval

pbar = tqdm.tqdm(total=cfg.collector.total_frames)
collected_frames = 0

r0 = None
l0 = None
for data in collector:
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())
# extend replay buffer
replay_buffer.extend(data.view(-1))
if collected_frames >= init_random_frames:
for _ in range(optim_steps_per_batch):
# sample
sample = replay_buffer.sample(batch_size)
# loss
loss_vals = loss(sample)
# backprop
loss_val = sum(
val for key, val in loss_vals.items() if key.startswith("loss")
)
loss_val.backward()
optim.step()
optim.zero_grad()
target_net_updater.step()
if r0 is None:
r0 = data["reward"].mean().item()
if l0 is None:
l0 = loss_val.item()

for key, value in loss_vals.items():
logger.log_scalar(key, value.item(), collected_frames)
logger.log_scalar(
"reward_training", data["reward"].mean().item(), collected_frames
)

pbar.set_description(
f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), reward: {data['reward'].mean(): 4.4f} (init={r0: 4.4f})"
)
collector.update_policy_weights_()
if (
collected_frames - frames_in_batch
) // record_interval < collected_frames // record_interval:
recorder()


if __name__ == "__main__":
457 changes: 457 additions & 0 deletions examples/ddpg/utils.py

Large diffs are not rendered by default.

26 changes: 12 additions & 14 deletions torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
@@ -55,20 +55,18 @@ def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821
"""
# Adapt all frame counts wrt frame_skip
if cfg.frame_skip != 1:
fields = [
"max_frames_per_traj",
"total_frames",
"frames_per_batch",
"record_frames",
"annealing_frames",
"init_random_frames",
"init_env_steps",
"noops",
]
for field in fields:
if hasattr(cfg, field):
setattr(cfg, field, getattr(cfg, field) // cfg.frame_skip)

frame_skip = cfg.env.frame_skip

if frame_skip != 1:
cfg.collector.max_frames_per_traj //= frame_skip
cfg.collector.total_frames //= frame_skip
cfg.collector.frames_per_batch //= frame_skip
cfg.collector.init_random_frames //= frame_skip
cfg.collector.init_env_steps //= frame_skip
cfg.recorder.record_frames //= frame_skip
cfg.model.annealing_frames //= frame_skip
cfg.env.noops //= frame_skip
return cfg


2 changes: 1 addition & 1 deletion torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
@@ -1161,7 +1161,7 @@ def __init__(
self.log_pbar = log_pbar

@torch.inference_mode()
def __call__(self, batch: TensorDictBase) -> Dict:
def __call__(self, batch: TensorDictBase = None) -> Dict:
out = None
if self._count % self.record_interval == 0:
with set_exploration_mode(self.exploration_mode):