Skip to content

Commit e1ede29

Browse files
Remove unsafe pickle loading code that was used on pytorch older than 2.4 (Comfy-Org#12473)
ComfyUI hasn't started on pytorch 2.4 since last month.
1 parent df1e5e8 commit e1ede29

File tree

2 files changed

+11
-27
lines changed

2 files changed

+11
-27
lines changed

comfy/checkpoint_pickle.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

comfy/utils.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
import math
2222
import struct
23-
import comfy.checkpoint_pickle
23+
import comfy.memory_management
2424
import safetensors.torch
2525
import numpy as np
2626
from PIL import Image
@@ -38,26 +38,26 @@
3838
MMAP_TORCH_FILES = args.mmap_torch_files
3939
DISABLE_MMAP = args.disable_mmap
4040

41-
ALWAYS_SAFE_LOAD = False
42-
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
41+
42+
if True: # ckpt/pt file whitelist for safe loading of old sd files
4343
class ModelCheckpoint:
4444
pass
4545
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
4646

4747
def scalar(*args, **kwargs):
48-
from numpy.core.multiarray import scalar as sc
49-
return sc(*args, **kwargs)
48+
return None
5049
scalar.__module__ = "numpy.core.multiarray"
5150

5251
from numpy import dtype
5352
from numpy.dtypes import Float64DType
54-
from _codecs import encode
53+
54+
def encode(*args, **kwargs): # no longer necessary on newer torch
55+
return None
56+
encode.__module__ = "_codecs"
5557

5658
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
57-
ALWAYS_SAFE_LOAD = True
5859
logging.info("Checkpoint files will always be loaded safely.")
59-
else:
60-
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
60+
6161

6262
# Current as of safetensors 0.7.0
6363
_TYPES = {
@@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
140140
if MMAP_TORCH_FILES:
141141
torch_args["mmap"] = True
142142

143-
if safe_load or ALWAYS_SAFE_LOAD:
144-
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
145-
else:
146-
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
147-
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
143+
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
144+
148145
if "state_dict" in pl_sd:
149146
sd = pl_sd["state_dict"]
150147
else:

0 commit comments

Comments
 (0)