@@ -94,6 +94,7 @@ def __init__(
94
94
replay_buffer_class : Optional [type [ReplayBuffer ]] = None ,
95
95
replay_buffer_kwargs : Optional [dict [str , Any ]] = None ,
96
96
optimize_memory_usage : bool = False ,
97
+ policy_delay : int = 1 ,
97
98
ent_coef : Union [str , float ] = "auto" ,
98
99
target_update_interval : int = 1 ,
99
100
target_entropy : Union [str , float ] = "auto" ,
@@ -145,6 +146,7 @@ def __init__(
145
146
self .target_update_interval = target_update_interval
146
147
self .ent_coef_optimizer : Optional [th .optim .Adam ] = None
147
148
self .top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
149
+ self .policy_delay = policy_delay
148
150
149
151
if _init_setup_model :
150
152
self ._setup_model ()
@@ -190,7 +192,7 @@ def _create_aliases(self) -> None:
190
192
self .critic = self .policy .critic
191
193
self .critic_target = self .policy .critic_target
192
194
193
- def train (self , gradient_steps : int , batch_size : int = 64 ) -> None :
195
+ def train (self , gradient_steps : int , batch_size : int = 64 , train_freq : int = 1 ) -> None :
194
196
# Switch to train mode (this affects batch norm / dropout)
195
197
self .policy .set_training_mode (True )
196
198
# Update optimizers learning rate
@@ -205,6 +207,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
205
207
actor_losses , critic_losses = [], []
206
208
207
209
for gradient_step in range (gradient_steps ):
210
+ self ._n_updates += 1
211
+ update_actor = self ._n_updates % self .policy_delay == 0
208
212
# Sample replay buffer
209
213
replay_data = self .replay_buffer .sample (batch_size , env = self ._vec_normalize_env ) # type: ignore[union-attr]
210
214
@@ -222,8 +226,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
222
226
# so we don't change it with other losses
223
227
# see https://github.com/rail-berkeley/softlearning/issues/60
224
228
ent_coef = th .exp (self .log_ent_coef .detach ())
225
- ent_coef_loss = - (self .log_ent_coef * (log_prob + self .target_entropy ).detach ()).mean ()
226
- ent_coef_losses .append (ent_coef_loss .item ())
229
+ if update_actor :
230
+ ent_coef_loss = - (self .log_ent_coef * (log_prob + self .target_entropy ).detach ()).mean ()
231
+ ent_coef_losses .append (ent_coef_loss .item ())
227
232
else :
228
233
ent_coef = self .ent_coef_tensor
229
234
@@ -265,24 +270,23 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
265
270
critic_loss .backward ()
266
271
self .critic .optimizer .step ()
267
272
268
- # Compute actor loss
269
- qf_pi = self .critic (replay_data .observations , actions_pi ).mean (dim = 2 ).mean (dim = 1 , keepdim = True )
270
- actor_loss = (ent_coef * log_prob - qf_pi ).mean ()
271
- actor_losses .append (actor_loss .item ())
273
+ if update_actor :
274
+ qf_pi = self .critic (replay_data .observations , actions_pi ).mean (dim = 2 ).mean (dim = 1 , keepdim = True )
275
+ actor_loss = (ent_coef * log_prob - qf_pi ).mean ()
276
+ actor_losses .append (actor_loss .item ())
277
+
278
+ # Optimize the actor
279
+ self .actor .optimizer .zero_grad ()
280
+ actor_loss .backward ()
281
+ self .actor .optimizer .step ()
272
282
273
- # Optimize the actor
274
- self .actor .optimizer .zero_grad ()
275
- actor_loss .backward ()
276
- self .actor .optimizer .step ()
277
283
278
284
# Update target networks
279
285
if gradient_step % self .target_update_interval == 0 :
280
286
polyak_update (self .critic .parameters (), self .critic_target .parameters (), self .tau )
281
287
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
282
288
polyak_update (self .batch_norm_stats , self .batch_norm_stats_target , 1.0 )
283
289
284
- self ._n_updates += gradient_steps
285
-
286
290
self .logger .record ("train/n_updates" , self ._n_updates , exclude = "tensorboard" )
287
291
self .logger .record ("train/ent_coef" , np .mean (ent_coefs ))
288
292
self .logger .record ("train/actor_loss" , np .mean (actor_losses ))
0 commit comments