Skip to content

Add to overload for GGMLTensor, so calling to on the model moves the quantized data. #7949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor


class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
Expand Down Expand Up @@ -76,7 +78,15 @@ def full_load_to_vram(self) -> int:
for k, v in self._cpu_state_dict.items():
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
self._model.to(self._compute_device)

check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
self._model.to(self._compute_device)
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
else:
self._model.to(self._compute_device)

self._is_in_vram = True
return self._total_bytes
Expand All @@ -92,7 +102,15 @@ def full_unload_from_vram(self) -> int:

if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
self._model.to(self._offload_device)

check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
self._model.to(self._offload_device)
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
else:
self._model.to(self._offload_device)

self._is_in_vram = False
return self._total_bytes
Loading