Skip to content

Commit d352b4c

Browse files
committed
update deepspeed
1 parent b89ca6e commit d352b4c

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

src/lightning/fabric/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_parallel_devices(devices: Any) -> Any:
4949
@staticmethod
5050
@abstractmethod
5151
def get_device_type() -> Any:
52-
"""Get the device for the current Accelerator."""
52+
"""Get the device_type for the current Accelerator."""
5353

5454
@staticmethod
5555
@abstractmethod

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,12 @@ def __init__(
299299

300300
self._deepspeed_engine: Optional[DeepSpeedEngine] = None
301301

302+
if isinstance(self.accelerator, Accelerator):
303+
self.device_type = self.accelerator.get_device_type()
304+
else:
305+
self.device_type = "cuda"
306+
self.torch_lib = getattr(torch, self.device_type)
307+
302308
@property
303309
def zero_stage_3(self) -> bool:
304310
assert isinstance(self.config, dict)
@@ -511,10 +517,8 @@ def load_checkpoint(
511517

512518
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
513519

514-
if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device_type() != "cpu":
515-
getattr(torch, self.root_device.type).empty_cache()
516-
else:
517-
torch.cuda.empty_cache()
520+
if hasattr(torch, self.device_type) and callable(self.torch_lib.empty_cache):
521+
self.torch_lib.empty_cache()
518522

519523
_, client_state = engine.load_checkpoint(
520524
path,

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,12 @@ def __init__(
319319
self.hysteresis = hysteresis
320320
self.min_loss_scale = min_loss_scale
321321

322+
try:
323+
self.device_type = self.accelerator.get_device_type()
324+
except Exception:
325+
self.device_type = "cuda"
326+
self.torch_lib = getattr(torch, self.device_type)
327+
322328
@override
323329
def setup_environment(self) -> None:
324330
from deepspeed.runtime.utils import get_accelerator
@@ -672,6 +678,9 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
672678

673679
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
674680

681+
if hasattr(torch, self.device_type) and callable(self.torch_lib.empty_cache):
682+
self.torch_lib.empty_cache()
683+
675684
_, client_state = self.deepspeed_engine.load_checkpoint(
676685
checkpoint_path,
677686
load_optimizer_states=is_fitting,

0 commit comments

Comments
 (0)