|
20 | 20 | import torch |
21 | 21 | import math |
22 | 22 | import struct |
23 | | -import comfy.checkpoint_pickle |
| 23 | +import comfy.memory_management |
24 | 24 | import safetensors.torch |
25 | 25 | import numpy as np |
26 | 26 | from PIL import Image |
|
38 | 38 | MMAP_TORCH_FILES = args.mmap_torch_files |
39 | 39 | DISABLE_MMAP = args.disable_mmap |
40 | 40 |
|
41 | | -ALWAYS_SAFE_LOAD = False |
42 | | -if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated |
| 41 | + |
| 42 | +if True: # ckpt/pt file whitelist for safe loading of old sd files |
43 | 43 | class ModelCheckpoint: |
44 | 44 | pass |
45 | 45 | ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" |
46 | 46 |
|
47 | 47 | def scalar(*args, **kwargs): |
48 | | - from numpy.core.multiarray import scalar as sc |
49 | | - return sc(*args, **kwargs) |
| 48 | + return None |
50 | 49 | scalar.__module__ = "numpy.core.multiarray" |
51 | 50 |
|
52 | 51 | from numpy import dtype |
53 | 52 | from numpy.dtypes import Float64DType |
54 | | - from _codecs import encode |
| 53 | + |
| 54 | + def encode(*args, **kwargs): # no longer necessary on newer torch |
| 55 | + return None |
| 56 | + encode.__module__ = "_codecs" |
55 | 57 |
|
56 | 58 | torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode]) |
57 | | - ALWAYS_SAFE_LOAD = True |
58 | 59 | logging.info("Checkpoint files will always be loaded safely.") |
59 | | -else: |
60 | | - logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") |
| 60 | + |
61 | 61 |
|
62 | 62 | # Current as of safetensors 0.7.0 |
63 | 63 | _TYPES = { |
@@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): |
140 | 140 | if MMAP_TORCH_FILES: |
141 | 141 | torch_args["mmap"] = True |
142 | 142 |
|
143 | | - if safe_load or ALWAYS_SAFE_LOAD: |
144 | | - pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) |
145 | | - else: |
146 | | - logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt)) |
147 | | - pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) |
| 143 | + pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) |
| 144 | + |
148 | 145 | if "state_dict" in pl_sd: |
149 | 146 | sd = pl_sd["state_dict"] |
150 | 147 | else: |
|
0 commit comments