Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Support mixed precision with CPU driver #118

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
8432251
add annotations
amogkam Sep 9, 2021
685a309
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Sep 23, 2021
3db5213
Bump pytorch-lightning from 1.4.7 to 1.5.2
dependabot[bot] Nov 22, 2021
adec375
update
amogkam Jan 6, 2022
0d7f3b2
fix test
amogkam Jan 6, 2022
21943d0
update readme
amogkam Jan 6, 2022
eb06731
more fixes
amogkam Jan 6, 2022
e5dd626
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Jan 6, 2022
e451f58
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Jan 6, 2022
58d8da0
move to post_dispatch
amogkam Jan 19, 2022
3936220
address comments
amogkam Jan 19, 2022
c124650
lint
amogkam Jan 19, 2022
84572e3
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Jan 19, 2022
012a99e
add back
amogkam Jan 19, 2022
9f64d1d
fix
amogkam Jan 19, 2022
664dcee
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Jan 20, 2022
6ff7ad3
fix test
amogkam Jan 20, 2022
bcdefda
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Jan 21, 2022
796e041
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Jan 22, 2022
64089d8
fix test
amogkam Jan 22, 2022
9ea707f
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Jan 22, 2022
6246f4b
fix
amogkam Jan 22, 2022
dd20863
updated delayedgpuaccelerator
amogkam Jan 22, 2022
d1015e7
fix tests
amogkam Jan 22, 2022
a11a0cc
fix tests
amogkam Jan 22, 2022
82334b1
fix gpu id
amogkam Jan 22, 2022
a8bba48
fix root device
amogkam Jan 22, 2022
53df245
unpin
amogkam Jan 22, 2022
ed884e1
fix
amogkam Jan 22, 2022
a4502de
share devices
amogkam Jan 22, 2022
49b75e5
share devices
amogkam Jan 22, 2022
5efbf41
horovod delayed accelerator
amogkam Jan 22, 2022
bebc8f7
fix horovod root device
amogkam Jan 22, 2022
bcd1b4c
1.5-gpu
amogkam Jan 22, 2022
b9cd821
update
amogkam Jan 22, 2022
2c7f5a0
lint
amogkam Jan 22, 2022
f0a2b02
fix
amogkam Jan 25, 2022
0c7fd92
4 gpu
amogkam Jan 25, 2022
e21016e
Merge branch '1.5-gpu' of github.com:amogkam/ray_lightning_accelerato…
amogkam Jan 25, 2022
b486597
wip
amogkam Jan 25, 2022
f5a657a
update
amogkam Jan 27, 2022
79123a2
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam Jan 27, 2022
0047eb1
lint
amogkam Jan 27, 2022
31ccc44
fix
amogkam Jan 27, 2022
fef9132
fix
amogkam Jan 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions ray_lightning/ray_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning import _logger as log, LightningModule
from pytorch_lightning.trainer.states import TrainerFn
Expand All @@ -23,7 +22,7 @@

from ray_lightning.session import init_session
from ray_lightning.util import process_results, to_state_stream, \
load_state_stream
load_state_stream, swap_accelerator
from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled


Expand Down Expand Up @@ -173,22 +172,7 @@ def setup(self):
ray.get([w.execute.remote(self.init_hook) for w in self.workers])

def setup_environment(self) -> None:
# Swap out the accelerator if necessary.
# This is needed to support CPU head with GPU workers or Ray Client.
current_accelerator = self.lightning_module.trainer.accelerator
if self.use_gpu and isinstance(current_accelerator, CPUAccelerator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite following this logic - by removing isinstance(current_accelerator, CPUAccelerator), what scenario does this solve? Wouldn't the problem case (CPUAccelerator) be changed to DelayedGPUAccelerator both before and after this PR?

from weakref import proxy
from ray_lightning.util import DelayedGPUAccelerator
precision_plugin = current_accelerator.precision_plugin
new_accelerator = DelayedGPUAccelerator(
precision_plugin=precision_plugin, training_type_plugin=self)
self.lightning_module.trainer._accelerator_connector \
._training_type_plugin = \
proxy(new_accelerator.training_type_plugin)
self.lightning_module.trainer._accelerator_connector \
._precision_plugin = proxy(new_accelerator.precision_plugin)
self.lightning_module.trainer._accelerator_connector.accelerator \
= new_accelerator
swap_accelerator(self)

def _setup_env_vars(self):
# Get rank 0 worker address and port for DDP connection.
Expand Down
20 changes: 2 additions & 18 deletions ray_lightning/ray_horovod.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import pytorch_lightning as pl
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import HorovodPlugin
from pytorch_lightning.utilities import rank_zero_only

Expand All @@ -11,7 +10,7 @@

from ray_lightning.session import init_session
from ray_lightning.util import process_results, Unavailable, to_state_stream, \
load_state_stream
load_state_stream, swap_accelerator
from ray_lightning.tune import TUNE_INSTALLED, is_session_enabled

try:
Expand Down Expand Up @@ -126,22 +125,7 @@ def setup(self):
self.executor.start(executable_cls=get_executable_cls())

def setup_environment(self) -> None:
# Swap out the accelerator if necessary.
# This is needed to support CPU head with GPU workers or Ray Client.
current_accelerator = self.lightning_module.trainer.accelerator
if self.use_gpu and isinstance(current_accelerator, CPUAccelerator):
from weakref import proxy
from ray_lightning.util import DelayedGPUAccelerator
precision_plugin = current_accelerator.precision_plugin
new_accelerator = DelayedGPUAccelerator(
precision_plugin=precision_plugin, training_type_plugin=self)
self.lightning_module.trainer._accelerator_connector \
._training_type_plugin = \
proxy(new_accelerator.training_type_plugin)
self.lightning_module.trainer._accelerator_connector \
._precision_plugin = proxy(new_accelerator.precision_plugin)
self.lightning_module.trainer._accelerator_connector.accelerator \
= new_accelerator
swap_accelerator(self)

def pre_dispatch(self):
"""All pre-dispatch logic should be done in train_remote instead."""
Expand Down
1 change: 1 addition & 0 deletions ray_lightning/tests/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def start_ray_client_server_2_cpus():
ray.init(num_cpus=2)
with ray_start_client_server() as client:
yield client
ray.shutdown()


@pytest.fixture
Expand Down
12 changes: 12 additions & 0 deletions ray_lightning/tests/test_ddp_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def test_train(tmpdir, ray_start_2_gpus, num_workers):
train_test(trainer, model)


@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_train_mixed_precision(tmpdir, ray_start_2_gpus):
"""Tests if training works with mixed precision."""
model = BoringModel()
plugin = RayPlugin(num_workers=2, use_gpu=True)
trainer = get_trainer(tmpdir, plugins=[plugin], gpus=1, precision=16)
# Make sure PTL doesn't automatically replace with bf16.
assert trainer.precision == 16
train_test(trainer, model)


@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.parametrize("num_workers", [1, 2])
Expand Down
24 changes: 23 additions & 1 deletion ray_lightning/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,23 @@ def ray_start_4_cpus_4_gpus():
ray.shutdown()


def train_func(dir, plugin, callbacks=None):
def train_func(dir, plugin, callbacks=None, amp=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a test with amp=True?

def _inner_train(config):
model = BoringModel()
trainer = get_trainer(
dir,
callbacks=callbacks,
plugins=[plugin],
checkpoint_callback=False,
gpus=1 if amp else 0,
precision=16 if amp else 32,
**config)
trainer.fit(model)

if amp:
# Make sure PTL doesn't automatically replace with bf16
assert trainer.precision == 16

return _inner_train


Expand Down Expand Up @@ -104,3 +110,19 @@ def test_checkpoint_horovod_gpu(tmpdir, ray_start_4_cpus_4_gpus):
"""Tests if Tune checkpointing works with HorovodRayAccelerator."""
plugin = HorovodRayPlugin(num_workers=2, use_gpu=True)
checkpoint_test(tmpdir, plugin)


def tune_test_mixed_precision(dir, plugin):
tune.run(
train_func(dir, plugin),
resources_per_trial=get_tune_resources(
num_workers=plugin.num_workers, use_gpu=plugin.use_gpu),
num_samples=2)


@pytest.mark.skipif(
torch.cuda.device_count() < 4, reason="test requires multi-GPU machine")
def test_tune_mixed_precision_ddp_gpu(tmpdir, ray_start_4_cpus_4_gpus):
"""Tests if Tune works with mixed precision."""
plugin = RayPlugin(num_workers=2, use_gpu=True)
tune_test_mixed_precision(tmpdir, plugin)
22 changes: 22 additions & 0 deletions ray_lightning/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytorch_lightning import Trainer

import ray
from pytorch_lightning.plugins import TrainingTypePlugin


class DelayedGPUAccelerator(GPUAccelerator):
Expand Down Expand Up @@ -37,6 +38,27 @@ def on_train_start(self) -> None:
super(DelayedGPUAccelerator, self).on_train_start()


def swap_accelerator(plugin: TrainingTypePlugin):
# Swap out the accelerator if necessary.
# This is needed to support CPU head with GPU workers or Ray Client.
# This is also needed to support GPU-only optimizations like mixed
# precision when using CPU head with GPU workers or Ray Client.
current_accelerator = plugin.lightning_module.trainer.accelerator

if plugin.use_gpu:
from weakref import proxy
precision_plugin = current_accelerator.precision_plugin
new_accelerator = DelayedGPUAccelerator(
precision_plugin=precision_plugin, training_type_plugin=plugin)
plugin.lightning_module.trainer._accelerator_connector \
._training_type_plugin = \
proxy(new_accelerator.training_type_plugin)
plugin.lightning_module.trainer._accelerator_connector \
._precision_plugin = proxy(new_accelerator.precision_plugin)
plugin.lightning_module.trainer._accelerator_connector.accelerator \
= new_accelerator


class Unavailable:
"""No object should be instance of this class"""

Expand Down