Skip to content

Commit 596ed68

Browse files
authored
Node Replacement API (Comfy-Org#12014)
1 parent ce4a1ab commit 596ed68

File tree

8 files changed

+291
-2
lines changed

8 files changed

+291
-2
lines changed

app/node_replace_manager.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from __future__ import annotations
2+
3+
from aiohttp import web
4+
5+
from typing import TYPE_CHECKING, TypedDict
6+
if TYPE_CHECKING:
7+
from comfy_api.latest._io_public import NodeReplace
8+
9+
from comfy_execution.graph_utils import is_link
10+
import nodes
11+
12+
class NodeStruct(TypedDict):
13+
inputs: dict[str, str | int | float | bool | tuple[str, int]]
14+
class_type: str
15+
_meta: dict[str, str]
16+
17+
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
18+
new_node_struct = node_struct.copy()
19+
if empty_inputs:
20+
new_node_struct["inputs"] = {}
21+
else:
22+
new_node_struct["inputs"] = node_struct["inputs"].copy()
23+
new_node_struct["_meta"] = node_struct["_meta"].copy()
24+
return new_node_struct
25+
26+
27+
class NodeReplaceManager:
28+
"""Manages node replacement registrations."""
29+
30+
def __init__(self):
31+
self._replacements: dict[str, list[NodeReplace]] = {}
32+
33+
def register(self, node_replace: NodeReplace):
34+
"""Register a node replacement mapping."""
35+
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
36+
37+
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
38+
"""Get replacements for an old node ID."""
39+
return self._replacements.get(old_node_id)
40+
41+
def has_replacement(self, old_node_id: str) -> bool:
42+
"""Check if a replacement exists for an old node ID."""
43+
return old_node_id in self._replacements
44+
45+
def apply_replacements(self, prompt: dict[str, NodeStruct]):
46+
connections: dict[str, list[tuple[str, str, int]]] = {}
47+
need_replacement: set[str] = set()
48+
for node_number, node_struct in prompt.items():
49+
class_type = node_struct["class_type"]
50+
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
51+
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
52+
need_replacement.add(node_number)
53+
# keep track of connections
54+
for input_id, input_value in node_struct["inputs"].items():
55+
if is_link(input_value):
56+
conn_number = input_value[0]
57+
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
58+
for node_number in need_replacement:
59+
node_struct = prompt[node_number]
60+
class_type = node_struct["class_type"]
61+
replacements = self.get_replacement(class_type)
62+
if replacements is None:
63+
continue
64+
# just use the first replacement
65+
replacement = replacements[0]
66+
new_node_id = replacement.new_node_id
67+
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
68+
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
69+
continue
70+
# first, replace node id (class_type)
71+
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
72+
new_node_struct["class_type"] = new_node_id
73+
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
74+
# second, replace inputs
75+
if replacement.input_mapping is not None:
76+
for input_map in replacement.input_mapping:
77+
if "set_value" in input_map:
78+
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
79+
elif "old_id" in input_map:
80+
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
81+
# finalize input replacement
82+
prompt[node_number] = new_node_struct
83+
# third, replace outputs
84+
if replacement.output_mapping is not None:
85+
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
86+
if node_number in connections:
87+
for conns in connections[node_number]:
88+
conn_node_number, conn_input_id, old_output_idx = conns
89+
for output_map in replacement.output_mapping:
90+
if output_map["old_idx"] == old_output_idx:
91+
new_output_idx = output_map["new_idx"]
92+
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
93+
previous_input[1] = new_output_idx
94+
95+
def as_dict(self):
96+
"""Serialize all replacements to dict."""
97+
return {
98+
k: [v.as_dict() for v in v_list]
99+
for k, v_list in self._replacements.items()
100+
}
101+
102+
def add_routes(self, routes):
103+
@routes.get("/node_replacements")
104+
async def get_node_replacements(request):
105+
return web.json_response(self.as_dict())

comfy_api/feature_flags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"supports_preview_metadata": True,
1515
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
1616
"extension": {"manager": {"supports_v4": True}},
17+
"node_replacements": True,
1718
}
1819

1920

comfy_api/latest/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
2121
VERSION = "latest"
2222
STABLE = False
2323

24+
def __init__(self):
25+
super().__init__()
26+
self.node_replacement = self.NodeReplacement()
27+
self.execution = self.Execution()
28+
29+
class NodeReplacement(ProxiedSingleton):
30+
async def register(self, node_replace: io.NodeReplace) -> None:
31+
"""Register a node replacement mapping."""
32+
from server import PromptServer
33+
PromptServer.instance.node_replace_manager.register(node_replace)
34+
2435
class Execution(ProxiedSingleton):
2536
async def set_progress(
2637
self,
@@ -73,8 +84,6 @@ async def set_progress(
7384
image=to_display,
7485
)
7586

76-
execution: Execution
77-
7887
class ComfyExtension(ABC):
7988
async def on_load(self) -> None:
8089
"""

comfy_api/latest/_io.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,68 @@ def as_dict(self) -> dict:
20302030
...
20312031

20322032

2033+
class InputMapOldId(TypedDict):
2034+
"""Map an old node input to a new node input by ID."""
2035+
new_id: str
2036+
old_id: str
2037+
2038+
class InputMapSetValue(TypedDict):
2039+
"""Set a specific value for a new node input."""
2040+
new_id: str
2041+
set_value: Any
2042+
2043+
InputMap = InputMapOldId | InputMapSetValue
2044+
"""
2045+
Input mapping for node replacement. Type is inferred by dictionary keys:
2046+
- {"new_id": str, "old_id": str} - maps old input to new input
2047+
- {"new_id": str, "set_value": Any} - sets a specific value for new input
2048+
"""
2049+
2050+
class OutputMap(TypedDict):
2051+
"""Map outputs of node replacement via indexes."""
2052+
new_idx: int
2053+
old_idx: int
2054+
2055+
class NodeReplace:
2056+
"""
2057+
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
2058+
2059+
Also supports assigning specific values to the input widgets of the new node.
2060+
2061+
Args:
2062+
new_node_id: The class name of the new replacement node.
2063+
old_node_id: The class name of the deprecated node.
2064+
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
2065+
connected. The workflow JSON stores widget values by their relative position index,
2066+
not by ID. This list maps those positional indexes to input IDs, enabling the
2067+
replacement system to correctly identify widget values during node migration.
2068+
input_mapping: List of input mappings from old node to new node.
2069+
output_mapping: List of output mappings from old node to new node.
2070+
"""
2071+
def __init__(self,
2072+
new_node_id: str,
2073+
old_node_id: str,
2074+
old_widget_ids: list[str] | None=None,
2075+
input_mapping: list[InputMap] | None=None,
2076+
output_mapping: list[OutputMap] | None=None,
2077+
):
2078+
self.new_node_id = new_node_id
2079+
self.old_node_id = old_node_id
2080+
self.old_widget_ids = old_widget_ids
2081+
self.input_mapping = input_mapping
2082+
self.output_mapping = output_mapping
2083+
2084+
def as_dict(self):
2085+
"""Create serializable representation of the node replacement."""
2086+
return {
2087+
"new_node_id": self.new_node_id,
2088+
"old_node_id": self.old_node_id,
2089+
"old_widget_ids": self.old_widget_ids,
2090+
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
2091+
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
2092+
}
2093+
2094+
20332095
__all__ = [
20342096
"FolderType",
20352097
"UploadType",
@@ -2121,4 +2183,5 @@ def as_dict(self) -> dict:
21212183
"ImageCompare",
21222184
"PriceBadgeDepends",
21232185
"PriceBadge",
2186+
"NodeReplace",
21242187
]

comfy_extras/nodes_post_processing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
655655
batched = batch_masks(values)
656656
return io.NodeOutput(batched)
657657

658+
658659
class PostProcessingExtension(ComfyExtension):
659660
@override
660661
async def get_node_list(self) -> list[type[io.ComfyNode]]:

comfy_extras/nodes_replacements.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from comfy_api.latest import ComfyExtension, io, ComfyAPI
2+
3+
api = ComfyAPI()
4+
5+
6+
async def register_replacements():
7+
"""Register all built-in node replacements."""
8+
await register_replacements_longeredge()
9+
await register_replacements_batchimages()
10+
await register_replacements_upscaleimage()
11+
await register_replacements_controlnet()
12+
await register_replacements_load3d()
13+
await register_replacements_preview3d()
14+
await register_replacements_svdimg2vid()
15+
await register_replacements_conditioningavg()
16+
17+
async def register_replacements_longeredge():
18+
# No dynamic inputs here
19+
await api.node_replacement.register(io.NodeReplace(
20+
new_node_id="ImageScaleToMaxDimension",
21+
old_node_id="ResizeImagesByLongerEdge",
22+
old_widget_ids=["longer_edge"],
23+
input_mapping=[
24+
{"new_id": "image", "old_id": "images"},
25+
{"new_id": "largest_size", "old_id": "longer_edge"},
26+
{"new_id": "upscale_method", "set_value": "lanczos"},
27+
],
28+
# just to test the frontend output_mapping code, does nothing really here
29+
output_mapping=[{"new_idx": 0, "old_idx": 0}],
30+
))
31+
32+
async def register_replacements_batchimages():
33+
# BatchImages node uses Autogrow
34+
await api.node_replacement.register(io.NodeReplace(
35+
new_node_id="BatchImagesNode",
36+
old_node_id="ImageBatch",
37+
input_mapping=[
38+
{"new_id": "images.image0", "old_id": "image1"},
39+
{"new_id": "images.image1", "old_id": "image2"},
40+
],
41+
))
42+
43+
async def register_replacements_upscaleimage():
44+
# ResizeImageMaskNode uses DynamicCombo
45+
await api.node_replacement.register(io.NodeReplace(
46+
new_node_id="ResizeImageMaskNode",
47+
old_node_id="ImageScaleBy",
48+
old_widget_ids=["upscale_method", "scale_by"],
49+
input_mapping=[
50+
{"new_id": "input", "old_id": "image"},
51+
{"new_id": "resize_type", "set_value": "scale by multiplier"},
52+
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
53+
{"new_id": "scale_method", "old_id": "upscale_method"},
54+
],
55+
))
56+
57+
async def register_replacements_controlnet():
58+
# T2IAdapterLoader → ControlNetLoader
59+
await api.node_replacement.register(io.NodeReplace(
60+
new_node_id="ControlNetLoader",
61+
old_node_id="T2IAdapterLoader",
62+
input_mapping=[
63+
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
64+
],
65+
))
66+
67+
async def register_replacements_load3d():
68+
# Load3DAnimation merged into Load3D
69+
await api.node_replacement.register(io.NodeReplace(
70+
new_node_id="Load3D",
71+
old_node_id="Load3DAnimation",
72+
))
73+
74+
async def register_replacements_preview3d():
75+
# Preview3DAnimation merged into Preview3D
76+
await api.node_replacement.register(io.NodeReplace(
77+
new_node_id="Preview3D",
78+
old_node_id="Preview3DAnimation",
79+
))
80+
81+
async def register_replacements_svdimg2vid():
82+
# Typo fix: SDV → SVD
83+
await api.node_replacement.register(io.NodeReplace(
84+
new_node_id="SVD_img2vid_Conditioning",
85+
old_node_id="SDV_img2vid_Conditioning",
86+
))
87+
88+
async def register_replacements_conditioningavg():
89+
# Typo fix: trailing space in node name
90+
await api.node_replacement.register(io.NodeReplace(
91+
new_node_id="ConditioningAverage",
92+
old_node_id="ConditioningAverage ",
93+
))
94+
95+
class NodeReplacementsExtension(ComfyExtension):
96+
async def on_load(self) -> None:
97+
await register_replacements()
98+
99+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
100+
return []
101+
102+
async def comfy_entrypoint() -> NodeReplacementsExtension:
103+
return NodeReplacementsExtension()

nodes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
22642264
if not isinstance(extension, ComfyExtension):
22652265
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
22662266
return False
2267+
await extension.on_load()
22672268
node_list = await extension.get_node_list()
22682269
if not isinstance(node_list, list):
22692270
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
@@ -2435,6 +2436,7 @@ async def init_builtin_extra_nodes():
24352436
"nodes_lora_debug.py",
24362437
"nodes_color.py",
24372438
"nodes_toolkit.py",
2439+
"nodes_replacements.py",
24382440
]
24392441

24402442
import_failed = []

server.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from app.model_manager import ModelFileManager
4141
from app.custom_node_manager import CustomNodeManager
4242
from app.subgraph_manager import SubgraphManager
43+
from app.node_replace_manager import NodeReplaceManager
4344
from typing import Optional, Union
4445
from api_server.routes.internal.internal_routes import InternalRoutes
4546
from protocol import BinaryEventTypes
@@ -204,6 +205,7 @@ def __init__(self, loop):
204205
self.model_file_manager = ModelFileManager()
205206
self.custom_node_manager = CustomNodeManager()
206207
self.subgraph_manager = SubgraphManager()
208+
self.node_replace_manager = NodeReplaceManager()
207209
self.internal_routes = InternalRoutes(self)
208210
self.supports = ["custom_nodes_from_web"]
209211
self.prompt_queue = execution.PromptQueue(self)
@@ -887,6 +889,8 @@ async def post_prompt(request):
887889
if "partial_execution_targets" in json_data:
888890
partial_execution_targets = json_data["partial_execution_targets"]
889891

892+
self.node_replace_manager.apply_replacements(prompt)
893+
890894
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
891895
extra_data = {}
892896
if "extra_data" in json_data:
@@ -995,6 +999,7 @@ def add_routes(self):
995999
self.model_file_manager.add_routes(self.routes)
9961000
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
9971001
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
1002+
self.node_replace_manager.add_routes(self.routes)
9981003
self.app.add_subapp('/internal', self.internal_routes.get_app())
9991004

10001005
# Prefix every route with /api for easier matching for delegation.

0 commit comments

Comments
 (0)