Skip to content

Commit f4e5b68

Browse files
Merge pull request #8 from codewithdark-git/fix-move-to-device-attribute-error
Fix: Handle nn.Module in move_to_device
2 parents 54e44e6 + 07d1028 commit f4e5b68

File tree

2 files changed

+67
-5
lines changed

2 files changed

+67
-5
lines changed

quantllm/quant/quantization_engine.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,25 @@ def get_device_map(model: PreTrainedModel) -> Dict[str, torch.device]:
1515
return device_map
1616

1717
def move_to_device(
18-
tensor: torch.Tensor,
18+
tensor: Union[torch.Tensor, torch.nn.Module],
1919
device: torch.device,
2020
force_copy: bool = False
21-
) -> torch.Tensor:
22-
"""Safely move tensor to device with proper error handling."""
21+
) -> Union[torch.Tensor, torch.nn.Module]:
22+
"""Safely move tensor or module to device with proper error handling."""
2323
try:
24+
if isinstance(tensor, torch.nn.Module):
25+
return tensor.to(device)
26+
# Existing logic for torch.Tensor
2427
if force_copy:
2528
return tensor.to(device, copy=True)
26-
if tensor.device == device:
29+
if tensor.device == device: # type: ignore[union-attr]
2730
return tensor
2831
return tensor.to(device)
2932
except Exception as e:
30-
raise RuntimeError(f"Failed to move tensor to {device}: {str(e)}")
33+
# It's good practice to indicate which tensor/module failed if possible,
34+
# but tensor name isn't available here.
35+
type_str = "module" if isinstance(tensor, torch.nn.Module) else "tensor"
36+
raise RuntimeError(f"Failed to move {type_str} to {device}: {str(e)}")
3137

3238
class DeviceManager:
3339
"""Manage device placement and synchronization."""

quantllm/quant/tests/test_api.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,5 +197,61 @@ def test_gguf_convert_to_gguf_gpu_layersCpu(self): # Quantizer on GPU, layers on
197197
def test_gguf_convert_to_gguf_gpu_layersDevice(self): # Quantizer on GPU, layers on GPU
198198
self._run_gguf_conversion_test(quantizer_device=self.device, gguf_cpu_offload=False)
199199

200+
# New test class for move_to_device
201+
from quantllm.quant.quantization_engine import move_to_device
202+
import torch.nn as nn
203+
204+
class TestMoveToDevice(unittest.TestCase):
205+
def test_move_tensor_and_module(self):
206+
"""Test move_to_device with both torch.Tensor and torch.nn.Module."""
207+
target_device_str = "cuda" if torch.cuda.is_available() else "cpu"
208+
target_device = torch.device(target_device_str)
209+
210+
# 1. Create a simple torch.Tensor
211+
my_tensor = torch.randn(2, 3, device="cpu") # Start on CPU
212+
213+
# 2. Create a simple torch.nn.Module
214+
my_module = nn.Linear(10, 10).to("cpu") # Start on CPU
215+
216+
# 4. Call move_to_device for the tensor and the module
217+
moved_tensor = move_to_device(my_tensor, target_device)
218+
moved_module = move_to_device(my_module, target_device)
219+
220+
# 5. Assert that the tensor is on the target device
221+
self.assertEqual(moved_tensor.device, target_device, "Tensor not moved to target device.")
222+
223+
# 6. Assert that the module is on the target device
224+
self.assertIsInstance(moved_module, nn.Module, "move_to_device did not return a Module.")
225+
226+
# Check device of a parameter
227+
if list(moved_module.parameters()): # Check if module has parameters
228+
self.assertEqual(
229+
next(moved_module.parameters()).device,
230+
target_device,
231+
"Module's parameters not moved to target device."
232+
)
233+
else: # Handle modules with no parameters (e.g. nn.ReLU()) if needed for future tests
234+
# For a simple Linear layer, this else block shouldn't be hit.
235+
# If testing with modules without parameters, one might check an attribute
236+
# or skip device check if not applicable. For nn.Linear, parameters exist.
237+
pass
238+
239+
# Test with force_copy=True for tensors
240+
another_tensor = torch.randn(2,3, device=target_device)
241+
copied_tensor = move_to_device(another_tensor, target_device, force_copy=True)
242+
self.assertEqual(copied_tensor.device, target_device)
243+
if target_device_str == "cpu": # On CPU, to() without copy=True might return same object if already on device
244+
pass # Data pointer check is more complex and not strictly necessary for device check
245+
else: # On CUDA, .to(device) typically creates a new tensor unless it's already there.
246+
# force_copy=True should ensure it's a different object.
247+
if another_tensor.is_cuda and copied_tensor.is_cuda: # Both on CUDA
248+
self.assertNotEqual(another_tensor.data_ptr(), copied_tensor.data_ptr(), "force_copy=True did not create a new tensor copy on CUDA.")
249+
250+
# Test moving a module already on the target device
251+
module_on_target = nn.Linear(5,5).to(target_device)
252+
moved_module_again = move_to_device(module_on_target, target_device)
253+
self.assertEqual(next(moved_module_again.parameters()).device, target_device)
254+
255+
200256
if __name__ == '__main__':
201257
unittest.main()

0 commit comments

Comments
 (0)