Skip to content

Commit

Permalink
Fix DQN and C51
Browse files Browse the repository at this point in the history
  • Loading branch information
VanillaWhey committed Nov 29, 2024
1 parent 9b8e231 commit 96f6813
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 19 deletions.
13 changes: 12 additions & 1 deletion cleanrl/architectures/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def __init__(self, env):
def forward(self, x):
return self.network(x / 255.0)

def get_action_and_value(self, x):
q_values = self.network(x / 255.0)
action = torch.argmax(q_values, 1)
return action, q_values[action], None, None


class QNetwork_C51(nn.Module):
def __init__(self, env, n_atoms=101, v_min=-100, v_max=100):
Expand Down Expand Up @@ -59,4 +64,10 @@ def get_action(self, x, action=None):
action = torch.argmax(q_values, 1)
return action, pmfs[torch.arange(len(x)), action]


def get_action_and_value(self, x):
logits = self.network(x / 255.0)
# probability mass function for each action
pmfs = torch.softmax(logits.view(len(x), self.n, self.n_atoms), dim=2)
q_values = (pmfs * self.atoms).sum(2)
action = torch.argmax(q_values, 1)
return action, q_values[:, action], None, None
14 changes: 7 additions & 7 deletions cleanrl/c51_atari_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ class Args:
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 1
"""the number of parallel game environments"""
n_atoms: int = 51
Expand Down Expand Up @@ -299,7 +297,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
# TRY NOT TO MODIFY: start the game
#import ipdb;ipdb.set_trace()
obs = envs.reset()
for global_step in range(args.total_timesteps):
pbar = tqdm(range(1, args.total_timesteps + 1), postfix=postfix)
for global_step in pbar:
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
if random.random() < epsilon:
Expand Down Expand Up @@ -394,23 +393,24 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
# Save the trained model to disk
model_path = f"{writer_dir}/{args.exp_name}.cleanrl_model"
model_data = {
"model_weights": agent.state_dict(),
"model_weights": q_network.state_dict(),
"args": vars(args),
}
torch.save(model_data, model_path)
logger.info(f"model saved to {model_path} in epoch {epoch}")
logger.info(f"model saved to {model_path} in epoch {global_step}")

# Log final model and performance with Weights and Biases if enabled
if args.track:
# Evaluate agent's performance
args.new_rf = ""
rewards = evaluate(
agent, make_env, 10,
q_network, make_env, 10,
env_id=args.env_id,
capture_video=args.capture_video,
run_dir=writer_dir,
feature_func=args.feature_func,
window_size=args.buffer_window_size
window_size=args.buffer_window_size,
device=device
)

wandb.log({"FinalReward": np.mean(rewards)})
Expand Down
19 changes: 9 additions & 10 deletions cleanrl/dqn_atari_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class Args:
"""the learning rate of the optimizer"""
num_envs: int = 1
"""the number of parallel game environments"""
buffer_size: int = 1000000
buffer_size: int = 1_000_000
"""the replay memory buffer size"""
gamma: float = 0.99
"""the discount factor gamma"""
Expand All @@ -119,7 +119,7 @@ class Args:
"""the ending epsilon for exploration"""
exploration_fraction: float = 0.10
"""the fraction of `total-timesteps` it takes from start-e to go end-e"""
learning_starts: int = 80000
learning_starts: int = 80_000
"""timestep to start learning"""
train_frequency: int = 4
"""the frequency of training"""
Expand Down Expand Up @@ -268,7 +268,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
target_network = QNetwork(envs).to(device)
target_network.load_state_dict(q_network.state_dict())


rb = ReplayBuffer(
args.buffer_size,
Expand All @@ -278,8 +277,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
optimize_memory_usage=True,
handle_timeout_termination=False,
)
start_time = time.time()


# Start training loop
global_step = 0
Expand All @@ -303,7 +300,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
done_in_episode = False

obs = envs.reset()
for global_step in range(args.total_timesteps):
pbar = tqdm(range(1, args.total_timesteps + 1), postfix=postfix)
for global_step in pbar:

# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand Down Expand Up @@ -382,23 +380,24 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
# Save the trained model to disk
model_path = f"{writer_dir}/{args.exp_name}.cleanrl_model"
model_data = {
"model_weights": agent.state_dict(),
"model_weights": q_network.state_dict(),
"args": vars(args),
}
torch.save(model_data, model_path)
logger.info(f"model saved to {model_path} in epoch {epoch}")
logger.info(f"model saved to {model_path} in epoch {global_step}")

# Log final model and performance with Weights and Biases if enabled
if args.track:
# Evaluate agent's performance
args.new_rf = ""
rewards = evaluate(
agent, make_env, 10,
q_network, make_env, 10,
env_id=args.env_id,
capture_video=args.capture_video,
run_dir=writer_dir,
feature_func=args.feature_func,
window_size=args.buffer_window_size
window_size=args.buffer_window_size,
device=device
)

wandb.log({"FinalReward": np.mean(rewards)})
Expand Down
3 changes: 2 additions & 1 deletion cleanrl_utils/evals/generic_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def evaluate(
agent,
make_env: Callable,
eval_episodes: int,
device,
**env_kwargs
):
envs = gym.vector.SyncVectorEnv([make_env(idx=0, **env_kwargs)])
Expand All @@ -16,7 +17,7 @@ def evaluate(
obs, _ = envs.reset()
episodic_returns = []
while len(episodic_returns) < eval_episodes:
actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(agent.device))
actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device))
next_obs, _, _, _, infos = envs.step(actions.cpu().numpy())
if "final_info" in infos:
for info in infos["final_info"]:
Expand Down

0 comments on commit 96f6813

Please sign in to comment.