Skip to content

Commit 8aa8b82

Browse files
authored
[Accelerate] allow get_execution_device to be used when initializing a model (#303)
* allow get_execution_device to be used when initializing a model Signed-off-by: Kyle Sayers <[email protected]> * formatting Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 16e6435 commit 8aa8b82

File tree

2 files changed

+66
-17
lines changed

2 files changed

+66
-17
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,6 @@ def is_module_offloaded(module: torch.nn.Module) -> bool:
9494
return has_offloaded_params(module)
9595

9696

97-
def get_execution_device(module: torch.nn.Module) -> torch.device:
98-
"""
99-
:param module: module to check
100-
:return: device module is loaded onto during forward pass
101-
"""
102-
if has_offloaded_params(module):
103-
return module._hf_hook.execution_device
104-
device = next(module.parameters()).device
105-
106-
# offload only gets set for leaf modules, fallback to checking for device type
107-
if device.type == "meta":
108-
return module._hf_hook.execution_device
109-
110-
return device
111-
112-
11397
def get_offloaded_device(module: torch.nn.Module) -> torch.device:
11498
"""
11599
:param module: module to check
@@ -158,6 +142,26 @@ def update_parameter_data(
158142
""" Candidates for Upstreaming """
159143

160144

145+
def get_execution_device(module: torch.nn.Module) -> torch.device:
146+
"""
147+
Get the device which inputs should be moved to before module execution
148+
149+
:param module: module to check, may be offloaded
150+
:return: onload device of module
151+
"""
152+
if has_offloaded_params(module):
153+
return module._hf_hook.execution_device
154+
155+
first_param = next(module.parameters(), None)
156+
if first_param is None:
157+
warnings.warn(
158+
f"Unable able to infer execution device of {module}, falling back to CPU"
159+
)
160+
return torch.device("cpu")
161+
162+
return first_param.device
163+
164+
161165
def register_offload_parameter(
162166
module: torch.nn.Module,
163167
name: str,

tests/test_utils/test_offload.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
align_module_device,
1818
delete_offload_parameter,
1919
disable_hf_hook,
20+
get_execution_device,
2021
has_offloaded_params,
2122
register_offload_parameter,
2223
update_offload_parameter,
2324
)
2425
from compressed_tensors.utils.offload import offload_to_weights_map
25-
from tests.testing_utils import requires_accelerate
26+
from tests.testing_utils import requires_accelerate, requires_gpu
2627

2728

2829
class ExampleModule(torch.nn.Module):
@@ -55,8 +56,46 @@ def test_has_offloaded_params():
5556
assert has_offloaded_params(module)
5657

5758

59+
@requires_gpu
60+
@requires_accelerate()
61+
def test_get_execution_device():
62+
from accelerate import init_empty_weights
63+
from accelerate.big_modeling import attach_align_device_hook
64+
65+
# no offloading
66+
module = ExampleModule()
67+
assert get_execution_device(module) == torch.device("cpu")
68+
69+
# with offloading
70+
attach_align_device_hook(module, torch.device("cuda:0"))
71+
assert get_execution_device(module) == torch.device("cuda:0")
72+
73+
# in meta context
74+
with torch.device("meta"):
75+
module = ExampleModule()
76+
assert get_execution_device(module) == torch.device("meta")
77+
78+
# offloaded in meta context
79+
module = ExampleModule()
80+
attach_align_device_hook(module, torch.device("cuda:0"))
81+
with torch.device("meta"):
82+
assert get_execution_device(module) == torch.device("cuda:0")
83+
84+
# in empty weights context
85+
with init_empty_weights():
86+
module = ExampleModule()
87+
assert get_execution_device(module) == torch.device("meta")
88+
89+
# offloaded in empty weights context
90+
module = ExampleModule()
91+
attach_align_device_hook(module, torch.device("cuda:0"))
92+
with init_empty_weights():
93+
assert get_execution_device(module) == torch.device("cuda:0")
94+
95+
5896
@requires_accelerate()
5997
def test_register_offload_parameter():
98+
from accelerate import init_empty_weights
6099
from accelerate.hooks import attach_align_device_hook
61100

62101
module = ExampleModule()
@@ -94,6 +133,12 @@ def test_register_offload_parameter():
94133
assert module.f.device == torch.device("cpu")
95134
assert module._hf_hook.weights_map["f"].device == torch.device("cpu")
96135

136+
# parameters registered in the empty init context are still empty
137+
with init_empty_weights():
138+
module = ExampleModule()
139+
register_offload_parameter(module, "c", parameter)
140+
assert module.a.device == module.b.device == module.c.device == torch.device("meta")
141+
97142

98143
@requires_accelerate()
99144
def test_update_offload_parameter():

0 commit comments

Comments
 (0)