Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class LatentPreviewMethod(enum.Enum):
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
Expand Down
3 changes: 3 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def model_size(self):
self.size = comfy.model_management.module_size(self.model)
return self.size

def get_ram_usage(self):
return self.model_size()

def loaded_size(self):
return self.model.model_loaded_weight_memory

Expand Down
14 changes: 14 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def clone(self):
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n

def get_ram_usage(self):
return self.patcher.get_ram_usage()

def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)

Expand Down Expand Up @@ -293,6 +296,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None

self.downscale_index_formula = None
self.upscale_index_formula = None
Expand Down Expand Up @@ -595,6 +599,16 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):

self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()

def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size

def get_ram_usage(self):
return self.model_size()

def throw_exception_if_invalid(self):
if self.first_stage_model is None:
Expand Down
83 changes: 83 additions & 0 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import bisect
import gc
import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -188,6 +193,9 @@ def clean_unused(self):
self._clean_cache()
self._clean_subcaches()

def poll(self, **kwargs):
pass

def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
Expand Down Expand Up @@ -276,6 +284,9 @@ def all_node_ids(self):
def clean_unused(self):
pass

def poll(self, **kwargs):
pass

def get(self, node_id):
return None

Expand Down Expand Up @@ -336,3 +347,75 @@ async def ensure_subcache_for(self, node_id, children_ids):
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self


#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.

RAM_CACHE_HYSTERESIS = 1.1

#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK

RAM_CACHE_DEFAULT_RAM_USAGE = 0.1

#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.

RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3

class RAMPressureCache(LRUCache):

def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}

def clean_unused(self):
self._clean_subcaches()

def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)

def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)

def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)

if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return

clean_list = []

for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])

ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)

oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))

while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()
7 changes: 6 additions & 1 deletion comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ def cache_link(self, from_node_id, to_node_id):
def get_output_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
return self.execution_cache[to_node_id].get(from_node_id)
value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this now returning only value[0] rather than all results of .get?


def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:
Expand Down
63 changes: 34 additions & 29 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NullCache,
HierarchicalCache,
LRUCache,
RAMPressureCache,
)
from comfy_execution.graph import (
DynamicPrompt,
Expand Down Expand Up @@ -92,45 +93,47 @@ class CacheType(Enum):
CLASSIC = 0
LRU = 1
NONE = 2
RAM_PRESSURE = 3


class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
cache_size = cache_args.get("lru", 0)
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()

self.all = [self.outputs, self.ui, self.objects]
self.all = [self.outputs, self.objects]

# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)

def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)

def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)

def init_null_cache(self):
self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache()

def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result

Expand Down Expand Up @@ -393,20 +396,23 @@ def format_value(x):
else:
return str(x)

async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
cached_ui = cached[1] or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached[1] is not None:
ui_outputs[unique_id] = cached[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can safely replace the UI cache with the outputs of the normal cache. While in many cases the UI cache will contain the same data as the normal output cache, in many other cases it won't. The simplest example is use of the Preview Any node to view the text representation of a tensor.
There are also many custom nodes which will use UI output to display data (like the size of an image passed through a node) that is different than the actual contents of the node's inputs/outputs.

The easiest way to check whether the UI cache is functioning is:

  1. Create two workflows, A and B which have relatively low RAM/VRAM requirements and so can both be cached at the same time.
  2. Execute workflow A --> A should execute and you should see the results of A
  3. Execute workflow B --> B should execute and you should see the results of B
  4. Execute workflow A again --> A should not execute any nodes (since they're all cached), but you should still see the results of A on all output nodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I might have misled in the git writeup as it's not replacing the UI cache with the output cache on the content level. The actual cache contents for outputs and UI remain completely separate, but the "output" cache data structure now stores both pieces as a two element tuple. This is possible as the resident keys of the output cache is always a superset of the keys of the UI cache.

This is the key change where the outputs cache is populated with the two element tuple:

@@ -557,8 +554,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
             pending_subgraph_results[unique_id] = cached_outputs
             return (ExecutionResult.PENDING, None, None)
 
-        caches.outputs.set(unique_id, output_data)
-        execution_list.cache_update(unique_id, output_data)
+        execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id)))
+        caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id)))

The getters use element [1] for the UI and element [0] for the outputs.

I did a quick test of your scenario in LRU caching mode. It works as described.

Screencast.from.2025-10-29.20-12-33.webm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I see, I totally missed that. This all seems functionally correct to me 👍 What do you think of using a NamedTuple instead of a raw tuple to make this storage a little more obvious (both when storing it and when accessing it instead of using [0] and [1])?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it too, +1 on the NamedTuple

get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)

input_data_all = None
Expand Down Expand Up @@ -506,15 +512,15 @@ async def await_completion():
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
ui_outputs[unique_id] = {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
"parent_node": parent_node_id,
"real_node_id": real_node_id,
},
"output": output_ui
})
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
Expand Down Expand Up @@ -557,8 +563,8 @@ async def await_completion():
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)

caches.outputs.set(unique_id, output_data)
execution_list.cache_update(unique_id, output_data)
execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id)))
caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id)))

except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
Expand Down Expand Up @@ -603,14 +609,14 @@ async def await_completion():
return (ExecutionResult.SUCCESS, None, None)

class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
def __init__(self, server, cache_type=False, cache_args=None):
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()

def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
self.success = True

Expand Down Expand Up @@ -685,6 +691,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
Expand All @@ -698,7 +705,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
break

assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
Expand All @@ -707,18 +714,16 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)

ui_outputs = {}
meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
Expand Down
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.NONE

e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
Expand Down