|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from aiohttp import web |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING, TypedDict |
| 6 | +if TYPE_CHECKING: |
| 7 | + from comfy_api.latest._io_public import NodeReplace |
| 8 | + |
| 9 | +from comfy_execution.graph_utils import is_link |
| 10 | +import nodes |
| 11 | + |
| 12 | +class NodeStruct(TypedDict): |
| 13 | + inputs: dict[str, str | int | float | bool | tuple[str, int]] |
| 14 | + class_type: str |
| 15 | + _meta: dict[str, str] |
| 16 | + |
| 17 | +def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct: |
| 18 | + new_node_struct = node_struct.copy() |
| 19 | + if empty_inputs: |
| 20 | + new_node_struct["inputs"] = {} |
| 21 | + else: |
| 22 | + new_node_struct["inputs"] = node_struct["inputs"].copy() |
| 23 | + new_node_struct["_meta"] = node_struct["_meta"].copy() |
| 24 | + return new_node_struct |
| 25 | + |
| 26 | + |
| 27 | +class NodeReplaceManager: |
| 28 | + """Manages node replacement registrations.""" |
| 29 | + |
| 30 | + def __init__(self): |
| 31 | + self._replacements: dict[str, list[NodeReplace]] = {} |
| 32 | + |
| 33 | + def register(self, node_replace: NodeReplace): |
| 34 | + """Register a node replacement mapping.""" |
| 35 | + self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace) |
| 36 | + |
| 37 | + def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None: |
| 38 | + """Get replacements for an old node ID.""" |
| 39 | + return self._replacements.get(old_node_id) |
| 40 | + |
| 41 | + def has_replacement(self, old_node_id: str) -> bool: |
| 42 | + """Check if a replacement exists for an old node ID.""" |
| 43 | + return old_node_id in self._replacements |
| 44 | + |
| 45 | + def apply_replacements(self, prompt: dict[str, NodeStruct]): |
| 46 | + connections: dict[str, list[tuple[str, str, int]]] = {} |
| 47 | + need_replacement: set[str] = set() |
| 48 | + for node_number, node_struct in prompt.items(): |
| 49 | + class_type = node_struct["class_type"] |
| 50 | + # need replacement if not in NODE_CLASS_MAPPINGS and has replacement |
| 51 | + if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): |
| 52 | + need_replacement.add(node_number) |
| 53 | + # keep track of connections |
| 54 | + for input_id, input_value in node_struct["inputs"].items(): |
| 55 | + if is_link(input_value): |
| 56 | + conn_number = input_value[0] |
| 57 | + connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1])) |
| 58 | + for node_number in need_replacement: |
| 59 | + node_struct = prompt[node_number] |
| 60 | + class_type = node_struct["class_type"] |
| 61 | + replacements = self.get_replacement(class_type) |
| 62 | + if replacements is None: |
| 63 | + continue |
| 64 | + # just use the first replacement |
| 65 | + replacement = replacements[0] |
| 66 | + new_node_id = replacement.new_node_id |
| 67 | + # if replacement is not a valid node, skip trying to replace it as will only cause confusion |
| 68 | + if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys(): |
| 69 | + continue |
| 70 | + # first, replace node id (class_type) |
| 71 | + new_node_struct = copy_node_struct(node_struct, empty_inputs=True) |
| 72 | + new_node_struct["class_type"] = new_node_id |
| 73 | + # TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema |
| 74 | + # second, replace inputs |
| 75 | + if replacement.input_mapping is not None: |
| 76 | + for input_map in replacement.input_mapping: |
| 77 | + if "set_value" in input_map: |
| 78 | + new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"] |
| 79 | + elif "old_id" in input_map: |
| 80 | + new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]] |
| 81 | + # finalize input replacement |
| 82 | + prompt[node_number] = new_node_struct |
| 83 | + # third, replace outputs |
| 84 | + if replacement.output_mapping is not None: |
| 85 | + # re-mapping outputs requires changing the input values of nodes that receive connections from this one |
| 86 | + if node_number in connections: |
| 87 | + for conns in connections[node_number]: |
| 88 | + conn_node_number, conn_input_id, old_output_idx = conns |
| 89 | + for output_map in replacement.output_mapping: |
| 90 | + if output_map["old_idx"] == old_output_idx: |
| 91 | + new_output_idx = output_map["new_idx"] |
| 92 | + previous_input = prompt[conn_node_number]["inputs"][conn_input_id] |
| 93 | + previous_input[1] = new_output_idx |
| 94 | + |
| 95 | + def as_dict(self): |
| 96 | + """Serialize all replacements to dict.""" |
| 97 | + return { |
| 98 | + k: [v.as_dict() for v in v_list] |
| 99 | + for k, v_list in self._replacements.items() |
| 100 | + } |
| 101 | + |
| 102 | + def add_routes(self, routes): |
| 103 | + @routes.get("/node_replacements") |
| 104 | + async def get_node_replacements(request): |
| 105 | + return web.json_response(self.as_dict()) |
0 commit comments