From 8dc96cfdc9054a11fba712dac756d59311802deb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 12 Feb 2024 14:59:32 +0000 Subject: [PATCH] [Minor] Remove warnings in test_cost (#1902) --- test/conftest.py | 7 +- test/test_cost.py | 147 ++++++++++++++++++--- torchrl/objectives/a2c.py | 25 ++-- torchrl/objectives/decision_transformer.py | 1 - torchrl/objectives/reinforce.py | 24 ++-- torchrl/objectives/sac.py | 4 +- 6 files changed, 165 insertions(+), 43 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 5ce980a4080..2dcd369003a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -53,7 +53,7 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def set_warnings() -> None: warnings.filterwarnings( "ignore", @@ -65,6 +65,11 @@ def set_warnings() -> None: category=UserWarning, message=r"Couldn't cast the policy onto the desired device on remote process", ) + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r"Skipping device Apple Paravirtual device", + ) warnings.filterwarnings( "ignore", category=DeprecationWarning, diff --git a/test/test_cost.py b/test/test_cost.py index f92a58707bd..2d8bfa91351 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -156,6 +156,10 @@ ) +# Capture all warnings +pytestmark = pytest.mark.filterwarnings("error") + + class _check_td_steady: def __init__(self, td): self.td_clone = td.clone() @@ -501,6 +505,11 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): else contextlib.nullcontext() ), _check_td_steady(td): loss = loss_fn(td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert loss_fn.tensor_keys.priority in td.keys() sum([item for _, item in loss.items()]).backward() @@ -562,6 +571,10 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -601,7 +614,7 @@ def test_dqn_tensordict_keys(self, td_est): torch.manual_seed(self.seed) action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) default_keys = { "advantage": "advantage", @@ -617,7 +630,7 @@ def test_dqn_tensordict_keys(self, td_est): self.tensordict_keys_test(loss_fn, default_keys=default_keys) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) key_mapping = { "advantage": ("advantage", "advantage_2"), "value_target": ("value_target", ("value_target", "nested")), @@ -630,7 +643,7 @@ def test_dqn_tensordict_keys(self, td_est): actor = self._create_mock_actor( action_spec_type=action_spec_type, action_value_key="chosen_action_value_2" ) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) key_mapping = { "value": ("value", "chosen_action_value_2"), } @@ -657,11 +670,14 @@ def test_dqn_tensordict_run(self, action_spec_type, td_est): action_value_key=tensor_keys["action_value"], ) - loss_fn = DQNLoss(actor, loss_function="l2") + loss_fn = DQNLoss(actor, loss_function="l2", delay_value=True) loss_fn.set_keys(**tensor_keys) if td_est is not None: loss_fn.make_value_estimator(td_est) + + SoftUpdate(loss_fn, eps=0.5) + with _check_td_steady(td): _ = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -707,6 +723,10 @@ def test_distributional_dqn( sum([item for _, item in loss.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + # Check param update effect on targets target_value = loss_fn.target_value_network_params.clone() for p in loss_fn.parameters(): @@ -744,7 +764,7 @@ def test_dqn_notensordict( module=module, in_keys=[observation_key], ) - dqn_loss = DQNLoss(actor) + dqn_loss = DQNLoss(actor, delay_value=True) dqn_loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) @@ -762,6 +782,8 @@ def test_dqn_notensordict( "action": action, } td = TensorDict(kwargs, []).unflatten_keys("_") + # Disable warning + SoftUpdate(dqn_loss, eps=0.5) loss_val = dqn_loss(**kwargs) loss_val_td = dqn_loss(td) torch.testing.assert_close(loss_val_td.get("loss"), loss_val) @@ -775,7 +797,7 @@ def test_distributional_dqn_tensordict_keys(self): action_spec_type=action_spec_type, atoms=atoms ) - loss_fn = DistributionalDQNLoss(actor, gamma=gamma) + loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=True) default_keys = { "priority": "td_error", @@ -810,11 +832,14 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est): action_key=tensor_keys["action"], action_value_key=tensor_keys["action_value"], ) - loss_fn = DistributionalDQNLoss(actor, gamma=0.9) + loss_fn = DistributionalDQNLoss(actor, gamma=0.9, delay_value=True) loss_fn.set_keys(**tensor_keys) loss_fn.make_value_estimator(td_est) + # remove warnings + SoftUpdate(loss_fn, eps=0.5) + with _check_td_steady(td): _ = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -984,6 +1009,10 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): sum([item for _, item in loss.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): @@ -1051,6 +1080,11 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert loss_fn.tensor_keys.priority in ms_td.keys() with torch.no_grad(): @@ -1105,7 +1139,7 @@ def test_qmix_tensordict_keys(self, td_est): action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) mixer = self._create_mock_mixer() - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) default_keys = { "advantage": "advantage", @@ -1122,7 +1156,7 @@ def test_qmix_tensordict_keys(self, td_est): self.tensordict_keys_test(loss_fn, default_keys=default_keys) - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) key_mapping = { "advantage": ("advantage", "advantage_2"), "value_target": ("value_target", ("value_target", "nested")), @@ -1138,7 +1172,7 @@ def test_qmix_tensordict_keys(self, td_est): mixer = self._create_mock_mixer( global_chosen_action_value_key=("some", "nested") ) - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) key_mapping = { "global_value": ("value", ("some", "nested")), } @@ -1173,9 +1207,9 @@ def test_qmix_tensordict_run(self, action_spec_type, td_est): action_value_key=tensor_keys["action_value"], ) - loss_fn = QMixerLoss(actor, mixer, loss_function="l2") + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=True) loss_fn.set_keys(**tensor_keys) - + SoftUpdate(loss_fn, eps=0.5) if td_est is not None: loss_fn.make_value_estimator(td_est) with _check_td_steady(td): @@ -1231,7 +1265,9 @@ def test_mixer_keys( ) td = actor(td) - loss = QMixerLoss(actor, mixer) + loss = QMixerLoss(actor, mixer, delay_value=True) + + SoftUpdate(loss, eps=0.5) # Wthout etting the keys if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): @@ -1245,7 +1281,10 @@ def test_mixer_keys( else: loss(td) - loss = QMixerLoss(actor, mixer) + loss = QMixerLoss(actor, mixer, delay_value=True) + + SoftUpdate(loss, eps=0.5) + # When setting the key loss.set_keys(global_value=mixer_global_chosen_action_value_key) if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): @@ -1466,6 +1505,10 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): ): loss = loss_fn(td) + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert all( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.value_network_params.values(True, True) @@ -1582,6 +1625,9 @@ def test_ddpg_separate_losses( with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert all( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.value_network_params.values(True, True) @@ -1702,6 +1748,11 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -2304,10 +2355,14 @@ def test_td3_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_qvalue or delay_actor: + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) + if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum([item for _, item in loss.items()]) @@ -3291,6 +3346,9 @@ def test_sac_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) + + SoftUpdate(loss, eps=0.5) + loss_val_td = loss(td) if version == 1: @@ -3538,6 +3596,7 @@ def test_discrete_sac( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): @@ -3648,6 +3707,7 @@ def test_discrete_sac_state_dict( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) sd = loss_fn.state_dict() @@ -3659,6 +3719,7 @@ def test_discrete_sac_state_dict( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) loss_fn2.load_state_dict(sd) @@ -3696,6 +3757,7 @@ def test_discrete_sac_batcher( loss_function="l2", target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, + action_space="one-hot", **kwargs, ) @@ -3712,6 +3774,8 @@ def test_discrete_sac_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -3800,6 +3864,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): qvalue_network=qvalue, num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) default_keys = { @@ -3822,6 +3887,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): qvalue_network=qvalue, num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -3860,6 +3926,7 @@ def test_discrete_sac_notensordict( actor_network=actor, qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, + action_space="one-hot", ) loss.set_keys( action=action_key, @@ -4370,6 +4437,8 @@ def test_redq_deprecated_separate_losses(self, separate_losses): ): loss = loss_fn(td) + SoftUpdate(loss_fn, eps=0.5) + # check that losses are independent for k in loss.keys(): if not k.startswith("loss"): @@ -5408,6 +5477,9 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys(True) + if delay_value: + SoftUpdate(loss_fn, eps=0.5) + sum([item for key, item in loss.items() if key.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 @@ -5467,6 +5539,9 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_value: + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -5510,7 +5585,7 @@ def test_dcql_tensordict_keys(self, td_est): torch.manual_seed(self.seed) action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) default_keys = { "value_target": "value_target", @@ -5566,6 +5641,8 @@ def test_dcql_tensordict_run(self, action_spec_type, td_est): loss_fn = DiscreteCQLLoss(actor, loss_function="l2") loss_fn.set_keys(**tensor_keys) + SoftUpdate(loss_fn, eps=0.5) + if td_est is not None: loss_fn.make_value_estimator(td_est) with _check_td_steady(td): @@ -5590,6 +5667,9 @@ def test_dcql_notensordict( in_keys=[observation_key], ) loss = DiscreteCQLLoss(actor) + + SoftUpdate(loss, eps=0.5) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) @@ -8783,6 +8863,9 @@ def test_iql_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + # Remove warnings + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -9206,6 +9289,7 @@ def test_discrete_iql( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -9328,6 +9412,7 @@ def test_discrete_iql_state_dict( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) sd = loss_fn.state_dict() loss_fn2 = DiscreteIQLLoss( @@ -9338,6 +9423,7 @@ def test_discrete_iql_state_dict( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) loss_fn2.load_state_dict(sd) @@ -9351,6 +9437,7 @@ def test_discrete_iql_separate_losses(self, separate_losses): value_network=value, loss_function="l2", separate_losses=separate_losses, + action_space="one-hot", ) with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) @@ -9529,6 +9616,7 @@ def test_discrete_iql_batcher( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) ms = MultiStep(gamma=gamma, n_steps=n).to(device) @@ -9544,6 +9632,8 @@ def test_discrete_iql_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -9615,6 +9705,7 @@ def test_discrete_iql_tensordict_keys(self, td_est): qvalue_network=qvalue, value_network=value, loss_function="l2", + action_space="one-hot", ) default_keys = { @@ -9640,6 +9731,7 @@ def test_discrete_iql_tensordict_keys(self, td_est): qvalue_network=qvalue, value_network=value, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -9675,7 +9767,10 @@ def test_discrete_iql_notensordict( value = self._create_mock_value(observation_key=observation_key) loss = DiscreteIQLLoss( - actor_network=actor, qvalue_network=qvalue, value_network=value + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + action_space="one-hot", ) loss.set_keys( action=action_key, @@ -9744,6 +9839,10 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: out_keys=["action"], ) loss = MyLoss(actor_module) + + if create_target_params: + SoftUpdate(loss, eps=0.5) + if cast is not None: loss.to(cast) for name in ("weight", "bias"): @@ -9873,11 +9972,13 @@ def __init__(self, delay_module=True): self.convert_to_functional( module1, "module1", create_target_params=delay_module ) + module2 = torch.nn.BatchNorm2d(10).eval() self.module2 = module2 - iterator_params = self.target_module1_params.values( - include_nested=True, leaves_only=True - ) + tparam = self._modules.get("target_module1_params", None) + if tparam is None: + tparam = self._modules.get("module1_params").data + iterator_params = tparam.values(include_nested=True, leaves_only=True) for target in iterator_params: if target.dtype is not torch.int64: target.data.normal_() @@ -12285,10 +12386,14 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True) def test_instantiate_with_different_keys(): - loss_1 = DQNLoss(value_network=nn.Linear(3, 3), action_space="one_hot") + loss_1 = DQNLoss( + value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True + ) loss_1.set_keys(reward="a") assert loss_1.tensor_keys.reward == "a" - loss_2 = DQNLoss(value_network=nn.Linear(3, 3), action_space="one_hot") + loss_2 = DQNLoss( + value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True + ) loss_2.set_keys(reward="b") assert loss_1.tensor_keys.reward == "a" diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6edcda5c800..8fcbd5a6699 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -255,7 +255,8 @@ def __init__( if functional: self.convert_to_functional( - actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"] + actor_network, + "actor_network", ) else: self.actor_network = actor_network @@ -350,7 +351,7 @@ def in_keys(self): *[("next", key) for key in self.actor_network.in_keys], ] if self.critic_coef: - keys.extend(self.critic.in_keys) + keys.extend(self.critic_network.in_keys) return list(set(keys)) @property @@ -414,11 +415,11 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # TODO: if the advantage is gathered by forward, this introduces an # overhead that we could easily reduce. target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select(*self.critic.in_keys) + tensordict_select = tensordict.select(*self.critic_network.in_keys) with self.critic_network_params.to_module( - self.critic + self.critic_network ) if self.functional else contextlib.nullcontext(): - state_value = self.critic( + state_value = self.critic_network( tensordict_select, ).get(self.tensor_keys.value) loss_value = distance_loss( @@ -477,13 +478,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor if self.functional: diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a24aa4a1271..954bd0b9a42 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -275,7 +275,6 @@ def __init__( actor_network, "actor_network", create_target_params=False, - funs_to_decorate=["forward"], ) self.loss_function = loss_function diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 4613810d0d3..9738b922c5d 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -351,7 +351,7 @@ def _set_in_keys(self): ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], - *self.critic.in_keys, + *self.critic_network.in_keys, ] self._in_keys = list(set(keys)) @@ -398,11 +398,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: try: target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select(*self.critic.in_keys) + tensordict_select = tensordict.select(*self.critic_network.in_keys) with self.critic_network_params.to_module( - self.critic + self.critic_network ) if self.functional else contextlib.nullcontext(): - state_value = self.critic(tensordict_select).get(self.tensor_keys.value) + state_value = self.critic_network(tensordict_select).get( + self.tensor_keys.value + ) loss_value = distance_loss( target_return, state_value, @@ -427,13 +429,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp["gamma"] = self.gamma hp.update(hyperparams) if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor if self.functional: diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 053da9e53d2..5b722fd05f3 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -292,7 +292,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -980,7 +979,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -1036,7 +1034,7 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. DiscreteSACLoss will default to 'one-hot'." - "This behaviour will be deprecated soon and a space will have to be passed." + "This behaviour will be deprecated soon and a space will have to be passed. " "Check the DiscreteSACLoss documentation to see how to pass the action space. " ) action_space = "one-hot"