Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from functools import partial
from typing import Union

import torch
Expand All @@ -17,7 +18,8 @@
elif is_mlu_available():
from torch.mlu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler as amp_GradScaler
GradScaler = partial(amp_GradScaler, device='cuda')


Comment on lines +22 to 24
Copy link

Copilot AI Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Creating a module-level variable GradScaler through partial assignment makes the code less maintainable and harder to understand. Consider either: (1) using amp_GradScaler('cuda', ...) directly at call sites, or (2) creating a proper wrapper function with a docstring explaining the device binding.

Suggested change
GradScaler = partial(amp_GradScaler, device='cuda')
def get_grad_scaler(*args, **kwargs):
"""Create a torch.amp.GradScaler instance bound to device='cuda'.
Args:
*args: Positional arguments passed to torch.amp.GradScaler.
**kwargs: Keyword arguments passed to torch.amp.GradScaler.
Returns:
amp_GradScaler: An instance of torch.amp.GradScaler with device='cuda'.
"""
return amp_GradScaler(*args, device='cuda', **kwargs)

Copilot uses AI. Check for mistakes.
@OPTIM_WRAPPERS.register_module()
Expand Down
11 changes: 7 additions & 4 deletions tests/test_optim/test_optimizer/test_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from functools import partial

import unittest
from unittest import TestCase
from unittest.mock import MagicMock
Comment on lines +3 to 7
Copy link

Copilot AI Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blank line after from functools import partial creates inconsistent import grouping. Move the functools import to be with other standard library imports (os, unittest) before the blank line that separates standard library imports from third-party imports.

Suggested change
from functools import partial
import unittest
from unittest import TestCase
from unittest.mock import MagicMock
import unittest
from unittest import TestCase
from unittest.mock import MagicMock
from functools import partial

Copilot uses AI. Check for mistakes.
Expand All @@ -8,7 +10,8 @@
import torch.distributed as torch_dist
import torch.nn as nn
from parameterized import parameterized
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler as amp_GradScaler
GradScaler = partial(amp_GradScaler, device='cuda')
Copy link

Copilot AI Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Creating a module-level variable GradScaler through partial assignment makes the code less maintainable and harder to understand. Consider either: (1) using amp_GradScaler('cuda', ...) directly at call sites, or (2) creating a proper wrapper function with a docstring explaining the device binding.

Suggested change
GradScaler = partial(amp_GradScaler, device='cuda')
def get_cuda_grad_scaler(*args, **kwargs):
"""Return a torch.amp.GradScaler instance bound to the 'cuda' device.
Args:
*args: Positional arguments for torch.amp.GradScaler.
**kwargs: Keyword arguments for torch.amp.GradScaler.
Returns:
amp_GradScaler: An instance of GradScaler with device='cuda'.
"""
return amp_GradScaler(*args, device='cuda', **kwargs)

Copilot uses AI. Check for mistakes.
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD, Adam, Optimizer

Expand Down Expand Up @@ -423,13 +426,13 @@ def setUp(self) -> None:
def test_init(self):
# Test with default arguments.
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, amp_GradScaler)

# Test with dynamic.
amp_optim_wrapper = AmpOptimWrapper(
'dynamic', optimizer=self.optimizer)
self.assertIsNone(amp_optim_wrapper._scale_update_param)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, amp_GradScaler)

# Test with dtype float16
amp_optim_wrapper = AmpOptimWrapper(
Expand All @@ -444,7 +447,7 @@ def test_init(self):
# Test with dict loss_scale.
amp_optim_wrapper = AmpOptimWrapper(
dict(init_scale=1, growth_factor=2), optimizer=self.optimizer)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, amp_GradScaler)
self.assertIsNone(amp_optim_wrapper._scale_update_param)
with self.assertRaisesRegex(TypeError,
'loss_scale must be of type float'):
Expand Down
Loading