Skip to content

Commit cdda5fc

Browse files
authored
Add Double DQN (#2148)
1 parent 08987a1 commit cdda5fc

File tree

4 files changed

+77
-6
lines changed

4 files changed

+77
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ The table below summarizes the algorithms available in garage.
5757
| REINFORCE (a.k.a. VPG) | PyTorch, TensorFlow |
5858
| DDPG | PyTorch, TensorFlow |
5959
| DQN | PyTorch, TensorFlow |
60-
| DDQN | TensorFlow |
60+
| DDQN | PyTorch, TensorFlow |
6161
| ERWR | TensorFlow |
6262
| NPO | TensorFlow |
6363
| PPO | PyTorch, TensorFlow |

examples/torch/dqn_atari.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
44
Here it creates a gym environment CartPole, and trains a DQN with 50k steps.
55
"""
6+
import math
7+
68
import click
79
import gym
810
import numpy as np
@@ -54,11 +56,13 @@
5456
@click.option('--seed', default=24)
5557
@click.option('--n', type=int, default=psutil.cpu_count(logical=False))
5658
@click.option('--buffer_size', type=int, default=None)
59+
@click.option('--n_steps', type=float, default=None)
5760
@click.option('--max_episode_length', type=int, default=None)
5861
def main(env=None,
5962
seed=24,
6063
n=psutil.cpu_count(logical=False),
6164
buffer_size=None,
65+
n_steps=None,
6266
max_episode_length=None):
6367
"""Wrapper to setup the logging directory.
6468
@@ -73,6 +77,9 @@ def main(env=None,
7377
buffer_size (int): size of the replay buffer in transitions. If None,
7478
defaults to hyperparams['buffer_size']. This is used by the
7579
integration tests.
80+
n_steps (float): Total number of environment steps to run for, not
81+
not including evaluation. If this is not None, n_epochs will
82+
be recalculated based on this value.
7683
max_episode_length (int): Max length of an episode. If None, defaults
7784
to the timelimit specific to the environment. Used by integration
7885
tests.
@@ -81,6 +88,10 @@ def main(env=None,
8188
env += 'NoFrameskip-v4'
8289
logdir = 'data/local/experiment/' + env
8390

91+
if n_steps is not None:
92+
hyperparams['n_epochs'] = math.ceil(
93+
int(n_steps) / (hyperparams['steps_per_epoch'] *
94+
hyperparams['sampler_batch_size']))
8495
if buffer_size is not None:
8596
hyperparams['buffer_size'] = buffer_size
8697

src/garage/torch/algos/dqn.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class DQN(RLAlgorithm):
3232
n_train_steps (int): Training steps.
3333
eval_env (Environment): Evaluation environment. If None, a copy of the
3434
main environment is used for evaluation.
35+
double_q (bool): Whether to use Double DQN.
36+
See https://arxiv.org/abs/1509.06461.
3537
max_episode_length_eval (int or None): Maximum length of episodes used
3638
for off-policy evaluation. If `None`, defaults to
3739
`env_spec.max_episode_length`.
@@ -67,6 +69,7 @@ def __init__(
6769
replay_buffer,
6870
exploration_policy=None,
6971
eval_env=None,
72+
double_q=True,
7073
qf_optimizer=torch.optim.Adam,
7174
*, # Everything after this is numbers.
7275
steps_per_epoch=20,
@@ -100,6 +103,7 @@ def __init__(
100103
self._steps_per_epoch = steps_per_epoch
101104
self._n_train_steps = n_train_steps
102105
self._buffer_batch_size = buffer_batch_size
106+
self._double_q = double_q
103107
self._discount = discount
104108
self._reward_scale = reward_scale
105109
self.max_episode_length = env_spec.max_episode_length
@@ -246,10 +250,18 @@ def _optimize_qf(self, timesteps):
246250
next_inputs = next_observations
247251
inputs = observations
248252
with torch.no_grad():
249-
# discrete, outputs Qs for all possible actions
250-
target_qvals = self._target_qf(next_inputs)
251-
best_qvals, _ = torch.max(target_qvals, 1)
252-
best_qvals = best_qvals.unsqueeze(1)
253+
if self._double_q:
254+
# Use online qf to get optimal actions
255+
selected_actions = torch.argmax(self._qf(next_inputs), axis=1)
256+
# use target qf to get Q values for those actions
257+
selected_actions = selected_actions.long().unsqueeze(1)
258+
best_qvals = torch.gather(self._target_qf(next_inputs),
259+
dim=1,
260+
index=selected_actions)
261+
else:
262+
target_qvals = self._target_qf(next_inputs)
263+
best_qvals, _ = torch.max(target_qvals, 1)
264+
best_qvals = best_qvals.unsqueeze(1)
253265

254266
rewards_clipped = rewards
255267
if self._clip_reward is not None:

tests/garage/torch/algos/test_dqn.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def setup():
2929
steps_per_epoch = 10
3030
sampler_batch_size = 512
3131
num_timesteps = 100 * steps_per_epoch * sampler_batch_size
32-
3332
env = GymEnv('CartPole-v0')
3433

3534
replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))
@@ -50,6 +49,7 @@ def setup():
5049
replay_buffer=replay_buffer,
5150
steps_per_epoch=steps_per_epoch,
5251
qf_lr=5e-5,
52+
double_q=False,
5353
discount=0.9,
5454
min_buffer_size=int(1e4),
5555
n_train_steps=500,
@@ -121,6 +121,54 @@ def test_dqn_loss(setup):
121121
assert (selected_qs == algo_selected_qs).all()
122122

123123

124+
def test_double_dqn_loss(setup):
125+
algo, env, buff, _, batch_size = setup
126+
127+
algo._double_q = True
128+
trainer = Trainer(snapshot_config)
129+
trainer.setup(algo, env, sampler_cls=LocalSampler)
130+
131+
paths = trainer.obtain_episodes(0, batch_size=batch_size)
132+
buff.add_episode_batch(paths)
133+
timesteps = buff.sample_timesteps(algo._buffer_batch_size)
134+
timesteps_copy = copy.deepcopy(timesteps)
135+
136+
observations = np_to_torch(timesteps.observations)
137+
rewards = np_to_torch(timesteps.rewards).reshape(-1, 1)
138+
actions = np_to_torch(timesteps.actions)
139+
next_observations = np_to_torch(timesteps.next_observations)
140+
terminals = np_to_torch(timesteps.terminals).reshape(-1, 1)
141+
142+
next_inputs = next_observations
143+
inputs = observations
144+
with torch.no_grad():
145+
# double Q loss
146+
selected_actions = torch.argmax(algo._qf(next_inputs), axis=1)
147+
# use target qf to get Q values for those actions
148+
selected_actions = selected_actions.long().unsqueeze(1)
149+
best_qvals = torch.gather(algo._target_qf(next_inputs),
150+
dim=1,
151+
index=selected_actions)
152+
153+
rewards_clipped = rewards
154+
y_target = (rewards_clipped +
155+
(1.0 - terminals) * algo._discount * best_qvals)
156+
y_target = y_target.squeeze(1)
157+
158+
# optimize qf
159+
qvals = algo._qf(inputs)
160+
selected_qs = torch.sum(qvals * actions, axis=1)
161+
qval_loss = F.smooth_l1_loss(selected_qs, y_target)
162+
163+
algo_loss, algo_targets, algo_selected_qs = algo._optimize_qf(
164+
timesteps_copy)
165+
env.close()
166+
167+
assert (qval_loss.detach() == algo_loss).all()
168+
assert (y_target == algo_targets).all()
169+
assert (selected_qs == algo_selected_qs).all()
170+
171+
124172
def test_to_device(setup):
125173
algo, _, _, _, _ = setup
126174
algo._qf.to = MagicMock(name='to')

0 commit comments

Comments
 (0)