Skip to content

Commit 882dc79

Browse files
louisfauryLouis Faury
authored andcommitted
[BugFix] PPOs with composite distribution (#2791)
Co-authored-by: Louis Faury <[email protected]> (cherry picked from commit edfa25d)
1 parent 2ebcb2e commit 882dc79

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

torchrl/objectives/ppo.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -527,12 +527,9 @@ def _log_weight(
527527
self.actor_network
528528
) if self.functional else contextlib.nullcontext():
529529
dist = self.actor_network.get_dist(tensordict)
530-
if isinstance(dist, CompositeDistribution):
531-
is_composite = True
532-
else:
533-
is_composite = False
534530

535-
# current log_prob of actions
531+
is_composite = isinstance(dist, CompositeDistribution)
532+
536533
if is_composite:
537534
action = tensordict.select(
538535
*(
@@ -562,25 +559,26 @@ def _log_weight(
562559
log_prob = dist.log_prob(action)
563560
if is_composite:
564561
with set_composite_lp_aggregate(False):
562+
if log_prob.batch_size != adv_shape:
563+
log_prob.batch_size = adv_shape
565564
if not is_tensor_collection(prev_log_prob):
566-
# this isn't great, in general multihead actions should have a composite log-prob too
565+
# this isn't great: in general, multi-head actions should have a composite log-prob too
567566
warnings.warn(
568567
"You are using a composite distribution, yet your log-probability is a tensor. "
569568
"Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at "
570569
"the beginning of your script to get a proper composite log-prob.",
571570
category=UserWarning,
572571
)
573-
if log_prob.batch_size != adv_shape:
574-
log_prob.batch_size = adv_shape
575-
if (
576-
is_composite
577-
and not is_tensor_collection(prev_log_prob)
578-
and is_tensor_collection(log_prob)
579-
):
580-
log_prob = _sum_td_features(log_prob)
581-
log_prob.view_as(prev_log_prob)
572+
573+
if is_tensor_collection(log_prob):
574+
log_prob = _sum_td_features(log_prob)
575+
log_prob.view_as(prev_log_prob)
582576

583577
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
578+
if is_tensor_collection(log_weight):
579+
log_weight = _sum_td_features(log_weight)
580+
log_weight = log_weight.view(adv_shape).unsqueeze(-1)
581+
584582
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
585583
if is_tensor_collection(kl_approx):
586584
kl_approx = _sum_td_features(kl_approx)
@@ -691,9 +689,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
691689
log_weight, dist, kl_approx = self._log_weight(
692690
tensordict, adv_shape=advantage.shape[:-1]
693691
)
694-
if is_tensor_collection(log_weight):
695-
log_weight = _sum_td_features(log_weight)
696-
log_weight = log_weight.view(advantage.shape)
697692
neg_loss = log_weight.exp() * advantage
698693
td_out = TensorDict({"loss_objective": -neg_loss})
699694
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
@@ -987,8 +982,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
987982
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
988983
# dispersion.
989984
lw = log_weight.squeeze()
990-
if not isinstance(lw, torch.Tensor):
991-
lw = _sum_td_features(lw)
992985
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
993986
batch = log_weight.shape[0]
994987

@@ -1000,8 +993,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1000993
gain2 = ratio * advantage
1001994

1002995
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
1003-
if is_tensor_collection(gain):
1004-
gain = _sum_td_features(gain)
1005996
td_out = TensorDict({"loss_objective": -gain})
1006997
td_out.set("clip_fraction", clip_fraction)
1007998
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
@@ -1291,8 +1282,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
12911282
tensordict_copy, adv_shape=advantage.shape[:-1]
12921283
)
12931284
neg_loss = log_weight.exp() * advantage
1294-
if is_tensor_collection(neg_loss):
1295-
neg_loss = _sum_td_features(neg_loss)
12961285

12971286
with self.actor_network_params.to_module(
12981287
self.actor_network

0 commit comments

Comments
 (0)