Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions comfy/checkpoint_pickle.py

This file was deleted.

25 changes: 11 additions & 14 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import math
import struct
import comfy.checkpoint_pickle
import comfy.memory_management
import safetensors.torch
import numpy as np
from PIL import Image
Expand All @@ -38,26 +38,26 @@
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap

ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated

if True: # ckpt/pt file whitelist for safe loading of old sd files
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"

def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
return None
scalar.__module__ = "numpy.core.multiarray"

from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode

def encode(*args, **kwargs): # no longer necessary on newer torch
return None
encode.__module__ = "_codecs"

torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")


# Current as of safetensors 0.7.0
_TYPES = {
Expand Down Expand Up @@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if MMAP_TORCH_FILES:
torch_args["mmap"] = True

if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)

if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
Expand Down
Loading