This repository was archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 33
Support mixed precision with CPU driver #118
Open
amogkam
wants to merge
45
commits into
ray-project:main
Choose a base branch
from
amogkam:cpu-head-mixed-precision
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
8432251
add annotations
amogkam 685a309
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam 3db5213
Bump pytorch-lightning from 1.4.7 to 1.5.2
dependabot[bot] adec375
update
amogkam 0d7f3b2
fix test
amogkam 21943d0
update readme
amogkam eb06731
more fixes
amogkam e5dd626
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam e451f58
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam 58d8da0
move to post_dispatch
amogkam 3936220
address comments
amogkam c124650
lint
amogkam 84572e3
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam 012a99e
add back
amogkam 9f64d1d
fix
amogkam 664dcee
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam 6ff7ad3
fix test
amogkam bcdefda
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam 796e041
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam 64089d8
fix test
amogkam 9ea707f
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam 6246f4b
fix
amogkam dd20863
updated delayedgpuaccelerator
amogkam d1015e7
fix tests
amogkam a11a0cc
fix tests
amogkam 82334b1
fix gpu id
amogkam a8bba48
fix root device
amogkam 53df245
unpin
amogkam ed884e1
fix
amogkam a4502de
share devices
amogkam 49b75e5
share devices
amogkam 5efbf41
horovod delayed accelerator
amogkam bebc8f7
fix horovod root device
amogkam bcd1b4c
1.5-gpu
amogkam b9cd821
update
amogkam 2c7f5a0
lint
amogkam f0a2b02
fix
amogkam 0c7fd92
4 gpu
amogkam e21016e
Merge branch '1.5-gpu' of github.com:amogkam/ray_lightning_accelerato…
amogkam b486597
wip
amogkam f5a657a
update
amogkam 79123a2
Merge branch 'main' of github.com:ray-project/ray_lightning_accelerat…
amogkam 0047eb1
lint
amogkam 31ccc44
fix
amogkam fef9132
fix
amogkam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,17 +25,23 @@ def ray_start_4_cpus_4_gpus(): | |
ray.shutdown() | ||
|
||
|
||
def train_func(dir, plugin, callbacks=None): | ||
def train_func(dir, plugin, callbacks=None, amp=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be a test with |
||
def _inner_train(config): | ||
model = BoringModel() | ||
trainer = get_trainer( | ||
dir, | ||
callbacks=callbacks, | ||
plugins=[plugin], | ||
checkpoint_callback=False, | ||
gpus=1 if amp else 0, | ||
precision=16 if amp else 32, | ||
**config) | ||
trainer.fit(model) | ||
|
||
if amp: | ||
# Make sure PTL doesn't automatically replace with bf16 | ||
assert trainer.precision == 16 | ||
|
||
return _inner_train | ||
|
||
|
||
|
@@ -104,3 +110,19 @@ def test_checkpoint_horovod_gpu(tmpdir, ray_start_4_cpus_4_gpus): | |
"""Tests if Tune checkpointing works with HorovodRayAccelerator.""" | ||
plugin = HorovodRayPlugin(num_workers=2, use_gpu=True) | ||
checkpoint_test(tmpdir, plugin) | ||
|
||
|
||
def tune_test_mixed_precision(dir, plugin): | ||
tune.run( | ||
train_func(dir, plugin), | ||
resources_per_trial=get_tune_resources( | ||
num_workers=plugin.num_workers, use_gpu=plugin.use_gpu), | ||
num_samples=2) | ||
|
||
|
||
@pytest.mark.skipif( | ||
torch.cuda.device_count() < 4, reason="test requires multi-GPU machine") | ||
def test_tune_mixed_precision_ddp_gpu(tmpdir, ray_start_4_cpus_4_gpus): | ||
"""Tests if Tune works with mixed precision.""" | ||
plugin = RayPlugin(num_workers=2, use_gpu=True) | ||
tune_test_mixed_precision(tmpdir, plugin) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite following this logic - by removing
isinstance(current_accelerator, CPUAccelerator)
, what scenario does this solve? Wouldn't the problem case (CPUAccelerator
) be changed toDelayedGPUAccelerator
both before and after this PR?