-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DuelingDQN #127
Draft
qgallouedec
wants to merge
16
commits into
master
Choose a base branch
from
feat/dueling-dqn
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
DuelingDQN #127
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
61a24de
DuelingDQN
qgallouedec 888a06b
dueling_dqn.rst
qgallouedec 0d19cfe
Update changelog
qgallouedec 0a30312
Add example in example.rst
qgallouedec 726d4b9
add dueling to index.rst
qgallouedec 400e636
Add policy_aliases
qgallouedec b2ee629
test-cnn
qgallouedec 823760b
typo
qgallouedec bf44d99
simplification
qgallouedec 372da3e
typo
qgallouedec d93f00b
Rm policy_kwargs as error from copying from DRDQN
qgallouedec bd8755a
Update README
qgallouedec 9b23663
Update setup.py
qgallouedec 7a34a77
Add dueling to the list of algorithm
qgallouedec 989f31f
ignore mypy error
qgallouedec 738f09f
Merge branch 'master' into feat/dueling-dqn
qgallouedec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
.. _dueling_dqn: | ||
|
||
.. automodule:: sb3_contrib.dueling_dqn | ||
|
||
|
||
Dueling-DQN | ||
=========== | ||
|
||
`Dueling DQN <https://arxiv.org/abs/1511.06581>`_ builds on `Deep Q-Network (DQN) <https://arxiv.org/abs/1312.5602>`_ | ||
and #TODO: | ||
|
||
|
||
.. rubric:: Available Policies | ||
|
||
.. autosummary:: | ||
:nosignatures: | ||
|
||
MlpPolicy | ||
CnnPolicy | ||
MultiInputPolicy | ||
|
||
|
||
Notes | ||
----- | ||
|
||
- Original paper: https://arxiv.org/abs/1511.06581 | ||
|
||
|
||
Can I use? | ||
---------- | ||
|
||
- Recurrent policies: ❌ | ||
- Multi processing: ✔️ | ||
- Gym spaces: | ||
|
||
|
||
============= ====== =========== | ||
Space Action Observation | ||
============= ====== =========== | ||
Discrete ✔️ ✔️ | ||
Box ❌ ✔️ | ||
MultiDiscrete ❌ ✔️ | ||
MultiBinary ❌ ✔️ | ||
Dict ❌ ✔️ | ||
============= ====== =========== | ||
|
||
|
||
Example | ||
------- | ||
|
||
.. code-block:: python | ||
|
||
import gym | ||
|
||
from sb3_contrib import DuelingDQN | ||
|
||
env = gym.make("CartPole-v1") | ||
|
||
model = DuelingDQN("MlpPolicy", env, verbose=1) | ||
model.learn(total_timesteps=10000, log_interval=4) | ||
model.save("dueling_dqn_cartpole") | ||
|
||
del model # remove to demonstrate saving and loading | ||
|
||
model = DuelingDQN.load("dueling_dqn_cartpole") | ||
|
||
obs = env.reset() | ||
while True: | ||
action, _states = model.predict(obs, deterministic=True) | ||
obs, reward, done, info = env.step(action) | ||
env.render() | ||
if done: | ||
obs = env.reset() | ||
|
||
|
||
Results | ||
------- | ||
|
||
Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds. | ||
|
||
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/126>`_. #TODO: | ||
|
||
|
||
.. note:: | ||
|
||
DuelingDQN implementation was validated against #TODO: valid the results | ||
|
||
|
||
============ =========== =========== | ||
Environments DuelingDQN DQN | ||
============ =========== =========== | ||
Breakout ~300 | ||
Pong ~20 | ||
CartPole 500 +/- 0 | ||
MountainCar -107 +/- 4 | ||
LunarLander 195 +/- 28 | ||
Acrobot -74 +/- 2 | ||
============ =========== =========== | ||
|
||
#TODO: Fill the tabular | ||
|
||
How to replicate the results? | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
Clone RL-Zoo fork and checkout the branch ``feat/dueling-dqn``: | ||
|
||
.. code-block:: bash | ||
|
||
git clone https://github.com/DLR-RM/rl-baselines3-zoo/ | ||
cd rl-baselines3-zoo/ | ||
git checkout feat/dueling-dqn #TODO: create this branch | ||
|
||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): | ||
|
||
.. code-block:: bash | ||
|
||
python train.py --algo dueling_dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000 #TODO: check if that command line works | ||
|
||
|
||
Plot the results: | ||
|
||
.. code-block:: bash | ||
|
||
python scripts/all_plots.py -a dueling_dqn -e Breakout Pong -f logs/ -o logs/dueling_dqn_results #TODO: check if that command line works | ||
python scripts/plot_from_file.py -i logs/dueling_dqn_results.pkl -latex -l Dueling DQN #TODO: check if that command line works | ||
|
||
|
||
|
||
Parameters | ||
---------- | ||
|
||
.. autoclass:: DuelingDQN | ||
:members: | ||
:inherited-members: | ||
|
||
.. _dueling_dqn_policies: | ||
|
||
Dueling DQN Policies | ||
-------------------- | ||
|
||
.. autoclass:: MlpPolicy | ||
:members: | ||
:inherited-members: | ||
|
||
.. autoclass:: sb3_contrib.dueling_dqn.policies.DuelingDQNPolicy | ||
:members: | ||
:noindex: | ||
|
||
.. autoclass:: CnnPolicy | ||
:members: | ||
|
||
.. autoclass:: MultiInputPolicy | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from sb3_contrib.dueling_dqn.dueling_dqn import DuelingDQN | ||
from sb3_contrib.dueling_dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy | ||
|
||
__all__ = ["DuelingDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union | ||
|
||
import torch as th | ||
from stable_baselines3.common.buffers import ReplayBuffer | ||
from stable_baselines3.common.policies import BasePolicy | ||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule | ||
from stable_baselines3.dqn.dqn import DQN | ||
|
||
from sb3_contrib.dueling_dqn.policies import CnnPolicy, DuelingDQNPolicy, MlpPolicy, MultiInputPolicy | ||
|
||
SelfDuelingDQN = TypeVar("SelfDuelingDQN", bound="DuelingDQN") | ||
|
||
|
||
class DuelingDQN(DQN): | ||
""" | ||
Dueling Deep Q-Network (Dueling DQN) | ||
|
||
Paper: https://arxiv.org/abs/1511.06581 | ||
|
||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) | ||
:param env: The environment to learn from (if registered in Gym, can be str) | ||
:param learning_rate: The learning rate, it can be a function | ||
of the current progress remaining (from 1 to 0) | ||
:param buffer_size: size of the replay buffer | ||
:param learning_starts: how many steps of the model to collect transitions for before learning starts | ||
:param batch_size: Minibatch size for each gradient update | ||
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update | ||
:param gamma: the discount factor | ||
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit | ||
like ``(5, "step")`` or ``(2, "episode")``. | ||
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) | ||
Set to ``-1`` means to do as many gradient steps as steps done in the environment | ||
during the rollout. | ||
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). | ||
If ``None``, it will be automatically selected. | ||
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. | ||
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer | ||
at a cost of more complexity. | ||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 | ||
:param target_update_interval: update the target network every ``target_update_interval`` | ||
environment steps. | ||
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced | ||
:param exploration_initial_eps: initial value of random action probability | ||
:param exploration_final_eps: final value of random action probability | ||
:param max_grad_norm: The maximum value for the gradient clipping | ||
:param tensorboard_log: the log location for tensorboard (if None, no logging) | ||
:param policy_kwargs: additional arguments to be passed to the policy on creation | ||
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for | ||
debug messages | ||
:param seed: Seed for the pseudo random generators | ||
:param device: Device (cpu, cuda, ...) on which the code should be run. | ||
Setting it to auto, the code will be run on the GPU if possible. | ||
:param _init_setup_model: Whether or not to build the network at the creation of the instance | ||
""" | ||
|
||
policy_aliases: Dict[str, Type[BasePolicy]] = { | ||
"MlpPolicy": MlpPolicy, | ||
"CnnPolicy": CnnPolicy, | ||
"MultiInputPolicy": MultiInputPolicy, | ||
} | ||
|
||
def __init__( | ||
self, | ||
policy: Union[str, Type[DuelingDQNPolicy]], | ||
env: Union[GymEnv, str], | ||
learning_rate: Union[float, Schedule] = 0.0001, | ||
buffer_size: int = 1000000, | ||
learning_starts: int = 50000, | ||
batch_size: int = 32, | ||
tau: float = 1, | ||
gamma: float = 0.99, | ||
train_freq: Union[int, Tuple[int, str]] = 4, | ||
gradient_steps: int = 1, | ||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None, | ||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None, | ||
optimize_memory_usage: bool = False, | ||
target_update_interval: int = 10000, | ||
exploration_fraction: float = 0.1, | ||
exploration_initial_eps: float = 1, | ||
exploration_final_eps: float = 0.05, | ||
max_grad_norm: float = 10, | ||
tensorboard_log: Optional[str] = None, | ||
policy_kwargs: Optional[Dict[str, Any]] = None, | ||
verbose: int = 0, | ||
seed: Optional[int] = None, | ||
device: Union[th.device, str] = "auto", | ||
_init_setup_model: bool = True, | ||
): | ||
super().__init__( | ||
policy, | ||
env, | ||
learning_rate, | ||
buffer_size, | ||
learning_starts, | ||
batch_size, | ||
tau, | ||
gamma, | ||
train_freq, | ||
gradient_steps, | ||
replay_buffer_class, | ||
replay_buffer_kwargs, | ||
optimize_memory_usage, | ||
target_update_interval, | ||
exploration_fraction, | ||
exploration_initial_eps, | ||
exploration_final_eps, | ||
max_grad_norm, | ||
tensorboard_log, | ||
policy_kwargs, | ||
verbose, | ||
seed, | ||
device, | ||
_init_setup_model, | ||
) | ||
|
||
def learn( | ||
self: SelfDuelingDQN, | ||
total_timesteps: int, | ||
callback: MaybeCallback = None, | ||
log_interval: int = 4, | ||
tb_log_name: str = "DuelingDQN", | ||
reset_num_timesteps: bool = True, | ||
progress_bar: bool = False, | ||
) -> SelfDuelingDQN: | ||
return super().learn( | ||
total_timesteps, | ||
callback, | ||
log_interval, | ||
tb_log_name, | ||
reset_num_timesteps, | ||
progress_bar, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if init is the same as DQN, I guess you can drop it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only (but necessary) diff is the policy type hint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see...