@@ -527,12 +527,9 @@ def _log_weight(
527
527
self .actor_network
528
528
) if self .functional else contextlib .nullcontext ():
529
529
dist = self .actor_network .get_dist (tensordict )
530
- if isinstance (dist , CompositeDistribution ):
531
- is_composite = True
532
- else :
533
- is_composite = False
534
530
535
- # current log_prob of actions
531
+ is_composite = isinstance (dist , CompositeDistribution )
532
+
536
533
if is_composite :
537
534
action = tensordict .select (
538
535
* (
@@ -562,25 +559,26 @@ def _log_weight(
562
559
log_prob = dist .log_prob (action )
563
560
if is_composite :
564
561
with set_composite_lp_aggregate (False ):
562
+ if log_prob .batch_size != adv_shape :
563
+ log_prob .batch_size = adv_shape
565
564
if not is_tensor_collection (prev_log_prob ):
566
- # this isn't great, in general multihead actions should have a composite log-prob too
565
+ # this isn't great: in general, multi-head actions should have a composite log-prob too
567
566
warnings .warn (
568
567
"You are using a composite distribution, yet your log-probability is a tensor. "
569
568
"Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at "
570
569
"the beginning of your script to get a proper composite log-prob." ,
571
570
category = UserWarning ,
572
571
)
573
- if log_prob .batch_size != adv_shape :
574
- log_prob .batch_size = adv_shape
575
- if (
576
- is_composite
577
- and not is_tensor_collection (prev_log_prob )
578
- and is_tensor_collection (log_prob )
579
- ):
580
- log_prob = _sum_td_features (log_prob )
581
- log_prob .view_as (prev_log_prob )
572
+
573
+ if is_tensor_collection (log_prob ):
574
+ log_prob = _sum_td_features (log_prob )
575
+ log_prob .view_as (prev_log_prob )
582
576
583
577
log_weight = (log_prob - prev_log_prob ).unsqueeze (- 1 )
578
+ if is_tensor_collection (log_weight ):
579
+ log_weight = _sum_td_features (log_weight )
580
+ log_weight = log_weight .view (adv_shape ).unsqueeze (- 1 )
581
+
584
582
kl_approx = (prev_log_prob - log_prob ).unsqueeze (- 1 )
585
583
if is_tensor_collection (kl_approx ):
586
584
kl_approx = _sum_td_features (kl_approx )
@@ -691,9 +689,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
691
689
log_weight , dist , kl_approx = self ._log_weight (
692
690
tensordict , adv_shape = advantage .shape [:- 1 ]
693
691
)
694
- if is_tensor_collection (log_weight ):
695
- log_weight = _sum_td_features (log_weight )
696
- log_weight = log_weight .view (advantage .shape )
697
692
neg_loss = log_weight .exp () * advantage
698
693
td_out = TensorDict ({"loss_objective" : - neg_loss })
699
694
td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
@@ -987,8 +982,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
987
982
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
988
983
# dispersion.
989
984
lw = log_weight .squeeze ()
990
- if not isinstance (lw , torch .Tensor ):
991
- lw = _sum_td_features (lw )
992
985
ess = (2 * lw .logsumexp (0 ) - (2 * lw ).logsumexp (0 )).exp ()
993
986
batch = log_weight .shape [0 ]
994
987
@@ -1000,8 +993,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1000
993
gain2 = ratio * advantage
1001
994
1002
995
gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
1003
- if is_tensor_collection (gain ):
1004
- gain = _sum_td_features (gain )
1005
996
td_out = TensorDict ({"loss_objective" : - gain })
1006
997
td_out .set ("clip_fraction" , clip_fraction )
1007
998
td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
@@ -1291,8 +1282,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
1291
1282
tensordict_copy , adv_shape = advantage .shape [:- 1 ]
1292
1283
)
1293
1284
neg_loss = log_weight .exp () * advantage
1294
- if is_tensor_collection (neg_loss ):
1295
- neg_loss = _sum_td_features (neg_loss )
1296
1285
1297
1286
with self .actor_network_params .to_module (
1298
1287
self .actor_network
0 commit comments