-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix(transformerblock): conditionally initialize cross attention compo… #8545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
WalkthroughThe TransformerBlock now instantiates cross-attention components only when Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. 📜 Recent review detailsConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
monai/networks/blocks/transformerblock.py (3)
82-91
: Define optional attrs up-front; guard forward by presence (TorchScript/type-safety).Conditional attribute creation can trip TorchScript and static analyzers. Always declare attrs (as Optional) in init and gate forward by presence. This also avoids surprises if someone toggles
with_cross_attention
post-init.Apply:
self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - if with_cross_attention: - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( + # Define optional modules for TorchScript/type-checkers; instantiate only if enabled. + self.norm_cross_attn: nn.LayerNorm | None = None + self.cross_attn: CrossAttentionBlock | None = None + if with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False, use_flash_attention=use_flash_attention, ) @@ - if self.with_cross_attention: - x = x + self.cross_attn(self.norm_cross_attn(x), context=context) + if self.cross_attn is not None and self.norm_cross_attn is not None: + x = x + self.cross_attn(self.norm_cross_attn(x), context=context)Also applies to: 97-99
44-55
: Docstring: documentwith_cross_attention
(+ missing args).
with_cross_attention
is a new behavioral knob—please document it (and the existing but undocumentedcausal
,sequence_length
). Keeps API clear.Apply:
Args: hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal (bool, optional): if True, apply causal masking for self-attention in this block. Defaults to False. + sequence_length (int | None, optional): sequence length hint for attention optimizations. Defaults to None. + with_cross_attention (bool, optional): if True, instantiate cross-attention (and its LayerNorm); otherwise + they are omitted to save memory/params. Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). include_fc: whether to include the final linear layer. Default to True. use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
93-100
: Tests: cover both paths and state_dict surface.Add unit tests to ensure behavior and parameter surface are as intended.
Example (pytest-style):
import torch from monai.networks.blocks.transformerblock import TransformerBlock def _param_count(m): return sum(p.numel() for p in m.parameters()) def test_transformerblock_cross_attn_param_surface(): a = TransformerBlock(64, 128, 4, with_cross_attention=False) b = TransformerBlock(64, 128, 4, with_cross_attention=True) assert _param_count(b) > _param_count(a) sd = a.state_dict() assert not any(k.startswith(("cross_attn.", "norm_cross_attn.")) for k in sd) def test_transformerblock_forward_shapes(): x = torch.randn(2, 16, 64) ctx = torch.randn(2, 12, 64) a = TransformerBlock(64, 128, 4, with_cross_attention=False) b = TransformerBlock(64, 128, 4, with_cross_attention=True) y_a = a(x) y_b = b(x, context=ctx) assert y_a.shape == x.shape assert y_b.shape == x.shapeHappy to open a follow-up PR with tests if helpful.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/blocks/transformerblock.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/blocks/transformerblock.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (1)
monai/networks/blocks/transformerblock.py (1)
82-91
: Nice memory/param savings.Conditionally instantiating cross-attn avoids unused params and reduces memory for the common SA-only case. Clean change.
…nents Initialize cross attention layers only when with_cross_attention is True to avoid unnecessary computation and memory usage Signed-off-by: li.yunhao <[email protected]>
Transformerblock
This PR improves the initialization logic of TransformerBlock by adding conditional layer creation for cross-attention.
Previously, the cross-attention components (norm_cross_attn and cross_attn) were always initialized, even when with_cross_attention=False. This introduced unnecessary memory overhead and additional parameters that were never used during forward passes.
With this change, the cross-attention layers are initialized only when with_cross_attention=True.
This ensures cleaner module definitions, reduced memory usage, and avoids confusion about unused layers in models that do not require cross-attention.
Fixes # .
Description
Before
After
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.