Skip to content

Commit bc4a72f

Browse files
BY571vmoens
andauthored
[Algorithm] Simpler IQL example (#998)
Co-authored-by: Vincent Moens <[email protected]>
1 parent 0906206 commit bc4a72f

File tree

19 files changed

+1115
-598
lines changed

19 files changed

+1115
-598
lines changed

.github/unittest/linux_examples/scripts/run_test.sh

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,15 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
4545
optim.updates_per_episode=3 \
4646
optim.warmup_steps=10 \
4747
optim.device=cuda:0 \
48-
logger.backend= \
49-
env.backend=gymnasium \
50-
env.name=HalfCheetah-v4
48+
logger.backend=
49+
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_offline.py \
50+
optim.gradient_steps=55 \
51+
optim.device=cuda:0 \
52+
logger.backend=
53+
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_offline.py \
54+
optim.gradient_steps=55 \
55+
optim.device=cuda:0 \
56+
logger.backend=
5157

5258
# ==================================================================================== #
5359
# ================================ Gymnasium ========================================= #
@@ -115,7 +121,6 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_c
115121
collector.frames_per_batch=16 \
116122
collector.env_per_collector=2 \
117123
collector.device=cuda:0 \
118-
optim.optim_steps_per_batch=1 \
119124
replay_buffer.size=120 \
120125
logger.backend=
121126
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
@@ -174,11 +179,20 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
174179
logger.backend=
175180
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
176181
collector.total_frames=48 \
177-
buffer.batch_size=10 \
182+
optim.batch_size=10 \
178183
collector.frames_per_batch=16 \
179-
collector.env_per_collector=2 \
184+
env.train_num_envs=2 \
185+
optim.device=cuda:0 \
180186
collector.device=cuda:0 \
181-
network.device=cuda:0 \
187+
logger.mode=offline \
188+
logger.backend=
189+
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \
190+
collector.total_frames=48 \
191+
optim.batch_size=10 \
192+
collector.frames_per_batch=16 \
193+
env.train_num_envs=2 \
194+
collector.device=cuda:0 \
195+
optim.device=cuda:0 \
182196
logger.mode=offline \
183197
logger.backend=
184198

@@ -248,12 +262,21 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
248262
logger.backend=
249263
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
250264
collector.total_frames=48 \
265+
optim.batch_size=10 \
251266
collector.frames_per_batch=16 \
252-
collector.env_per_collector=1 \
267+
env.train_num_envs=1 \
268+
logger.mode=offline \
269+
optim.device=cuda:0 \
253270
collector.device=cuda:0 \
254-
network.device=cuda:0 \
255-
buffer.batch_size=10 \
271+
logger.backend=
272+
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_online.py \
273+
collector.total_frames=48 \
274+
optim.batch_size=10 \
275+
collector.frames_per_batch=16 \
276+
collector.env_per_collector=1 \
256277
logger.mode=offline \
278+
optim.device=cuda:0 \
279+
collector.device=cuda:0 \
257280
logger.backend=
258281
python .github/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
259282
collector.total_frames=48 \

examples/cql/cql_offline.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
1111
"""
1212

13+
import time
14+
1315
import hydra
1416
import numpy as np
1517
import torch
@@ -18,16 +20,18 @@
1820
from torchrl.record.loggers import generate_exp_name, get_logger
1921

2022
from utils import (
23+
log_metrics,
24+
make_continuous_cql_optimizer,
25+
make_continuous_loss,
2126
make_cql_model,
22-
make_cql_optimizer,
2327
make_environment,
24-
make_loss,
2528
make_offline_replay_buffer,
2629
)
2730

2831

2932
@hydra.main(config_path=".", config_name="offline_config", version_base="1.1")
3033
def main(cfg: "DictConfig"): # noqa: F821
34+
# Create logger
3135
exp_name = generate_exp_name("CQL-offline", cfg.env.exp_name)
3236
logger = None
3337
if cfg.logger.backend:
@@ -37,49 +41,96 @@ def main(cfg: "DictConfig"): # noqa: F821
3741
experiment_name=exp_name,
3842
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
3943
)
40-
44+
# Set seeds
4145
torch.manual_seed(cfg.env.seed)
4246
np.random.seed(cfg.env.seed)
4347
device = torch.device(cfg.optim.device)
4448

45-
# Make Env
49+
# Create env
4650
train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs)
4751

48-
# Make Buffer
52+
# Create replay buffer
4953
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
5054

51-
# Make Model
55+
# Create agent
5256
model = make_cql_model(cfg, train_env, eval_env, device)
5357

54-
# Make Loss
55-
loss_module, target_net_updater = make_loss(cfg.loss, model)
58+
# Create loss
59+
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
5660

57-
# Make Optimizer
58-
optimizer = make_cql_optimizer(cfg.optim, loss_module)
61+
# Create Optimizer
62+
(
63+
policy_optim,
64+
critic_optim,
65+
alpha_optim,
66+
alpha_prime_optim,
67+
) = make_continuous_cql_optimizer(cfg, loss_module)
5968

6069
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
6170

62-
r0 = None
63-
l0 = None
64-
6571
gradient_steps = cfg.optim.gradient_steps
72+
policy_eval_start = cfg.optim.policy_eval_start
6673
evaluation_interval = cfg.logger.eval_iter
6774
eval_steps = cfg.logger.eval_steps
6875

76+
# Training loop
77+
start_time = time.time()
6978
for i in range(gradient_steps):
7079
pbar.update(i)
80+
# sample data
7181
data = replay_buffer.sample()
72-
# loss
73-
loss_vals = loss_module(data)
74-
# backprop
75-
actor_loss = loss_vals["loss_actor"]
82+
# compute loss
83+
loss_vals = loss_module(data.clone().to(device))
84+
85+
# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
86+
if i >= policy_eval_start:
87+
actor_loss = loss_vals["loss_actor"]
88+
else:
89+
actor_loss = loss_vals["loss_actor_bc"]
7690
q_loss = loss_vals["loss_qvalue"]
77-
value_loss = loss_vals["loss_value"]
78-
loss_val = actor_loss + q_loss + value_loss
91+
cql_loss = loss_vals["loss_cql"]
92+
93+
q_loss = q_loss + cql_loss
94+
95+
alpha_loss = loss_vals["loss_alpha"]
96+
alpha_prime_loss = loss_vals["loss_alpha_prime"]
97+
98+
# update model
99+
alpha_loss = loss_vals["loss_alpha"]
100+
alpha_prime_loss = loss_vals["loss_alpha_prime"]
101+
102+
alpha_optim.zero_grad()
103+
alpha_loss.backward()
104+
alpha_optim.step()
79105

80-
optimizer.zero_grad()
81-
loss_val.backward()
82-
optimizer.step()
106+
policy_optim.zero_grad()
107+
actor_loss.backward()
108+
policy_optim.step()
109+
110+
if alpha_prime_optim is not None:
111+
alpha_prime_optim.zero_grad()
112+
alpha_prime_loss.backward(retain_graph=True)
113+
alpha_prime_optim.step()
114+
115+
critic_optim.zero_grad()
116+
# TODO: we have the option to compute losses independently retain is not needed?
117+
q_loss.backward(retain_graph=False)
118+
critic_optim.step()
119+
120+
loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
121+
122+
# log metrics
123+
to_log = {
124+
"loss": loss.item(),
125+
"loss_actor_bc": loss_vals["loss_actor_bc"].item(),
126+
"loss_actor": loss_vals["loss_actor"].item(),
127+
"loss_qvalue": q_loss.item(),
128+
"loss_cql": cql_loss.item(),
129+
"loss_alpha": alpha_loss.item(),
130+
"loss_alpha_prime": alpha_prime_loss.item(),
131+
}
132+
133+
# update qnet_target params
83134
target_net_updater.step()
84135

85136
# evaluation
@@ -88,20 +139,13 @@ def main(cfg: "DictConfig"): # noqa: F821
88139
eval_td = eval_env.rollout(
89140
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
90141
)
142+
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
143+
to_log["evaluation_reward"] = eval_reward
91144

92-
if r0 is None:
93-
r0 = eval_td["next", "reward"].sum(1).mean().item()
94-
if l0 is None:
95-
l0 = loss_val.item()
96-
97-
for key, value in loss_vals.items():
98-
logger.log_scalar(key, value.item(), i)
99-
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
100-
logger.log_scalar("evaluation_reward", eval_reward, i)
145+
log_metrics(logger, to_log, i)
101146

102-
pbar.set_description(
103-
f"loss: {loss_val.item(): 4.4f} (init: {l0: 4.4f}), evaluation_reward: {eval_reward: 4.4f} (init={r0: 4.4f})"
104-
)
147+
pbar.close()
148+
print(f"Training time: {time.time() - start_time}")
105149

106150

107151
if __name__ == "__main__":

0 commit comments

Comments
 (0)