|
17 | 17 | align_module_device,
|
18 | 18 | delete_offload_parameter,
|
19 | 19 | disable_hf_hook,
|
| 20 | + get_execution_device, |
20 | 21 | has_offloaded_params,
|
21 | 22 | register_offload_parameter,
|
22 | 23 | update_offload_parameter,
|
23 | 24 | )
|
24 | 25 | from compressed_tensors.utils.offload import offload_to_weights_map
|
25 |
| -from tests.testing_utils import requires_accelerate |
| 26 | +from tests.testing_utils import requires_accelerate, requires_gpu |
26 | 27 |
|
27 | 28 |
|
28 | 29 | class ExampleModule(torch.nn.Module):
|
@@ -55,8 +56,46 @@ def test_has_offloaded_params():
|
55 | 56 | assert has_offloaded_params(module)
|
56 | 57 |
|
57 | 58 |
|
| 59 | +@requires_gpu |
| 60 | +@requires_accelerate() |
| 61 | +def test_get_execution_device(): |
| 62 | + from accelerate import init_empty_weights |
| 63 | + from accelerate.big_modeling import attach_align_device_hook |
| 64 | + |
| 65 | + # no offloading |
| 66 | + module = ExampleModule() |
| 67 | + assert get_execution_device(module) == torch.device("cpu") |
| 68 | + |
| 69 | + # with offloading |
| 70 | + attach_align_device_hook(module, torch.device("cuda:0")) |
| 71 | + assert get_execution_device(module) == torch.device("cuda:0") |
| 72 | + |
| 73 | + # in meta context |
| 74 | + with torch.device("meta"): |
| 75 | + module = ExampleModule() |
| 76 | + assert get_execution_device(module) == torch.device("meta") |
| 77 | + |
| 78 | + # offloaded in meta context |
| 79 | + module = ExampleModule() |
| 80 | + attach_align_device_hook(module, torch.device("cuda:0")) |
| 81 | + with torch.device("meta"): |
| 82 | + assert get_execution_device(module) == torch.device("cuda:0") |
| 83 | + |
| 84 | + # in empty weights context |
| 85 | + with init_empty_weights(): |
| 86 | + module = ExampleModule() |
| 87 | + assert get_execution_device(module) == torch.device("meta") |
| 88 | + |
| 89 | + # offloaded in empty weights context |
| 90 | + module = ExampleModule() |
| 91 | + attach_align_device_hook(module, torch.device("cuda:0")) |
| 92 | + with init_empty_weights(): |
| 93 | + assert get_execution_device(module) == torch.device("cuda:0") |
| 94 | + |
| 95 | + |
58 | 96 | @requires_accelerate()
|
59 | 97 | def test_register_offload_parameter():
|
| 98 | + from accelerate import init_empty_weights |
60 | 99 | from accelerate.hooks import attach_align_device_hook
|
61 | 100 |
|
62 | 101 | module = ExampleModule()
|
@@ -94,6 +133,12 @@ def test_register_offload_parameter():
|
94 | 133 | assert module.f.device == torch.device("cpu")
|
95 | 134 | assert module._hf_hook.weights_map["f"].device == torch.device("cpu")
|
96 | 135 |
|
| 136 | + # parameters registered in the empty init context are still empty |
| 137 | + with init_empty_weights(): |
| 138 | + module = ExampleModule() |
| 139 | + register_offload_parameter(module, "c", parameter) |
| 140 | + assert module.a.device == module.b.device == module.c.device == torch.device("meta") |
| 141 | + |
97 | 142 |
|
98 | 143 | @requires_accelerate()
|
99 | 144 | def test_update_offload_parameter():
|
|
0 commit comments