Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions monai/visualize/class_activation_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,31 @@ def get_layer(self, layer_id: str | Callable[[nn.Module], nn.Module]) -> nn.Modu
return cast(nn.Module, mod)
raise NotImplementedError(f"Could not find {layer_id}.")

def class_score(self, logits: torch.Tensor, class_idx: int) -> torch.Tensor:
return logits[:, class_idx].squeeze()
def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
if isinstance(class_idx, int):
return logits[:, class_idx].squeeze()
elif class_idx.numel() == 1:
return logits[:, class_idx.item()]
elif len(class_idx.view(-1)) == logits.shape[0]:
return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1)
else:
raise ValueError("expect length of class_idx equal to batch size")
Comment on lines +127 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

❓ Verification inconclusive

Batch indexing: enforce dtype/device; avoid squeeze; tighten validation.

Current behavior can return a 0‑D tensor for batch==1 and doesn’t enforce Long dtype for gather. Recommend unifying outputs to shape [B] and validating lengths.

Apply within this hunk:

-    def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
-        if isinstance(class_idx, int):
-            return logits[:, class_idx].squeeze()
-        elif class_idx.numel() == 1:
-            return logits[:, class_idx.item()]
-        elif len(class_idx.view(-1)) == logits.shape[0]:
-            return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1)
-        else:
-            raise ValueError("expect length of class_idx equal to batch size")
+    def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
+        if isinstance(class_idx, int):
+            # keep shape [B], avoid scalar squeeze when B==1
+            return logits[:, class_idx]
+        idx = class_idx.to(device=logits.device, dtype=torch.long).view(-1)
+        if idx.numel() == 1:
+            return logits[:, idx.item()]
+        if idx.numel() == logits.shape[0]:
+            return torch.gather(logits, 1, idx.unsqueeze(1)).squeeze(1)
+        raise ValueError("class_idx must be scalar or have length equal to batch size")

Also, GradCAM++ multiplies self.nn_module.score.exp() by grad. With batched scores this will broadcast incorrectly unless reshaped. See suggestion under verification below.

To prevent a broadcast error in GradCAM++ for B>1, reshape score before multiplication:

# In GradCAMpp.compute_map(...) before multiplying with grad:
score = cast(torch.Tensor, self.nn_module.score).exp().view(b, *([1] * (grad.dim() - 1)))
relu_grad = F.relu(score * grad)

Please run the GradCAM++ tests with B>1 to confirm no shape/broadcast errors.


Batch indexing: enforce dtype/device; avoid squeeze; tighten validation.

File: monai/visualize/class_activation_maps.py Lines: 127-135

Keep outputs shaped [B], coerce indices to torch.long on logits.device, and validate lengths.

Apply within this hunk:

-    def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
-        if isinstance(class_idx, int):
-            return logits[:, class_idx].squeeze()
-        elif class_idx.numel() == 1:
-            return logits[:, class_idx.item()]
-        elif len(class_idx.view(-1)) == logits.shape[0]:
-            return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1)
-        else:
-            raise ValueError("expect length of class_idx equal to batch size")
+    def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
+        if isinstance(class_idx, int):
+            # keep shape [B], avoid scalar squeeze when B==1
+            return logits[:, class_idx]
+        idx = class_idx.to(device=logits.device, dtype=torch.long).view(-1)
+        if idx.numel() == 1:
+            return logits[:, idx.item()]
+        if idx.numel() == logits.shape[0]:
+            return torch.gather(logits, 1, idx.unsqueeze(1)).squeeze(1)
+        raise ValueError("class_idx must be scalar or have length equal to batch size")

Also fix GradCAM++ broadcasting (reshape score to match grad dims), e.g.:

# in GradCAMpp.compute_map(...) before multiplying with grad:
score = cast(torch.Tensor, self.nn_module.score).exp().view(b, *([1] * (grad.dim() - 1)))
relu_grad = F.relu(score * grad)
🧰 Tools
🪛 Ruff (0.12.2)

135-135: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In monai/visualize/class_activation_maps.py around lines 127 to 135, the
class_score method must return a tensor shaped [B], validate that class_idx
length equals batch size, and ensure class_idx is a torch.long tensor on the
same device as logits instead of using Python int checks and squeeze calls;
coerce class_idx via class_idx = torch.as_tensor(class_idx,
device=logits.device, dtype=torch.long), check class_idx.numel() ==
logits.shape[0] and then use torch.gather(logits, 1,
class_idx.view(-1,1)).view(-1) for a stable [B] output, raising a ValueError
otherwise. Also in GradCAM++ compute_map, reshape the scalar score to broadcast
across grad dims before multiplying: set score =
self.nn_module.score.exp().view(b, *([1] * (grad.dim() - 1))) and then use
relu_grad = F.relu(score * grad) so shapes align for subsequent reductions.


def __call__(self, x, class_idx=None, retain_graph=False, **kwargs):
train = self.model.training
self.model.eval()
logits = self.model(x, **kwargs)
self.class_idx = logits.max(1)[-1] if class_idx is None else class_idx
if class_idx is None:
self.class_idx = logits.max(1)[-1]
elif isinstance(class_idx, torch.Tensor):
self.class_idx = class_idx.to(logits.device)
else:
self.class_idx = class_idx
acti, grad = None, None
if self.register_forward:
acti = tuple(self.activations[layer] for layer in self.target_layers)
if self.register_backward:
self.score = self.class_score(logits, cast(int, self.class_idx))
self.score = self.class_score(logits, self.class_idx)
self.model.zero_grad()
self.score.sum().backward(retain_graph=retain_graph)
for layer in self.target_layers:
Expand Down
2 changes: 0 additions & 2 deletions monai/visualize/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def model(self, m):
def get_grad(
self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph: bool = True, **kwargs: Any
) -> torch.Tensor:
if x.shape[0] != 1:
raise ValueError("expect batch size of 1")
x.requires_grad = True

self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs)
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/test_vis_gradbased.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,19 @@ def __call__(self, x, adjoint_info):
for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad):
# 2D densenet
TESTS.append([type, DENSENET2D, (1, 1, 48, 64)])
TESTS.append([type, DENSENET2D, (4, 1, 48, 64)])
# 3D densenet
TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6)])
TESTS.append([type, DENSENET3D, (2, 1, 6, 6, 6)])
# 2D senet
TESTS.append([type, SENET2D, (1, 3, 64, 64)])
TESTS.append([type, SENET2D, (3, 3, 64, 64)])
# 3D senet
TESTS.append([type, SENET3D, (1, 3, 8, 8, 48)])
TESTS.append([type, SENET3D, (2, 3, 8, 8, 48)])
# 2D densenet - adjoint
TESTS.append([type, DENSENET2DADJOINT, (1, 1, 48, 64)])
TESTS.append([type, DENSENET2DADJOINT, (3, 1, 48, 64)])


class TestGradientClassActivationMap(unittest.TestCase):
Expand Down
Loading