diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index c2df40be012..7e0fef99786 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -9,12 +9,7 @@ import torch from mocking_classes import DiscreteActionVecMockEnv from tensordict import pad, TensorDict, unravel_key_list -from tensordict.nn import ( - InteractionType, - make_functional, - TensorDictModule, - TensorDictSequential, -) +from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -255,15 +250,33 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) elif safe and spec_type == "bounded": assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + +class TestTDSequence: + # Temporarily disabling this test until 473 is merged in tensordict + # def test_in_key_warning(self): + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] + # ) + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] + # ) + @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional(self, safe, spec_type): + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful(self, safe, spec_type, lazy): torch.manual_seed(0) param_multiplier = 1 - - net = nn.Linear(3, 4 * param_multiplier) - - params = make_functional(net) + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) if spec_type is None: spec = None @@ -272,31 +285,51 @@ def test_functional(self, safe, spec_type): elif spec_type == "unbounded": spec = UnboundedContinuousTensorSpec(4) + kwargs = {} + if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tensordict_module = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return + pytest.skip("safe and spec is None is checked elsewhere") else: - tensordict_module = SafeModule( - spec=spec, - module=net, + tdmodule1 = SafeModule( + net1, + spec=None, in_keys=["in"], + out_keys=["hidden"], + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + spec=None, + in_keys=["hidden"], + out_keys=["hidden"], + safe=False, + ) + tdmodule2 = SafeModule( + spec=spec, + module=net2, + in_keys=["hidden"], out_keys=["out"], - safe=safe, + safe=False, + **kwargs, ) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 3 + tdmodule[1] = tdmodule2 + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 3 + del tdmodule[2] + assert len(tdmodule) == 2 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=TensorDict({"module": params}, [])) + tdmodule(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -308,16 +341,19 @@ def test_functional(self, safe, spec_type): @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic(self, safe, spec_type): + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful_probabilistic(self, safe, spec_type, lazy): torch.manual_seed(0) param_multiplier = 2 - - tdnet = SafeModule( - module=NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)), - spec=None, - in_keys=["in"], - out_keys=["loc", "scale"], - ) + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) + net2 = NormalParamWrapper(net2) if spec_type is None: spec = None @@ -331,1075 +367,128 @@ def test_functional_probabilistic(self, safe, spec_type): kwargs = {"distribution_class": TanhNormal} if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return + pytest.skip("safe and spec is None is checked elsewhere") else: + tdmodule1 = SafeModule( + net1, + in_keys=["in"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + in_keys=["hidden"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeModule( + module=net2, + in_keys=["hidden"], + out_keys=["loc", "scale"], + spec=None, + safe=False, + ) + prob_module = SafeProbabilisticModule( + spec=spec, in_keys=["loc", "scale"], out_keys=["out"], - spec=spec, - safe=safe, + safe=False, **kwargs, ) + tdmodule = SafeProbabilisticTensorDictSequential( + tdmodule1, dummy_tdmodule, tdmodule2, prob_module + ) + + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 4 + tdmodule[1] = tdmodule2 + tdmodule[2] = prob_module + assert len(tdmodule) == 4 - tensordict_module = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tensordict_module) + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 4 + del tdmodule[3] + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 + assert tdmodule[2] is prob_module td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=params) + tdmodule(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) + dist = tdmodule.get_dist(td) + assert dist.rsample().shape[: td.ndimension()] == td.shape + # test bounds if not safe and spec_type == "bounded": assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() elif safe and spec_type == "bounded": assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.BatchNorm1d(32 * param_multiplier) - params = make_functional(net) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(32) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=TensorDict({"module": params}, [])) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) + def test_submodule_sequence(self): + td_module_1 = SafeModule( + nn.Linear(3, 2), + in_keys=["in"], + out_keys=["hidden"], + ) + td_module_2 = SafeModule( + nn.Linear(2, 4), + in_keys=["hidden"], + out_keys=["out"], + ) + td_module = SafeSequential(td_module_1, td_module_2) - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + sub_seq_1(td_1) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + sub_seq_2(td_2) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic(self, safe, spec_type): + @pytest.mark.parametrize("stack", [True, False]) + def test_sequential_partial(self, stack): torch.manual_seed(0) param_multiplier = 2 - tdnet = SafeModule( - module=NormalParamWrapper(nn.BatchNorm1d(32 * param_multiplier)), - spec=None, - in_keys=["in"], - out_keys=["loc", "scale"], - ) + net1 = nn.Linear(3, 4) - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(32) - else: - raise NotImplementedError + net2 = nn.Linear(4, 4 * param_multiplier) + net2 = NormalParamWrapper(net2) + net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) - kwargs = {"distribution_class": TanhNormal} + net3 = nn.Linear(4, 4 * param_multiplier) + net3 = NormalParamWrapper(net3) + net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) + spec = BoundedTensorSpec(-0.1, 0.1, 4) - return - else: - prob_module = SafeProbabilisticModule( + kwargs = {"distribution_class": TanhNormal} + + tdmodule1 = SafeModule( + net1, + in_keys=["a"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeProbabilisticTensorDictSequential( + net2, + SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=["out"], spec=spec, - safe=safe, - **kwargs, - ) - - tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tdmodule) - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.Linear(3, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) - tdnet = SafeModule( - module=net, in_keys=["in"], out_keys=["loc", "scale"], spec=None - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return - else: - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - -class TestTDSequence: - # Temporarily disabling this test until 473 is merged in tensordict - # def test_in_key_warning(self): - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] - # ) - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] - # ) - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 1 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - kwargs = {} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - spec=spec, - module=net2, - in_keys=["hidden"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful_probabilistic(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 2 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - in_keys=["in"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - - prob_module = SafeProbabilisticModule( - spec=spec, - in_keys=["loc", "scale"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) - - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - params = make_functional(tdmodule, funs_to_decorate=["forward", "get_dist"]) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - with params.unlock_(): - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - with params.unlock_(): - del params["module", "3"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td, params=params) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - dummy_net = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(7) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) - - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - dummy_net = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(7) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, in_keys=["in"], out_keys=["hidden"], spec=None, safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - params = make_functional(tdmodule, ["forward", "get_dist"]) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - with params.unlock_(): - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - with params.unlock_(): - del params["module", "3"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) - - dist = tdmodule.get_dist(td, params=params) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td_repeat - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, tdmodule2, prob_module - ) - - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td_repeat - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.parametrize("functional", [True, False]) - def test_submodule_sequence(self, functional): - td_module_1 = SafeModule( - nn.Linear(3, 2), - in_keys=["in"], - out_keys=["hidden"], - ) - td_module_2 = SafeModule( - nn.Linear(2, 4), - in_keys=["hidden"], - out_keys=["out"], - ) - td_module = SafeSequential(td_module_1, td_module_2) - - if functional: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - params = make_functional(sub_seq_1) - sub_seq_1(td_1, params=params) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - params = make_functional(sub_seq_2) - sub_seq_2(td_2, params=params) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - else: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1(td_1) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2(td_2) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - - @pytest.mark.parametrize("stack", [True, False]) - @pytest.mark.parametrize("functional", [True, False]) - def test_sequential_partial(self, stack, functional): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) - - net3 = nn.Linear(4, 4 * param_multiplier) - net3 = NormalParamWrapper(net3) - net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - - spec = BoundedTensorSpec(-0.1, 0.1, 4) - - kwargs = {"distribution_class": TanhNormal} - - tdmodule1 = SafeModule( - net1, - in_keys=["a"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeProbabilisticTensorDictSequential( - net2, - SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=True, + safe=True, **kwargs, ), ) @@ -1417,11 +506,6 @@ def test_sequential_partial(self, stack, functional): tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True ) - if functional: - params = make_functional(tdmodule) - else: - params = None - if stack: td = torch.stack( [ @@ -1430,10 +514,7 @@ def test_sequential_partial(self, stack, functional): ], 0, ) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1444,10 +525,7 @@ def test_sequential_partial(self, stack, functional): assert "b" in td[0].keys() else: td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 71b15d1dfae..d0d9d9826f2 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -1115,3 +1115,207 @@ def _repr_by_depth(key): return (0, key) else: return (len(key) - 1, ".".join(key)) + + +def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False): + if policy is None: + if env is None: + raise ValueError( + "env must be provided to _get_policy_and_device if policy is None" + ) + policy = RandomPolicy(env.input_spec["full_action_spec"]) + # make sure policy is an nn.Module + policy = _NonParametricPolicyWrapper(policy) + if not _policy_is_tensordict_compatible(policy): + # policy is a nn.Module that doesn't operate on tensordicts directly + # so we attempt to auto-wrap policy with TensorDictModule + if observation_spec is None: + raise ValueError( + "Unable to read observation_spec from the environment. This is " + "required to check compatibility of the environment and policy " + "since the policy is a nn.Module that operates on tensors " + "rather than a TensorDictModule or a nn.Module that accepts a " + "TensorDict as input and defines in_keys and out_keys." + ) + + try: + sig = policy.forward.__signature__ + except AttributeError: + sig = inspect.signature(policy.forward) + # we check if all the mandatory params are there + params = list(sig.parameters.keys()) + if ( + set(sig.parameters) == {"tensordict"} + or set(sig.parameters) == {"td"} + or ( + len(params) == 1 + and is_tensor_collection(sig.parameters[params[0]].annotation) + ) + ): + return policy + if fast_wrap: + in_keys = list(observation_spec.keys()) + out_keys = list(env.action_keys) + return TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) + + required_kwargs = { + str(k) for k, p in sig.parameters.items() if p.default is inspect._empty + } + next_observation = { + key: value for key, value in observation_spec.rand().items() + } + if not required_kwargs.difference(set(next_observation)): + in_keys = [str(k) for k in sig.parameters if k in next_observation] + if env is None: + out_keys = ["action"] + else: + out_keys = list(env.action_keys) + for p in policy.parameters(): + policy_device = p.device + break + else: + policy_device = None + if policy_device: + next_observation = tree_map( + lambda x: x.to(policy_device), next_observation + ) + + output = policy(**next_observation) + + if isinstance(output, tuple): + out_keys.extend(f"output{i + 1}" for i in range(len(output) - 1)) + + policy = TensorDictModule(policy, in_keys=in_keys, out_keys=out_keys) + else: + raise TypeError( + f"""Arguments to policy.forward are incompatible with entries in + env.observation_spec (got incongruent signatures: fun signature is {set(sig.parameters)} vs specs {set(next_observation)}). + If you want TorchRL to automatically wrap your policy with a TensorDictModule + then the arguments to policy.forward must correspond one-to-one with entries + in env.observation_spec. + For more complex behaviour and more control you can consider writing your + own TensorDictModule. + Check the collector documentation to know more about accepted policies. + """ + ) + return policy + + +def _policy_is_tensordict_compatible(policy: nn.Module): + if isinstance(policy, _NonParametricPolicyWrapper) and isinstance( + policy.policy, RandomPolicy + ): + return True + + if isinstance(policy, TensorDictModuleBase): + return True + + sig = inspect.signature(policy.forward) + + if ( + len(sig.parameters) == 1 + and hasattr(policy, "in_keys") + and hasattr(policy, "out_keys") + ): + raise RuntimeError( + "Passing a policy that is not a tensordict.nn.TensorDictModuleBase subclass but has in_keys and out_keys " + "is deprecated. Users should inherit from this class (which " + "has very few restrictions) to make the experience smoother. " + "Simply change your policy from `class Policy(nn.Module)` to `Policy(tensordict.nn.TensorDictModuleBase)` " + "and this error should disappear.", + ) + elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): + # if it's not a TensorDictModule, and in_keys and out_keys are not defined then + # we assume no TensorDict compatibility and will try to wrap it. + return False + + # if in_keys or out_keys were defined but policy is not a TensorDictModule or + # accepts multiple arguments then it's likely the user is trying to do something + # that will have undetermined behaviour, we raise an error + raise TypeError( + "Received a policy that defines in_keys or out_keys and also expects multiple " + "arguments to policy.forward. If the policy is compatible with TensorDict, it " + "should take a single argument of type TensorDict to policy.forward and define " + "both in_keys and out_keys. Alternatively, policy.forward can accept " + "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " + "TorchRL will attempt to automatically wrap the policy with a TensorDictModule." + ) + + +class RandomPolicy: + """A random policy for data collectors. + + This is a wrapper around the action_spec.rand method. + + Args: + action_spec: TensorSpec object describing the action specs + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.data.tensor_specs import BoundedTensorSpec + >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) + >>> actor = RandomPolicy(action_spec=action_spec) + >>> td = actor(TensorDict({}, batch_size=[])) # selects a random action in the cube [-1; 1] + """ + + def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() + self.action_spec = action_spec.clone() + self.action_key = action_key + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + if isinstance(self.action_spec, CompositeSpec): + return td.update(self.action_spec.rand()) + else: + return td.set(self.action_key, self.action_spec.rand()) + + +class _PolicyMetaClass(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + # no kwargs + if isinstance(args[0], nn.Module): + return args[0] + return super().__call__(*args) + + +class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass): + """A wrapper for non-parametric policies.""" + + def __init__(self, policy): + super().__init__() + self.policy = policy + + @property + def forward(self): + forward = self.__dict__.get("_forward", None) + if forward is None: + + @functools.wraps(self.policy) + def forward(*input, **kwargs): + return self.policy.__call__(*input, **kwargs) + + self.__dict__["_forward"] = forward + return forward + + def __getattr__(self, attr: str) -> Any: + if attr in self.__dir__(): + return self.__getattribute__( + attr + ) # make sure that appropriate exceptions are raised + + elif attr.startswith("__"): + raise AttributeError( + "passing built-in private methods is " + f"not permitted with type {type(self)}. " + f"Got attribute {attr}." + ) + + elif "policy" in self.__dir__(): + policy = self.__getattribute__("policy") + return getattr(policy, attr) + try: + super().__getattr__(attr) + except Exception: + raise AttributeError( + f"policy not set in {self.__class__.__name__}, cannot access {attr}." + ) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 252b4fd2146..4a818474985 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -297,12 +297,11 @@ def _loss_actor( ) -> torch.Tensor: td_copy = tensordict.select(*self.actor_in_keys) # Get an action from the actor network: since we made it functional, we need to pass the params - td_copy = self.actor_network(td_copy, params=self.actor_network_params) + with self.actor_network_params.to_module(self.actor_network): + td_copy = self.actor_network(td_copy) # get the value associated with that action - td_copy = self.value_network( - td_copy, - params=self.value_network_params.detach(), - ) + with self.value_network_params.detach().to_module(self.value_network): + td_copy = self.value_network(td_copy) return -td_copy.get("state_action_value") @@ -324,7 +323,8 @@ def _loss_value( td_copy = tensordict.clone() # V(s, a) - self.value_network(td_copy, params=self.value_network_params) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) pred_val = td_copy.get("state_action_value").squeeze(-1) # we manually reconstruct the parameters of the actor-critic, where the first @@ -339,9 +339,8 @@ def _loss_value( batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) - target_value = self.value_estimator.value_estimate( - tensordict, target_params=target_params - ).squeeze(-1) + with target_params.to_module(self.value_estimator): + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index ce3f0bb4b98..25213503e19 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -543,21 +543,30 @@ # Functional Programming (Ensembling / Meta-RL) # ---------------------------------------------- -from tensordict.nn import make_functional +from tensordict import TensorDict -params = make_functional(sequence) -len(list(sequence.parameters())) # functional modules have no parameters +params = TensorDict.from_module(sequence) +print("extracted params", params) ############################################################################### +# functional call using tensordict: -sequence(tensordict, params) +with params.to_module(sequence): + sequence(tensordict) ############################################################################### - +# Using vectorized map for model ensembling from torch import vmap params_expand = params.expand(4) -tensordict_exp = vmap(sequence, (None, 0))(tensordict, params_expand) + + +def exec_sequence(params, data): + with params.to_module(sequence): + return sequence(data) + + +tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, tensordict) print(tensordict_exp) ###############################################################################