Skip to content

Commit

Permalink
Fix multiple-model issues
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Sep 26, 2024
1 parent f1b7aac commit 1bb554c
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions tests/deepspeed/test_deepspeed_multiple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
slow,
)
from accelerate.test_utils.training import RegressionDataset
from accelerate.utils import patch_environment
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler, get_active_deepspeed_plugin


Expand Down Expand Up @@ -145,26 +146,27 @@ def test_multiple_accelerators(self):
_ = Accelerator(deepspeed_plugin=ds_zero3)

def test_prepare_multiple_models_zero3_inference(self):
ds_plugins = self.get_ds_plugins(zero3_inference=True)
accelerator = Accelerator(deepspeed_plugin=ds_plugins)
# Using Zero-2 first
model1 = self.model_init()
optimizer = DummyOptim(model1.parameters())
scheduler = DummyScheduler(optimizer)

dataset = RegressionDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
model1, optimizer, scheduler, dataloader = accelerator.prepare(model1, optimizer, scheduler, dataloader)
accelerator.state.select_deepspeed_plugin("zero3")
model2 = self.model_init()
with self.assertLogs(level="WARNING") as captured:
model2 = accelerator.prepare(model2)
self.assertIn(
"A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance.",
captured.output[0],
)

assert accelerator.deepspeed_engine_wrapped.engine is model1
with patch_environment(**self.dist_env):
ds_plugins = self.get_ds_plugins(zero3_inference=True)
accelerator = Accelerator(deepspeed_plugin=ds_plugins)
# Using Zero-2 first
model1 = self.model_init()
optimizer = DummyOptim(model1.parameters())
scheduler = DummyScheduler(optimizer)

dataset = RegressionDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
model1, optimizer, scheduler, dataloader = accelerator.prepare(model1, optimizer, scheduler, dataloader)
accelerator.state.select_deepspeed_plugin("zero3")
model2 = self.model_init()
with self.assertLogs(level="WARNING") as captured:
model2 = accelerator.prepare(model2)
self.assertIn(
"A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance.",
captured.output[0],
)

assert accelerator.deepspeed_engine_wrapped.engine is model1

@require_huggingface_suite
@require_multi_device
Expand Down

0 comments on commit 1bb554c

Please sign in to comment.