-
Couldn't load subscription status.
- Fork 10.4k
Add RAM Pressure cache mode #10454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add RAM Pressure cache mode #10454
Changes from all commits
9e0cc78
708f002
d34fb4d
0c95f22
f3f526f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| NullCache, | ||
| HierarchicalCache, | ||
| LRUCache, | ||
| RAMPressureCache, | ||
| ) | ||
| from comfy_execution.graph import ( | ||
| DynamicPrompt, | ||
|
|
@@ -92,45 +93,47 @@ class CacheType(Enum): | |
| CLASSIC = 0 | ||
| LRU = 1 | ||
| NONE = 2 | ||
| RAM_PRESSURE = 3 | ||
|
|
||
|
|
||
| class CacheSet: | ||
| def __init__(self, cache_type=None, cache_size=None): | ||
| def __init__(self, cache_type=None, cache_args={}): | ||
| if cache_type == CacheType.NONE: | ||
| self.init_null_cache() | ||
| logging.info("Disabling intermediate node cache.") | ||
| elif cache_type == CacheType.RAM_PRESSURE: | ||
| cache_ram = cache_args.get("ram", 16.0) | ||
| self.init_ram_cache(cache_ram) | ||
| logging.info("Using RAM pressure cache.") | ||
| elif cache_type == CacheType.LRU: | ||
| if cache_size is None: | ||
| cache_size = 0 | ||
| cache_size = cache_args.get("lru", 0) | ||
| self.init_lru_cache(cache_size) | ||
| logging.info("Using LRU cache") | ||
| else: | ||
| self.init_classic_cache() | ||
|
|
||
| self.all = [self.outputs, self.ui, self.objects] | ||
| self.all = [self.outputs, self.objects] | ||
|
|
||
| # Performs like the old cache -- dump data ASAP | ||
| def init_classic_cache(self): | ||
| self.outputs = HierarchicalCache(CacheKeySetInputSignature) | ||
| self.ui = HierarchicalCache(CacheKeySetInputSignature) | ||
| self.objects = HierarchicalCache(CacheKeySetID) | ||
|
|
||
| def init_lru_cache(self, cache_size): | ||
| self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) | ||
| self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) | ||
| self.objects = HierarchicalCache(CacheKeySetID) | ||
|
|
||
| def init_ram_cache(self, min_headroom): | ||
| self.outputs = RAMPressureCache(CacheKeySetInputSignature) | ||
| self.objects = HierarchicalCache(CacheKeySetID) | ||
|
|
||
| def init_null_cache(self): | ||
| self.outputs = NullCache() | ||
| #The UI cache is expected to be iterable at the end of each workflow | ||
| #so it must cache at least a full workflow. Use Heirachical | ||
| self.ui = HierarchicalCache(CacheKeySetInputSignature) | ||
| self.objects = NullCache() | ||
|
|
||
| def recursive_debug_dump(self): | ||
| result = { | ||
| "outputs": self.outputs.recursive_debug_dump(), | ||
| "ui": self.ui.recursive_debug_dump(), | ||
| } | ||
| return result | ||
|
|
||
|
|
@@ -393,20 +396,23 @@ def format_value(x): | |
| else: | ||
| return str(x) | ||
|
|
||
| async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): | ||
| async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): | ||
| unique_id = current_item | ||
| real_node_id = dynprompt.get_real_node_id(unique_id) | ||
| display_node_id = dynprompt.get_display_node_id(unique_id) | ||
| parent_node_id = dynprompt.get_parent_node_id(unique_id) | ||
| inputs = dynprompt.get_node(unique_id)['inputs'] | ||
| class_type = dynprompt.get_node(unique_id)['class_type'] | ||
| class_def = nodes.NODE_CLASS_MAPPINGS[class_type] | ||
| if caches.outputs.get(unique_id) is not None: | ||
| cached = caches.outputs.get(unique_id) | ||
| if cached is not None: | ||
| if server.client_id is not None: | ||
| cached_output = caches.ui.get(unique_id) or {} | ||
| server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) | ||
| cached_ui = cached[1] or {} | ||
| server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) | ||
| if cached[1] is not None: | ||
| ui_outputs[unique_id] = cached[1] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we can safely replace the UI cache with the outputs of the normal cache. While in many cases the UI cache will contain the same data as the normal output cache, in many other cases it won't. The simplest example is use of the The easiest way to check whether the UI cache is functioning is:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I might have misled in the git writeup as it's not replacing the UI cache with the output cache on the content level. The actual cache contents for outputs and UI remain completely separate, but the "output" cache data structure now stores both pieces as a two element tuple. This is possible as the resident keys of the output cache is always a superset of the keys of the UI cache. This is the key change where the outputs cache is populated with the two element tuple: The getters use element [1] for the UI and element [0] for the outputs. I did a quick test of your scenario in LRU caching mode. It works as described. Screencast.from.2025-10-29.20-12-33.webmThere was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh, I see, I totally missed that. This all seems functionally correct to me 👍 What do you think of using a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed it too, +1 on the NamedTuple |
||
| get_progress_state().finish_progress(unique_id) | ||
| execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) | ||
| execution_list.cache_update(unique_id, cached) | ||
| return (ExecutionResult.SUCCESS, None, None) | ||
|
|
||
| input_data_all = None | ||
|
|
@@ -506,15 +512,15 @@ async def await_completion(): | |
| asyncio.create_task(await_completion()) | ||
| return (ExecutionResult.PENDING, None, None) | ||
| if len(output_ui) > 0: | ||
| caches.ui.set(unique_id, { | ||
| ui_outputs[unique_id] = { | ||
| "meta": { | ||
| "node_id": unique_id, | ||
| "display_node": display_node_id, | ||
| "parent_node": parent_node_id, | ||
| "real_node_id": real_node_id, | ||
| }, | ||
| "output": output_ui | ||
| }) | ||
| } | ||
| if server.client_id is not None: | ||
| server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) | ||
| if has_subgraph: | ||
|
|
@@ -557,8 +563,8 @@ async def await_completion(): | |
| pending_subgraph_results[unique_id] = cached_outputs | ||
| return (ExecutionResult.PENDING, None, None) | ||
|
|
||
| caches.outputs.set(unique_id, output_data) | ||
| execution_list.cache_update(unique_id, output_data) | ||
| execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id))) | ||
| caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id))) | ||
|
|
||
| except comfy.model_management.InterruptProcessingException as iex: | ||
| logging.info("Processing interrupted") | ||
|
|
@@ -603,14 +609,14 @@ async def await_completion(): | |
| return (ExecutionResult.SUCCESS, None, None) | ||
|
|
||
| class PromptExecutor: | ||
| def __init__(self, server, cache_type=False, cache_size=None): | ||
| self.cache_size = cache_size | ||
| def __init__(self, server, cache_type=False, cache_args=None): | ||
| self.cache_args = cache_args | ||
| self.cache_type = cache_type | ||
| self.server = server | ||
| self.reset() | ||
|
|
||
| def reset(self): | ||
| self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) | ||
| self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) | ||
| self.status_messages = [] | ||
| self.success = True | ||
|
|
||
|
|
@@ -685,6 +691,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= | |
| broadcast=False) | ||
| pending_subgraph_results = {} | ||
| pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results | ||
| ui_node_outputs = {} | ||
| executed = set() | ||
| execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) | ||
| current_outputs = self.caches.outputs.all_node_ids() | ||
|
|
@@ -698,7 +705,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= | |
| break | ||
|
|
||
| assert node_id is not None, "Node ID should not be None at this point" | ||
| result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) | ||
| result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) | ||
| self.success = result != ExecutionResult.FAILURE | ||
| if result == ExecutionResult.FAILURE: | ||
| self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) | ||
|
|
@@ -707,18 +714,16 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= | |
| execution_list.unstage_node_execution() | ||
| else: # result == ExecutionResult.SUCCESS: | ||
| execution_list.complete_node_execution() | ||
| self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) | ||
| else: | ||
| # Only execute when the while-loop ends without break | ||
| self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) | ||
|
|
||
| ui_outputs = {} | ||
| meta_outputs = {} | ||
| all_node_ids = self.caches.ui.all_node_ids() | ||
| for node_id in all_node_ids: | ||
| ui_info = self.caches.ui.get(node_id) | ||
| if ui_info is not None: | ||
| ui_outputs[node_id] = ui_info["output"] | ||
| meta_outputs[node_id] = ui_info["meta"] | ||
| for node_id, ui_info in ui_node_outputs.items(): | ||
| ui_outputs[node_id] = ui_info["output"] | ||
| meta_outputs[node_id] = ui_info["meta"] | ||
| self.history_result = { | ||
| "outputs": ui_outputs, | ||
| "meta": meta_outputs, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this now returning only
value[0]rather than all results of.get?