Skip to content

Commit a548adf

Browse files
committed
Add prompt data and handle kwargs
1 parent fa4d363 commit a548adf

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

comfyui_to_python.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class CodeGenerator:
166166
base_node_class_mappings (Dict): Base mappings of node classes.
167167
"""
168168

169-
def __init__(self, node_class_mappings: Dict, base_node_class_mappings: Dict):
169+
def __init__(self, node_class_mappings: Dict, base_node_class_mappings: Dict, prompt: Dict):
170170
"""Initialize the CodeGenerator with given node class mappings.
171171
172172
Args:
@@ -175,6 +175,7 @@ def __init__(self, node_class_mappings: Dict, base_node_class_mappings: Dict):
175175
"""
176176
self.node_class_mappings = node_class_mappings
177177
self.base_node_class_mappings = base_node_class_mappings
178+
self.prompt = prompt
178179

179180
def can_be_imported(self, import_name: str):
180181
if import_name in self.base_node_class_mappings.keys():
@@ -195,6 +196,7 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
195196
Returns:
196197
str: Generated execution code as a string.
197198
"""
199+
include_prompt_data = False
198200
# Create the necessary data structures to hold imports and generated code
199201
import_statements, executed_variables, arg_inputs, special_functions_code, code = set(['NODE_CLASS_MAPPINGS']), {}, [], [], []
200202
# This dictionary will store the names of the objects that we have already initialized
@@ -206,8 +208,9 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
206208
# Generate class definition and inputs from the data
207209
inputs, class_type = data['inputs'], data['class_type']
208210

211+
input_types = self.node_class_mappings[class_type].INPUT_TYPES()
209212
missing = []
210-
for i, input in enumerate(self.node_class_mappings[class_type].INPUT_TYPES().get("required", {}).keys()):
213+
for i, input in enumerate(input_types.get("required", {}).keys()):
211214
if input not in inputs:
212215
input_var = f"{input}{len(arg_inputs)+1}"
213216
arg_inputs.append((input_var, f"Argument {i}, input `{input}` for node \\\"{data['_meta'].get('title', class_type)}\\\" id {idx}"))
@@ -233,14 +236,18 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
233236

234237
# Get all possible parameters for class_def
235238
class_def_params = self.get_function_parameters(getattr(class_def, class_def.FUNCTION))
239+
no_params = class_def_params is None
236240

237241
# Remove any keyword arguments from **inputs if they are not in class_def_params
238-
inputs = {key: value for key, value in inputs.items() if key in class_def_params}
242+
inputs = {key: value for key, value in inputs.items() if no_params or key in class_def_params}
239243
for input, input_var, arg in missing:
240244
inputs[input] = {"variable_name": f"parse_arg(args." + input_var + ")"}
241245
# Deal with hidden variables
242-
if 'unique_id' in class_def_params:
246+
if no_params or 'unique_id' in class_def_params:
243247
inputs['unique_id'] = random.randint(1, 2**64)
248+
if no_params or 'prompt' in class_def_params:
249+
inputs["prompt"] = {"variable_name": "PROMPT_DATA"}
250+
include_prompt_data = True
244251

245252
# Create executed variable and generate code
246253
executed_variables[idx] = f'{self.clean_variable_name(class_type)}_{idx}'
@@ -261,7 +268,7 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
261268
code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs))
262269

263270
# Generate final code by combining imports and code, and wrap them in a main function
264-
final_code = self.assemble_python_code(import_statements, special_functions_code, arg_inputs, code, queue_size, custom_nodes)
271+
final_code = self.assemble_python_code(import_statements, special_functions_code, arg_inputs, code, queue_size, custom_nodes, include_prompt_data)
265272

266273
return final_code
267274

@@ -304,7 +311,7 @@ def format_arg(self, key: str, value: any) -> str:
304311
return f'{key}={value["variable_name"]}'
305312
return f'{key}={value}'
306313

307-
def assemble_python_code(self, import_statements: set, special_functions_code: List[str], arg_inputs: List[Tuple[str, str]], code: List[str], queue_size: int, custom_nodes=False) -> str:
314+
def assemble_python_code(self, import_statements: set, special_functions_code: List[str], arg_inputs: List[Tuple[str, str]], code: List[str], queue_size: int, custom_nodes=False, include_prompt_data=True) -> str:
308315
"""Generates the final code string.
309316
310317
Args:
@@ -349,6 +356,8 @@ def assemble_python_code(self, import_statements: set, special_functions_code: L
349356
# Define static import statements required for the script
350357
static_imports = ['import os', 'import random', 'import sys', 'import json', 'import argparse', 'import contextlib', 'from typing import Sequence, Mapping, Any, Union',
351358
'import torch'] + func_strings + argparse_code
359+
if include_prompt_data:
360+
static_imports.append(f'PROMPT_DATA = json.loads({repr(json.dumps(self.prompt))})')
352361
# Check if custom nodes should be included
353362
if custom_nodes:
354363
static_imports.append(f'\n{inspect.getsource(import_custom_nodes)}\n')
@@ -459,7 +468,8 @@ def get_function_parameters(self, func: Callable) -> List:
459468
signature = inspect.signature(func)
460469
parameters = {name: param.default if param.default != param.empty else None
461470
for name, param in signature.parameters.items()}
462-
return list(parameters.keys())
471+
catch_all = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values())
472+
return list(parameters.keys()) if not catch_all else None
463473

464474
def update_inputs(self, inputs: Dict, executed_variables: Dict) -> Dict:
465475
"""Update inputs based on the executed variables.
@@ -542,7 +552,7 @@ def execute(self):
542552
load_order = load_order_determiner.determine_load_order()
543553

544554
# Step 4: Generate the workflow code
545-
code_generator = CodeGenerator(self.node_class_mappings, self.base_node_class_mappings)
555+
code_generator = CodeGenerator(self.node_class_mappings, self.base_node_class_mappings, data)
546556
generated_code = code_generator.generate_workflow(load_order, queue_size=self.queue_size)
547557

548558
# Step 5: Write the generated code to a file

0 commit comments

Comments
 (0)