Skip to content

Faster R-CNN training OOM at ops/boxes.py _box_inter_union function #7959

Open
@davidgill97

Description

@davidgill97

🐛 Describe the bug

Training Faster R-CNN on large dataset (~1M of resolution 512x512) fails due to CUDA OOM in RPN.
These are the hyperparameters for the experiment:

rpn_pre_nms_top_n_train = 2000
rpn_pre_nms_top_n_test = 2000
rpn_post_nms_top_n_train = 2000
rpn_post_nms_top_n_test = 2000
rpn_nms_thresh = 0.7
rpn_fg_iou_thresh = 0.7
rpn_bg_iou_thresh = 0.3
rpn_batch_size_per_image = 256
rpn_positive_fraction = 0.5
rpn_score_thresh = 0
box_score_thresh = 0.05
box_nms_thresh = 0.1
box_detections_per_img = 100
box_fg_iou_thresh = 0.5
box_bg_iou_thresh = 0.5
box_batch_size_per_image = 512
box_positive_fraction = 0.25
batch_size = 4
   File "/workspace/entrypoint.py", line 182, in <module>
    main(args)
  File "/workspace/entrypoint.py", line 152, in main
    trainer.fit(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 980, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1023, in _run_stage
    self.fit_loop.run()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
    self.advance()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py", line 355, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
    self.advance(data_fetcher)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/training_epoch_loop.py", line 219, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 188, in run
    self._optimizer_step(kwargs.get("batch_idx", 0), closure)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 266, in _optimizer_step
    call._call_lightning_module_hook(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 146, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/module.py", line 1270, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/optimizer.py", line 161, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 231, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 116, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py", line 69, in wrapper
    return wrapped(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 280, in wrapper
    out = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 33, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/adam.py", line 121, in step
    loss = closure()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 103, in _wrap_closure
    closure_result = closure()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 142, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 128, in closure
    step_output = self._step_fn()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 315, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 294, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 380, in training_step
    return self.model.training_step(*args, **kwargs)
  File "/workspace/lightning_modules.py", line 96, in training_step
    loss_dict = self(images, targets)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/lightning_modules.py", line 90, in forward
    return self.model(images, targets)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/workspace/models/detection/builder.py", line 183, in forward
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/models/detection/rpn.py", line 334, in forward
    losses = self._return_loss(anchors, targets, objectness, pred_bbox_deltas)
  File "/workspace/models/detection/rpn.py", line 343, in _return_loss
    labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
  File "/workspace/models/detection/rpn.py", line 193, in assign_targets_to_anchors
    match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
  File "/usr/local/lib/python3.10/dist-packages/torchvision/ops/boxes.py", line 271, in box_iou
    inter, union = _box_inter_union(boxes1, boxes2)
  File "/usr/local/lib/python3.10/dist-packages/torchvision/ops/boxes.py", line 250, in _box_inter_union
    union = area1[:, None] + area2 - inter
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.42 GiB (GPU 0; 15.99 GiB total capacity; 12.84 GiB already allocated; 273.61 MiB free; 14.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

_box_inter_union function in torchvision/ops/boxes.py seems like it consumes a large amount of memory in the tensor operations, when len(boxes1) and len(boxes2) are large. I have altered the code as following to resolve the issue:

def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
    # original 
    # area1 = box_area(boxes1)
    # area2 = box_area(boxes2)

    # lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    # rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    # wh = _upcast(rb - lt).clamp(min=0)  # [N,M,2]
    # inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    # union = area1[:, None] + area2 - inter

    # return inter, union
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    N, M = boxes1.size(0), boxes2.size(0)

    inter = torch.zeros((N, M), device=boxes1.device)
    union = torch.zeros((N, M), device=boxes1.device)
  
    for i in range(N):
        lt = torch.max(boxes1[i, :2], boxes2[:, :2])  # [M, 2]
        rb = torch.min(boxes1[i, 2:], boxes2[:, 2:])  # [M, 2]
        wh = (rb - lt).clamp(min=0)  # [M, 2]
        curr_inter = wh[:, 0] * wh[:, 1]  # [M]

        inter[i] = curr_inter
        union[i] = area1[i] + area2 - curr_inter

    return inter, union

Below are GPU memory usage before and after the modification.
Before modification
After modification

It is very odd that the memory usage even increases, but before the modification, GPU memory usage continues to increase. This is clearly not expected behaviour. Can anyone help me figure out what is going on?

Versions

[pip3] numpy==1.25.2
[pip3] pytorch-lightning==2.0.8
[pip3] torch==2.0.1
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.0.2
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions