Skip to content

Commit ed656a1

Browse files
author
Vincent Moens
authored
[BugFix] Fix missing min/max alpha clamps in losses (#2684)
1 parent f672c70 commit ed656a1

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

torchrl/objectives/cql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
892892

893893
@property
894894
def _alpha(self):
895-
if self.min_log_alpha is not None:
895+
if self.min_log_alpha is not None or self.max_log_alpha is not None:
896896
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
897897
alpha = self.log_alpha.data.exp()
898898
return alpha

torchrl/objectives/crossq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def alpha_loss(self, log_prob: Tensor) -> Tensor:
677677

678678
@property
679679
def _alpha(self):
680-
if self.min_log_alpha is not None:
680+
if self.min_log_alpha is not None or self.max_log_alpha is not None:
681681
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
682682
with torch.no_grad():
683683
alpha = self.log_alpha.exp()

torchrl/objectives/decision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _forward_value_estimator_keys(self, **kwargs):
171171

172172
@property
173173
def alpha(self):
174-
if self.min_log_alpha is not None:
174+
if self.min_log_alpha is not None or self.max_log_alpha is not None:
175175
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
176176
with torch.no_grad():
177177
alpha = self.log_alpha.exp()

torchrl/objectives/sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor:
846846

847847
@property
848848
def _alpha(self):
849-
if self.min_log_alpha is not None:
849+
if self.min_log_alpha is not None or self.max_log_alpha is not None:
850850
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
851851
with torch.no_grad():
852852
alpha = self.log_alpha.exp()
@@ -1374,7 +1374,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor:
13741374

13751375
@property
13761376
def _alpha(self):
1377-
if self.min_log_alpha is not None:
1377+
if self.min_log_alpha is not None or self.max_log_alpha is not None:
13781378
self.log_alpha.data = self.log_alpha.data.clamp(
13791379
self.min_log_alpha, self.max_log_alpha
13801380
)

0 commit comments

Comments
 (0)