Skip to content

Commit ffec97e

Browse files
committed
- added dropq to TQC (from pull request Stable-Baselines-Team#100)
1 parent e0335d7 commit ffec97e

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

sb3_contrib/tqc/policies.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def __init__(
209209
n_quantiles: int = 25,
210210
n_critics: int = 2,
211211
share_features_extractor: bool = False,
212+
dropout_rate: float = 0.0,
213+
layer_norm: bool = False
212214
):
213215
super().__init__(
214216
observation_space,
@@ -226,7 +228,14 @@ def __init__(
226228
self.quantiles_total = n_quantiles * n_critics
227229

228230
for i in range(n_critics):
229-
qf_net_list = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
231+
qf_net = create_mlp(
232+
features_dim + action_dim,
233+
n_quantiles,
234+
net_arch,
235+
activation_fn,
236+
dropout_rate=dropout_rate,
237+
layer_norm=layer_norm,
238+
)
230239
qf_net = nn.Sequential(*qf_net_list)
231240
self.add_module(f"qf{i}", qf_net)
232241
self.q_networks.append(qf_net)
@@ -294,6 +303,9 @@ def __init__(
294303
n_quantiles: int = 25,
295304
n_critics: int = 2,
296305
share_features_extractor: bool = False,
306+
# For the critic only
307+
dropout_rate: float = 0.0,
308+
layer_norm: bool = False,
297309
):
298310
super().__init__(
299311
observation_space,
@@ -335,6 +347,8 @@ def __init__(
335347
"n_critics": n_critics,
336348
"net_arch": critic_arch,
337349
"share_features_extractor": share_features_extractor,
350+
"dropout_rate": dropout_rate,
351+
"layer_norm": layer_norm,
338352
}
339353
self.critic_kwargs.update(tqc_kwargs)
340354
self.share_features_extractor = share_features_extractor

sb3_contrib/tqc/tqc.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
9595
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
9696
optimize_memory_usage: bool = False,
97+
policy_delay: int = 1,
9798
ent_coef: Union[str, float] = "auto",
9899
target_update_interval: int = 1,
99100
target_entropy: Union[str, float] = "auto",
@@ -145,6 +146,7 @@ def __init__(
145146
self.target_update_interval = target_update_interval
146147
self.ent_coef_optimizer: Optional[th.optim.Adam] = None
147148
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
149+
self.policy_delay = policy_delay
148150

149151
if _init_setup_model:
150152
self._setup_model()
@@ -190,7 +192,7 @@ def _create_aliases(self) -> None:
190192
self.critic = self.policy.critic
191193
self.critic_target = self.policy.critic_target
192194

193-
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
195+
def train(self, gradient_steps: int, batch_size: int = 64, train_freq: int = 1) -> None:
194196
# Switch to train mode (this affects batch norm / dropout)
195197
self.policy.set_training_mode(True)
196198
# Update optimizers learning rate
@@ -205,6 +207,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
205207
actor_losses, critic_losses = [], []
206208

207209
for gradient_step in range(gradient_steps):
210+
self._n_updates += 1
211+
update_actor = self._n_updates % self.policy_delay == 0
208212
# Sample replay buffer
209213
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
210214

@@ -222,8 +226,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
222226
# so we don't change it with other losses
223227
# see https://github.com/rail-berkeley/softlearning/issues/60
224228
ent_coef = th.exp(self.log_ent_coef.detach())
225-
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
226-
ent_coef_losses.append(ent_coef_loss.item())
229+
if update_actor:
230+
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
231+
ent_coef_losses.append(ent_coef_loss.item())
227232
else:
228233
ent_coef = self.ent_coef_tensor
229234

@@ -265,24 +270,23 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
265270
critic_loss.backward()
266271
self.critic.optimizer.step()
267272

268-
# Compute actor loss
269-
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
270-
actor_loss = (ent_coef * log_prob - qf_pi).mean()
271-
actor_losses.append(actor_loss.item())
273+
if update_actor:
274+
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
275+
actor_loss = (ent_coef * log_prob - qf_pi).mean()
276+
actor_losses.append(actor_loss.item())
277+
278+
# Optimize the actor
279+
self.actor.optimizer.zero_grad()
280+
actor_loss.backward()
281+
self.actor.optimizer.step()
272282

273-
# Optimize the actor
274-
self.actor.optimizer.zero_grad()
275-
actor_loss.backward()
276-
self.actor.optimizer.step()
277283

278284
# Update target networks
279285
if gradient_step % self.target_update_interval == 0:
280286
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
281287
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
282288
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
283289

284-
self._n_updates += gradient_steps
285-
286290
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
287291
self.logger.record("train/ent_coef", np.mean(ent_coefs))
288292
self.logger.record("train/actor_loss", np.mean(actor_losses))

0 commit comments

Comments
 (0)