Skip to content

Commit b9e17e8

Browse files
dongyang0122milesialpre-commit-ci[bot]
authored
Improve GPU utilization of dints network (#6050)
Improve GPU utilization of dints network ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Alexandre Milesi <[email protected]> Co-authored-by: milesial <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3eef61e commit b9e17e8

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

monai/networks/nets/dints.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
class CellInterface(torch.nn.Module):
4141
"""interface for torchscriptable Cell"""
4242

43-
def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: # type: ignore
43+
def forward(self, x: torch.Tensor, weight) -> torch.Tensor: # type: ignore
4444
pass
4545

4646

@@ -170,7 +170,7 @@ def __init__(self, c: int, ops: dict, arch_code_c=None):
170170
if arch_c > 0:
171171
self.ops.append(ops[op_name](c))
172172

173-
def forward(self, x: torch.Tensor, weight: torch.Tensor):
173+
def forward(self, x: torch.Tensor, weight: torch.Tensor | None = None):
174174
"""
175175
Args:
176176
x: input tensor.
@@ -179,9 +179,10 @@ def forward(self, x: torch.Tensor, weight: torch.Tensor):
179179
out: weighted average of the operation results.
180180
"""
181181
out = 0.0
182-
weight = weight.to(x)
182+
if weight is not None:
183+
weight = weight.to(x)
183184
for idx, _op in enumerate(self.ops):
184-
out = out + _op(x) * weight[idx]
185+
out = (out + _op(x)) if weight is None else out + _op(x) * weight[idx]
185186
return out
186187

187188

@@ -297,7 +298,7 @@ def __init__(
297298

298299
self.op = MixedOp(c, self.OPS, arch_code_c)
299300

300-
def forward(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
301+
def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor:
301302
"""
302303
Args:
303304
x: input tensor
@@ -669,15 +670,13 @@ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
669670
x: input tensor.
670671
"""
671672
# generate path activation probability
672-
inputs, outputs = x, [torch.tensor(0.0).to(x[0])] * self.num_depths
673+
inputs = x
673674
for blk_idx in range(self.num_blocks):
674-
outputs = [torch.tensor(0.0).to(x[0])] * self.num_depths
675+
outputs = [torch.tensor(0.0, dtype=x[0].dtype, device=x[0].device)] * self.num_depths
675676
for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data):
676677
if activation:
677678
mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))]
678-
_out = mod.forward(
679-
x=inputs[self.arch_code2in[res_idx]], weight=torch.ones_like(self.arch_code_c[blk_idx, res_idx])
680-
)
679+
_out = mod.forward(x=inputs[self.arch_code2in[res_idx]], weight=None)
681680
outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out
682681
inputs = outputs
683682

@@ -885,13 +884,13 @@ def get_ram_cost_usage(self, in_size, full: bool = False):
885884
sizes = []
886885
for res_idx in range(self.num_depths):
887886
sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod())
888-
sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample)))
887+
sizes = torch.tensor(sizes, dtype=torch.float32, device=self.device) / (2 ** (int(self.use_downsample)))
889888
probs_a, arch_code_prob_a = self.get_prob_a(child=False)
890889
cell_prob = F.softmax(self.log_alpha_c, dim=-1)
891890
if full:
892891
arch_code_prob_a = arch_code_prob_a.detach()
893892
arch_code_prob_a.fill_(1)
894-
ram_cost = torch.from_numpy(self.ram_cost).to(torch.float32).to(self.device)
893+
ram_cost = torch.from_numpy(self.ram_cost).to(dtype=torch.float32, device=self.device)
895894
usage = 0.0
896895
for blk_idx in range(self.num_blocks):
897896
# node activation for input

0 commit comments

Comments
 (0)