Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor


class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
Expand Down Expand Up @@ -76,7 +78,15 @@ def full_load_to_vram(self) -> int:
for k, v in self._cpu_state_dict.items():
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
self._model.to(self._compute_device)

check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
self._model.to(self._compute_device)
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
else:
self._model.to(self._compute_device)

self._is_in_vram = True
return self._total_bytes
Expand All @@ -92,7 +102,15 @@ def full_unload_from_vram(self) -> int:

if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
self._model.to(self._offload_device)

check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
self._model.to(self._offload_device)
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
else:
self._model.to(self._offload_device)

self._is_in_vram = False
return self._total_bytes