Skip to content

Commit 024a9f5

Browse files
committed
fix so that run_blocks can work with inputs in the state
1 parent 005195c commit 024a9f5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/pipelines/custom_pipeline_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,17 +2331,17 @@ def run_blocks(self, state: PipelineState = None, **kwargs):
23312331
if name in input_params:
23322332
state.add_intermediate(name, input_params.pop(name))
23332333

2334-
# Add inputs to state, using defaults if not provided
2334+
# Add inputs to state, using defaults if not provided in the kwargs or the state
2335+
# if same input already in the state, will override it if provided in the kwargs
23352336
for name, default in default_params.items():
23362337
if name in input_params:
23372338
state.add_input(name, input_params.pop(name))
2338-
else:
2339+
elif name not in state.inputs:
23392340
state.add_input(name, default)
23402341

23412342
# Warn about unexpected inputs
23422343
if len(input_params) > 0:
23432344
logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.")
2344-
23452345
# Run the pipeline
23462346
with torch.no_grad():
23472347
for block in self.pipeline_blocks:

0 commit comments

Comments
 (0)