Skip to content

Conversation

yunhaoli24
Copy link

Transformerblock

This PR improves the initialization logic of TransformerBlock by adding conditional layer creation for cross-attention.
Previously, the cross-attention components (norm_cross_attn and cross_attn) were always initialized, even when with_cross_attention=False. This introduced unnecessary memory overhead and additional parameters that were never used during forward passes.

With this change, the cross-attention layers are initialized only when with_cross_attention=True.
This ensures cleaner module definitions, reduced memory usage, and avoids confusion about unused layers in models that do not require cross-attention.

Fixes # .

Description

Before

...
class TransformerBlock(nn.Module):


    def __init__(
        self,
...
    ) -> None:
       ...
        self.norm2 = nn.LayerNorm(hidden_size)
        self.with_cross_attention = with_cross_attention

        self.norm_cross_attn = nn.LayerNorm(hidden_size)
        self.cross_attn = CrossAttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            qkv_bias=qkv_bias,
            causal=False,
            use_flash_attention=use_flash_attention,
        )

After

...
class TransformerBlock(nn.Module):


    def __init__(
        self,
...
    ) -> None:
       ...
        self.norm2 = nn.LayerNorm(hidden_size)
        self.with_cross_attention = with_cross_attention

        if with_cross_attention:
            self.norm_cross_attn = nn.LayerNorm(hidden_size)
            self.cross_attn = CrossAttentionBlock(
                hidden_size=hidden_size,
                num_heads=num_heads,
                dropout_rate=dropout_rate,
                qkv_bias=qkv_bias,
                causal=False,
                use_flash_attention=use_flash_attention,
            )

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Copy link

coderabbitai bot commented Aug 22, 2025

Walkthrough

The TransformerBlock now instantiates cross-attention components only when with_cross_attention is True. Specifically, norm\_cross\_attn and cross\_attn are created conditionally; previously they were always initialized. The forward path already checks with_cross_attention, so runtime behavior is unchanged when the flag is False, but those attributes may be absent on instances created with with_cross_attention=False.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.


📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between bad0028 and e71c4f9.

📒 Files selected for processing (1)
  • monai/networks/blocks/transformerblock.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • monai/networks/blocks/transformerblock.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
monai/networks/blocks/transformerblock.py (3)

82-91: Define optional attrs up-front; guard forward by presence (TorchScript/type-safety).

Conditional attribute creation can trip TorchScript and static analyzers. Always declare attrs (as Optional) in init and gate forward by presence. This also avoids surprises if someone toggles with_cross_attention post-init.

Apply:

         self.norm2 = nn.LayerNorm(hidden_size)
         self.with_cross_attention = with_cross_attention
-        if with_cross_attention:
-            self.norm_cross_attn = nn.LayerNorm(hidden_size)
-            self.cross_attn = CrossAttentionBlock(
+        # Define optional modules for TorchScript/type-checkers; instantiate only if enabled.
+        self.norm_cross_attn: nn.LayerNorm | None = None
+        self.cross_attn: CrossAttentionBlock | None = None
+        if with_cross_attention:
+            self.norm_cross_attn = nn.LayerNorm(hidden_size)
+            self.cross_attn = CrossAttentionBlock(
                 hidden_size=hidden_size,
                 num_heads=num_heads,
                 dropout_rate=dropout_rate,
                 qkv_bias=qkv_bias,
                 causal=False,
                 use_flash_attention=use_flash_attention,
             )
@@
-        if self.with_cross_attention:
-            x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
+        if self.cross_attn is not None and self.norm_cross_attn is not None:
+            x = x + self.cross_attn(self.norm_cross_attn(x), context=context)

Also applies to: 97-99


44-55: Docstring: document with_cross_attention (+ missing args).

with_cross_attention is a new behavioral knob—please document it (and the existing but undocumented causal, sequence_length). Keeps API clear.

Apply:

         Args:
             hidden_size (int): dimension of hidden layer.
             mlp_dim (int): dimension of feedforward layer.
             num_heads (int): number of attention heads.
             dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
             qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
             save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
+            causal (bool, optional): if True, apply causal masking for self-attention in this block. Defaults to False.
+            sequence_length (int | None, optional): sequence length hint for attention optimizations. Defaults to None.
+            with_cross_attention (bool, optional): if True, instantiate cross-attention (and its LayerNorm); otherwise
+                they are omitted to save memory/params. Defaults to False.
             use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
                 (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
             include_fc: whether to include the final linear layer. Default to True.
             use_combined_linear: whether to use a single linear layer for qkv projection, default to True.

93-100: Tests: cover both paths and state_dict surface.

Add unit tests to ensure behavior and parameter surface are as intended.

Example (pytest-style):

import torch
from monai.networks.blocks.transformerblock import TransformerBlock

def _param_count(m): return sum(p.numel() for p in m.parameters())

def test_transformerblock_cross_attn_param_surface():
    a = TransformerBlock(64, 128, 4, with_cross_attention=False)
    b = TransformerBlock(64, 128, 4, with_cross_attention=True)
    assert _param_count(b) > _param_count(a)
    sd = a.state_dict()
    assert not any(k.startswith(("cross_attn.", "norm_cross_attn.")) for k in sd)

def test_transformerblock_forward_shapes():
    x = torch.randn(2, 16, 64)
    ctx = torch.randn(2, 12, 64)
    a = TransformerBlock(64, 128, 4, with_cross_attention=False)
    b = TransformerBlock(64, 128, 4, with_cross_attention=True)
    y_a = a(x)
    y_b = b(x, context=ctx)
    assert y_a.shape == x.shape
    assert y_b.shape == x.shape

Happy to open a follow-up PR with tests if helpful.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1e6c661 and bad0028.

📒 Files selected for processing (1)
  • monai/networks/blocks/transformerblock.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/blocks/transformerblock.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (1)
monai/networks/blocks/transformerblock.py (1)

82-91: Nice memory/param savings.

Conditionally instantiating cross-attn avoids unused params and reduces memory for the common SA-only case. Clean change.

…nents

Initialize cross attention layers only when with_cross_attention is True to avoid unnecessary computation and memory usage

Signed-off-by: li.yunhao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant