Description
🐛 Describe the bug
Training Faster R-CNN on large dataset (~1M of resolution 512x512) fails due to CUDA OOM in RPN.
These are the hyperparameters for the experiment:
rpn_pre_nms_top_n_train = 2000
rpn_pre_nms_top_n_test = 2000
rpn_post_nms_top_n_train = 2000
rpn_post_nms_top_n_test = 2000
rpn_nms_thresh = 0.7
rpn_fg_iou_thresh = 0.7
rpn_bg_iou_thresh = 0.3
rpn_batch_size_per_image = 256
rpn_positive_fraction = 0.5
rpn_score_thresh = 0
box_score_thresh = 0.05
box_nms_thresh = 0.1
box_detections_per_img = 100
box_fg_iou_thresh = 0.5
box_bg_iou_thresh = 0.5
box_batch_size_per_image = 512
box_positive_fraction = 0.25
batch_size = 4
File "/workspace/entrypoint.py", line 182, in <module>
main(args)
File "/workspace/entrypoint.py", line 152, in main
trainer.fit(
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
call._call_and_handle_interrupt(
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 980, in _run
results = self._run_stage()
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1023, in _run_stage
self.fit_loop.run()
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
self.advance()
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py", line 355, in advance
self.epoch_loop.run(self._data_fetcher)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
self.advance(data_fetcher)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/training_epoch_loop.py", line 219, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 188, in run
self._optimizer_step(kwargs.get("batch_idx", 0), closure)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 266, in _optimizer_step
call._call_lightning_module_hook(
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 146, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/module.py", line 1270, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/optimizer.py", line 161, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 231, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 116, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py", line 69, in wrapper
return wrapped(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 280, in wrapper
out = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 33, in _use_grad
ret = func(self, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/optim/adam.py", line 121, in step
loss = closure()
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 103, in _wrap_closure
closure_result = closure()
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 142, in __call__
self._result = self.closure(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 128, in closure
step_output = self._step_fn()
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/automatic.py", line 315, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 294, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 380, in training_step
return self.model.training_step(*args, **kwargs)
File "/workspace/lightning_modules.py", line 96, in training_step
loss_dict = self(images, targets)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/lightning_modules.py", line 90, in forward
return self.model(images, targets)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/workspace/models/detection/builder.py", line 183, in forward
proposals, proposal_losses = self.rpn(images, features, targets)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/models/detection/rpn.py", line 334, in forward
losses = self._return_loss(anchors, targets, objectness, pred_bbox_deltas)
File "/workspace/models/detection/rpn.py", line 343, in _return_loss
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
File "/workspace/models/detection/rpn.py", line 193, in assign_targets_to_anchors
match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
File "/usr/local/lib/python3.10/dist-packages/torchvision/ops/boxes.py", line 271, in box_iou
inter, union = _box_inter_union(boxes1, boxes2)
File "/usr/local/lib/python3.10/dist-packages/torchvision/ops/boxes.py", line 250, in _box_inter_union
union = area1[:, None] + area2 - inter
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.42 GiB (GPU 0; 15.99 GiB total capacity; 12.84 GiB already allocated; 273.61 MiB free; 14.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
_box_inter_union function in torchvision/ops/boxes.py seems like it consumes a large amount of memory in the tensor operations, when len(boxes1) and len(boxes2) are large. I have altered the code as following to resolve the issue:
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
# original
# area1 = box_area(boxes1)
# area2 = box_area(boxes2)
# lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
# rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
# wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
# inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
# union = area1[:, None] + area2 - inter
# return inter, union
area1 = box_area(boxes1)
area2 = box_area(boxes2)
N, M = boxes1.size(0), boxes2.size(0)
inter = torch.zeros((N, M), device=boxes1.device)
union = torch.zeros((N, M), device=boxes1.device)
for i in range(N):
lt = torch.max(boxes1[i, :2], boxes2[:, :2]) # [M, 2]
rb = torch.min(boxes1[i, 2:], boxes2[:, 2:]) # [M, 2]
wh = (rb - lt).clamp(min=0) # [M, 2]
curr_inter = wh[:, 0] * wh[:, 1] # [M]
inter[i] = curr_inter
union[i] = area1[i] + area2 - curr_inter
return inter, union
Below are GPU memory usage before and after the modification.
It is very odd that the memory usage even increases, but before the modification, GPU memory usage continues to increase. This is clearly not expected behaviour. Can anyone help me figure out what is going on?
Versions
[pip3] numpy==1.25.2
[pip3] pytorch-lightning==2.0.8
[pip3] torch==2.0.1
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.0.2
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0