Skip to content

Commit

Permalink
[Minor] Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 28, 2024
1 parent 7adccce commit d56730a
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 22 deletions.
2 changes: 1 addition & 1 deletion sota-check/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export MUJOCO_GL=egl

conda create -n rl-sota-bench python=3.10 -y
conda install anaconda::libglu -y
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install "gymnasium[accept-rom-license,atari,mujoco]" vmas tqdm wandb pygame moviepy imageio submitit hydra-core transformers

cd /path/to/tensordict
Expand Down
3 changes: 2 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2549,7 +2549,8 @@ def _step(self, tensordict):
"reward": action.sum().unsqueeze(0),
**self.full_done_spec.zero(),
"observation": obs,
}
},
batch_size=[],
)

torch.manual_seed(0)
Expand Down
17 changes: 0 additions & 17 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,6 @@ def _check_for_empty_spec(specs: CompositeSpec):
def map_device(key, value, device_map=device_map):
return value.to(device_map[key])

# self._env_tensordict.named_apply(
# map_device, nested_keys=True, filter_empty=True
# )
self._env_tensordict.named_apply(
map_device,
nested_keys=True,
Expand Down Expand Up @@ -809,11 +806,6 @@ def select_and_clone(name, tensor):
if name in selected_output_keys:
return tensor.clone()

# out = self.shared_tensordict_parent.named_apply(
# select_and_clone,
# nested_keys=True,
# filter_empty=True,
# )
out = self.shared_tensordict_parent.named_apply(
select_and_clone,
nested_keys=True,
Expand Down Expand Up @@ -1208,14 +1200,12 @@ def step_and_maybe_reset(
if x.device != device
else x.clone(),
device=device,
# filter_empty=True,
)
tensordict_ = tensordict_._fast_apply(
lambda x: x.to(device, non_blocking=self.non_blocking)
if x.device != device
else x.clone(),
device=device,
# filter_empty=True,
)
else:
next_td = next_td.clone().clear_device_()
Expand Down Expand Up @@ -1271,7 +1261,6 @@ def select_and_clone(name, tensor):
out = next_td.named_apply(
select_and_clone,
nested_keys=True,
# filter_empty=True,
)
if out.device != device:
if device is None:
Expand Down Expand Up @@ -1357,7 +1346,6 @@ def select_and_clone(name, tensor):
out = self.shared_tensordict_parent.named_apply(
select_and_clone,
nested_keys=True,
# filter_empty=True,
)
del out["next"]

Expand Down Expand Up @@ -1495,7 +1483,6 @@ def _run_worker_pipe_shared_mem(
def look_for_cuda(tensor, has_cuda=has_cuda):
has_cuda[0] = has_cuda[0] or tensor.is_cuda

# shared_tensordict.apply(look_for_cuda, filter_empty=True)
shared_tensordict.apply(look_for_cuda)
has_cuda = has_cuda[0]
else:
Expand Down Expand Up @@ -1685,9 +1672,5 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
child_pipe.send(("_".join([cmd, "done"]), None))


def _filter_empty(tensordict):
return tensordict.select(*tensordict.keys(True, True))


# Create an alias for possible imports
_BatchedEnv = BatchedEnvBase
1 change: 0 additions & 1 deletion torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def _compare_and_expand(param):
return param._apply_nest(
_compare_and_expand,
batch_size=[expand_dim, *param.shape],
filter_empty=False,
call_on_nested=True,
)
if not isinstance(param, nn.Parameter):
Expand Down
4 changes: 2 additions & 2 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
entropy = -dist.log_prob(x)
entropy = -dist.log_prob(x).mean(0)
return entropy.unsqueeze(-1)

def _log_weight(
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())

if self.critic_coef:
loss_critic = self.loss_critic(tensordict)
loss_critic = self.loss_critic(tensordict_copy)
td_out.set("loss_critic", loss_critic.mean())

return td_out
Expand Down

0 comments on commit d56730a

Please sign in to comment.