@@ -166,7 +166,7 @@ class CodeGenerator:
166
166
base_node_class_mappings (Dict): Base mappings of node classes.
167
167
"""
168
168
169
- def __init__ (self , node_class_mappings : Dict , base_node_class_mappings : Dict ):
169
+ def __init__ (self , node_class_mappings : Dict , base_node_class_mappings : Dict , prompt : Dict ):
170
170
"""Initialize the CodeGenerator with given node class mappings.
171
171
172
172
Args:
@@ -175,6 +175,7 @@ def __init__(self, node_class_mappings: Dict, base_node_class_mappings: Dict):
175
175
"""
176
176
self .node_class_mappings = node_class_mappings
177
177
self .base_node_class_mappings = base_node_class_mappings
178
+ self .prompt = prompt
178
179
179
180
def can_be_imported (self , import_name : str ):
180
181
if import_name in self .base_node_class_mappings .keys ():
@@ -195,6 +196,7 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
195
196
Returns:
196
197
str: Generated execution code as a string.
197
198
"""
199
+ include_prompt_data = False
198
200
# Create the necessary data structures to hold imports and generated code
199
201
import_statements , executed_variables , arg_inputs , special_functions_code , code = set (['NODE_CLASS_MAPPINGS' ]), {}, [], [], []
200
202
# This dictionary will store the names of the objects that we have already initialized
@@ -206,8 +208,9 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
206
208
# Generate class definition and inputs from the data
207
209
inputs , class_type = data ['inputs' ], data ['class_type' ]
208
210
211
+ input_types = self .node_class_mappings [class_type ].INPUT_TYPES ()
209
212
missing = []
210
- for i , input in enumerate (self . node_class_mappings [ class_type ]. INPUT_TYPES () .get ("required" , {}).keys ()):
213
+ for i , input in enumerate (input_types .get ("required" , {}).keys ()):
211
214
if input not in inputs :
212
215
input_var = f"{ input } { len (arg_inputs )+ 1 } "
213
216
arg_inputs .append ((input_var , f"Argument { i } , input `{ input } ` for node \\ \" { data ['_meta' ].get ('title' , class_type )} \\ \" id { idx } " ))
@@ -233,14 +236,18 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
233
236
234
237
# Get all possible parameters for class_def
235
238
class_def_params = self .get_function_parameters (getattr (class_def , class_def .FUNCTION ))
239
+ no_params = class_def_params is None
236
240
237
241
# Remove any keyword arguments from **inputs if they are not in class_def_params
238
- inputs = {key : value for key , value in inputs .items () if key in class_def_params }
242
+ inputs = {key : value for key , value in inputs .items () if no_params or key in class_def_params }
239
243
for input , input_var , arg in missing :
240
244
inputs [input ] = {"variable_name" : f"parse_arg(args." + input_var + ")" }
241
245
# Deal with hidden variables
242
- if 'unique_id' in class_def_params :
246
+ if no_params or 'unique_id' in class_def_params :
243
247
inputs ['unique_id' ] = random .randint (1 , 2 ** 64 )
248
+ if no_params or 'prompt' in class_def_params :
249
+ inputs ["prompt" ] = {"variable_name" : "PROMPT_DATA" }
250
+ include_prompt_data = True
244
251
245
252
# Create executed variable and generate code
246
253
executed_variables [idx ] = f'{ self .clean_variable_name (class_type )} _{ idx } '
@@ -261,7 +268,7 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
261
268
code .append (self .create_function_call_code (initialized_objects [class_type ], class_def .FUNCTION , executed_variables [idx ], is_special_function , ** inputs ))
262
269
263
270
# Generate final code by combining imports and code, and wrap them in a main function
264
- final_code = self .assemble_python_code (import_statements , special_functions_code , arg_inputs , code , queue_size , custom_nodes )
271
+ final_code = self .assemble_python_code (import_statements , special_functions_code , arg_inputs , code , queue_size , custom_nodes , include_prompt_data )
265
272
266
273
return final_code
267
274
@@ -304,7 +311,7 @@ def format_arg(self, key: str, value: any) -> str:
304
311
return f'{ key } ={ value ["variable_name" ]} '
305
312
return f'{ key } ={ value } '
306
313
307
- def assemble_python_code (self , import_statements : set , special_functions_code : List [str ], arg_inputs : List [Tuple [str , str ]], code : List [str ], queue_size : int , custom_nodes = False ) -> str :
314
+ def assemble_python_code (self , import_statements : set , special_functions_code : List [str ], arg_inputs : List [Tuple [str , str ]], code : List [str ], queue_size : int , custom_nodes = False , include_prompt_data = True ) -> str :
308
315
"""Generates the final code string.
309
316
310
317
Args:
@@ -349,6 +356,8 @@ def assemble_python_code(self, import_statements: set, special_functions_code: L
349
356
# Define static import statements required for the script
350
357
static_imports = ['import os' , 'import random' , 'import sys' , 'import json' , 'import argparse' , 'import contextlib' , 'from typing import Sequence, Mapping, Any, Union' ,
351
358
'import torch' ] + func_strings + argparse_code
359
+ if include_prompt_data :
360
+ static_imports .append (f'PROMPT_DATA = json.loads({ repr (json .dumps (self .prompt ))} )' )
352
361
# Check if custom nodes should be included
353
362
if custom_nodes :
354
363
static_imports .append (f'\n { inspect .getsource (import_custom_nodes )} \n ' )
@@ -459,7 +468,8 @@ def get_function_parameters(self, func: Callable) -> List:
459
468
signature = inspect .signature (func )
460
469
parameters = {name : param .default if param .default != param .empty else None
461
470
for name , param in signature .parameters .items ()}
462
- return list (parameters .keys ())
471
+ catch_all = any (param .kind == inspect .Parameter .VAR_KEYWORD for param in signature .parameters .values ())
472
+ return list (parameters .keys ()) if not catch_all else None
463
473
464
474
def update_inputs (self , inputs : Dict , executed_variables : Dict ) -> Dict :
465
475
"""Update inputs based on the executed variables.
@@ -542,7 +552,7 @@ def execute(self):
542
552
load_order = load_order_determiner .determine_load_order ()
543
553
544
554
# Step 4: Generate the workflow code
545
- code_generator = CodeGenerator (self .node_class_mappings , self .base_node_class_mappings )
555
+ code_generator = CodeGenerator (self .node_class_mappings , self .base_node_class_mappings , data )
546
556
generated_code = code_generator .generate_workflow (load_order , queue_size = self .queue_size )
547
557
548
558
# Step 5: Write the generated code to a file
0 commit comments