Skip to content

Commit 684688a

Browse files
binliunlsyiheng-wang-nvKumoLiu
authored
Optimize VISTA3D (#8123)
Fixes #8122 . ### Description As shown in [this PR](Project-MONAI/model-zoo#671), the memory malloc and mask embedding for-loop are the bottlenecks that caused the vista3d slow inference. Therefore, this PR fixed them by adding the logic for malloc and replacing the for-loop with a tensor multiplication. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: binliu <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent 052dbb4 commit 684688a

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

monai/networks/nets/segresnet_ds.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,10 @@ def forward( # type: ignore
508508

509509
outputs: list[torch.Tensor] = []
510510
outputs_auto: list[torch.Tensor] = []
511-
x_ = x.clone()
511+
x_ = x
512512
if with_point:
513+
if with_label:
514+
x_ = x.clone()
513515
i = 0
514516
for level in self.up_layers:
515517
x = level["upsample"](x)

monai/networks/nets/vista3d.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -639,12 +639,10 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
639639
if self.use_mlp:
640640
class_embedding = self.mlp(class_embedding)
641641
# [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
642-
masks = []
643-
for i in range(b):
644-
mask = class_embedding @ src[[i]].view(1, c, h * w * d)
645-
masks.append(mask.view(-1, 1, h, w, d))
642+
masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)
643+
masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)
646644

647-
return torch.cat(masks, 1), class_embedding
645+
return masks_embedding, class_embedding
648646

649647

650648
class TwoWayTransformer(nn.Module):

0 commit comments

Comments
 (0)