Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
28970c0
refactor: simplify docstrings in workflow schemas
srijanpatel Feb 22, 2025
1d03a9e
refactor: improve docstring formatting in base node classes
srijanpatel Feb 22, 2025
c87116a
Merge remote-tracking branch 'origin/main' into refactor/base-node-si…
srijanpatel Feb 23, 2025
25731c7
refactor: enhance BaseNode config initialization with flexible kwargs…
srijanpatel Feb 24, 2025
221933b
feat: add NodeMetaclass for dynamic config initialization in BaseNode
srijanpatel Feb 24, 2025
a474348
refactor: improve node factory type handling and configuration
srijanpatel Feb 24, 2025
98455fe
refactor: simplify BaseNode configuration and remove NodeMetaclass
srijanpatel Feb 24, 2025
e5b87cd
refactor: redesign BaseNode with Pydantic model
srijanpatel Feb 25, 2025
5e1d1d2
refactor: improve BaseNode initialization and type handling
srijanpatel Feb 25, 2025
1ca1ca8
refactor: update ExampleNode with simplified BaseNode implementation
srijanpatel Feb 25, 2025
93438dd
refactor: improve BaseNode and pydantic_utils with default input/outp…
srijanpatel Feb 25, 2025
1d77fb8
refactor: generalize BaseNode input and output model types
srijanpatel Feb 25, 2025
7c54a5f
refactor: update ExampleNode to use explicit input and output models
srijanpatel Feb 25, 2025
5378232
refactor: enhance BaseNode with flexible input and output model confi…
srijanpatel Feb 25, 2025
fc5726d
feat: add input and output models to ExampleNode
srijanpatel Feb 25, 2025
28de0df
refactor: only allow model classes as input and output models
srijanpatel Feb 26, 2025
91fb42a
feat: add node_function decorator for dynamic node creation
srijanpatel Mar 1, 2025
8fcabd0
refactor: rename BaseNode to Tool and update related terminology
srijanpatel Mar 1, 2025
e7f7977
feat: add decorator-based tool and enhance example tool with message …
srijanpatel Mar 1, 2025
244a2f6
refactor: modernize SingleLLMCallNode with Tool base class and improv…
srijanpatel Mar 1, 2025
d0f1c8c
refactor: modernize LLMFunctionCallNode and RetrieverNode with Tool b…
srijanpatel Mar 1, 2025
5cf6f50
refactor: modernize CoalesceNode, MergeNode, and RouterNode with Tool…
srijanpatel Mar 1, 2025
2e7458a
refactor: modernize PythonFuncNode with Tool base class and dynamic o…
srijanpatel Mar 1, 2025
32c25b6
refactor: update NodeRegistry to use Tool base class and simplify nod…
srijanpatel Mar 1, 2025
525cd7c
refactor: modernize FirecrawlCrawl and FirecrawlScrape nodes with Too…
srijanpatel Mar 1, 2025
ecc2036
refactor: modernize SlackNotifyNode with Tool base class and simplifi…
srijanpatel Mar 1, 2025
01ef9be
refactor: update NodeFactory to use Tool base class and type casting
srijanpatel Mar 1, 2025
b181352
refactor: update WorkflowExecutor to use Tool base class and type hints
srijanpatel Mar 1, 2025
66a7f6f
refactor: modernize BaseLoopSubworkflowNode with Tool base class and …
srijanpatel Mar 1, 2025
df0d4eb
refactor: modernize BaseSubworkflowNode with Tool base class and dyna…
srijanpatel Mar 1, 2025
b6506ba
refactor: modernize InputNode and OutputNode with Tool base class and…
srijanpatel Mar 1, 2025
12ad7f9
refactor: modernize ForLoopNode with Tool base class and simplified c…
srijanpatel Mar 1, 2025
ea005d9
Merge remote-tracking branch 'origin/main' into refactor/base-node-si…
srijanpatel Mar 4, 2025
ffa544f
refactor: refactor HI node with updated base Tool class
srijanpatel Mar 4, 2025
7d3a308
Merge remote-tracking branch 'origin/main' into refactor/base-node-si…
srijanpatel Mar 12, 2025
b4bb740
fix: update user message in ChatInput validation and add context to m…
srijanpatel Mar 12, 2025
941889b
docs: improve docstring formatting and clarity in tool decorator
srijanpatel Mar 13, 2025
eb86bb0
Merge remote-tracking branch 'origin/main' into refactor/base-node-si…
srijanpatel Mar 13, 2025
29cb8aa
Merge remote-tracking branch 'origin/main' into refactor/base-node-si…
srijanpatel Mar 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 55 additions & 48 deletions backend/pyspur/execution/workflow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,26 @@
import traceback
from collections import defaultdict, deque
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar, Union

from pydantic import ValidationError
from pydantic import BaseModel, ValidationError

from ..models.run_model import RunModel, RunStatus
from ..models.task_model import TaskStatus
from ..models.user_session_model import MessageModel, SessionModel
from ..models.workflow_model import WorkflowModel
from ..nodes.base import BaseNode, BaseNodeOutput
from ..nodes.base import Tool
from ..nodes.factory import NodeFactory
from ..nodes.logic.human_intervention import PauseException
from ..schemas.workflow_schemas import (
SpurType,
WorkflowDefinitionSchema,
WorkflowNodeSchema,
)
from .task_recorder import TaskRecorder
from .task_recorder import TaskRecorder, TaskStatus
from .workflow_execution_context import WorkflowExecutionContext

# Define a type variable for the output of a node
T = TypeVar("T", bound=BaseModel)
if TYPE_CHECKING:
from .task_recorder import TaskRecorder

Expand Down Expand Up @@ -59,22 +60,23 @@ def __init__(
self.task_recorder = None
self.context = context
self._node_dict: Dict[str, WorkflowNodeSchema] = {}
self.node_instances: Dict[str, BaseNode] = {}
self.node_instances: Dict[str, Tool] = {}
self._dependencies: Dict[str, Set[str]] = {}
self._node_tasks: Dict[str, asyncio.Task[Optional[BaseNodeOutput]]] = {}
self._outputs: Dict[str, Optional[BaseNodeOutput]] = {}
self._node_tasks: Dict[str, asyncio.Task[Optional[BaseModel]]] = {}
self._initial_inputs: Dict[str, Dict[str, Any]] = {}
self._outputs: Dict[str, Optional[BaseModel]] = {}
self._failed_nodes: Set[str] = set()
self._resumed_node_ids: Set[str] = set(resumed_node_ids or [])
self._build_node_dict()
self._build_dependencies()

@property
def outputs(self) -> Dict[str, Optional[BaseNodeOutput]]:
def outputs(self) -> Dict[str, Optional[BaseModel]]:
"""Get the current outputs of the workflow execution."""
return self._outputs

@outputs.setter
def outputs(self, value: Dict[str, Optional[BaseNodeOutput]]):
def outputs(self, value: Dict[str, Optional[BaseModel]]):
"""Set the outputs of the workflow execution."""
self._outputs = value

Expand Down Expand Up @@ -157,9 +159,7 @@ def _get_source_handles(self) -> Dict[Tuple[str, str], str]:
source_handles[(link.source_id, link.target_id)] = link.source_handle
return source_handles

def _get_async_task_for_node_execution(
self, node_id: str
) -> asyncio.Task[Optional[BaseNodeOutput]]:
def _get_async_task_for_node_execution(self, node_id: str) -> asyncio.Task[Optional[BaseModel]]:
if node_id in self._node_tasks:
return self._node_tasks[node_id]
# Start task for the node
Expand Down Expand Up @@ -295,9 +295,7 @@ def _store_message_history(
self.context.db_session.add(assistant_msg)
self.context.db_session.commit()

def _mark_node_as_paused(
self, node_id: str, pause_output: Optional[BaseNodeOutput] = None
) -> None:
def _mark_node_as_paused(self, node_id: str, pause_output: Optional[BaseModel] = None) -> None:
"""Mark a node as paused and store its output."""
# Store the output
self._outputs[node_id] = pause_output
Expand Down Expand Up @@ -423,7 +421,7 @@ def _fix_canceled_tasks_after_pause(self, paused_node_id: str) -> None:
)
self.context.db_session.commit()

async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]: # noqa: C901
async def _execute_node(self, node_id: str) -> Optional[BaseModel]: # noqa: C901
node = self._node_dict[node_id]
node_input = {}
try:
Expand Down Expand Up @@ -461,7 +459,7 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]: # noqa
dependency_ids = self._dependencies.get(node_id, set())

# Wait for dependencies
predecessor_outputs: List[Optional[BaseNodeOutput]] = []
predecessor_outputs: List[Optional[BaseModel]] = []
if dependency_ids:
try:
predecessor_outputs = await asyncio.gather(
Expand Down Expand Up @@ -606,13 +604,21 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]: # noqa
self._outputs[node_id] = None
raise UnconnectedNodeError(f"Node {node_id} has no input")

node_instance = NodeFactory.create_node(
# Create the node instance
node_instance: Tool = NodeFactory.create_node(
node_name=node.title,
node_type_name=node.node_type,
config=node.config,
)
self.node_instances[node_id] = node_instance

# Update task recorder
if self.task_recorder:
self.task_recorder.update_task(
node_id=node_id,
status=TaskStatus.RUNNING,
subworkflow=getattr(node_instance, "subworkflow", None),
)
# Set workflow definition in node context if available
if hasattr(node_instance, "context"):
node_instance.context = WorkflowExecutionContext(
Expand All @@ -624,27 +630,27 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]: # noqa
workflow_definition=self.workflow.model_dump(),
)

try:
output = await node_instance(node_input)
# Execute node
output: Optional[BaseModel] = await node_instance(node_input)

# Update task recorder
if self.task_recorder:
self.task_recorder.update_task(
node_id=node_id,
status=TaskStatus.COMPLETED,
outputs=self._serialize_output(output),
end_time=datetime.now(),
subworkflow=node_instance.subworkflow,
subworkflow_output=node_instance.subworkflow_output,
)
# Update task recorder
if self.task_recorder:
output_dict: Dict[str, Any] = {}
if hasattr(output, "model_dump"):
output_dict = output.model_dump()

# Store output
self._outputs[node_id] = output
return output
except PauseException as e:
self._handle_pause_exception(node_id, e)
# Return None to prevent downstream execution
return None
self.task_recorder.update_task(
node_id=node_id,
status=TaskStatus.COMPLETED,
outputs=output_dict,
end_time=datetime.now(),
subworkflow=getattr(node_instance, "subworkflow", None),
subworkflow_output=getattr(node_instance, "subworkflow_output", None),
)
except PauseException as e:
self._handle_pause_exception(node_id, e)
# Return None to prevent downstream execution
return None

except UpstreamFailureError as e:
self._failed_nodes.add(node_id)
Expand Down Expand Up @@ -705,7 +711,7 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]: # noqa
)
raise e

def _serialize_output(self, output: Optional[BaseNodeOutput]) -> Optional[Dict[str, Any]]:
def _serialize_output(self, output: Optional[BaseModel]) -> Optional[Dict[str, Any]]:
"""Serialize node outputs, handling datetime objects."""
if output is None:
return None
Expand All @@ -731,17 +737,18 @@ async def _execute_workflow( # noqa: C901
input: Dict[str, Any] = {},
node_ids: List[str] = [],
precomputed_outputs: Dict[str, Dict[str, Any] | List[Dict[str, Any]]] = {},
) -> Dict[str, BaseNodeOutput]:
) -> Dict[str, BaseModel]:
# Handle precomputed outputs first
if precomputed_outputs:
for node_id, output in precomputed_outputs.items():
try:
if isinstance(output, dict):
self._outputs[node_id] = NodeFactory.create_node(
node_instance: Tool = NodeFactory.create_node(
node_name=self._node_dict[node_id].title,
node_type_name=self._node_dict[node_id].node_type,
config=self._node_dict[node_id].config,
).output_model.model_validate(output)
)
self._outputs[node_id] = node_instance.output_model.model_validate(output)
else:
# If output is a list of dicts, do not validate the output
# these are outputs of loop nodes,
Expand Down Expand Up @@ -774,7 +781,7 @@ async def _execute_workflow( # noqa: C901
)
self._initial_inputs[input_node.id] = input
# also update outputs for input node
input_node_obj = NodeFactory.create_node(
input_node_obj: Tool = NodeFactory.create_node(
node_name=input_node.title,
node_type_name=input_node.node_type,
config=input_node.config,
Expand Down Expand Up @@ -873,7 +880,7 @@ async def run(
input: Dict[str, Any] = {},
node_ids: List[str] = [],
precomputed_outputs: Dict[str, Dict[str, Any] | List[Dict[str, Any]]] = {},
) -> Dict[str, BaseNodeOutput]:
) -> Dict[str, BaseModel]:
# For chatbot workflows, extract and inject message history
if self.workflow.spur_type == SpurType.CHATBOT:
session_id = input.get("session_id")
Expand Down Expand Up @@ -921,7 +928,7 @@ async def __call__(
input: Dict[str, Any] = {},
node_ids: List[str] = [],
precomputed_outputs: Dict[str, Dict[str, Any] | List[Dict[str, Any]]] = {},
) -> Dict[str, BaseNodeOutput]:
) -> Dict[str, BaseModel]:
"""Execute the workflow with the given input data.

input: input for the input node of the workflow. Dict[<field_name>: <value>]
Expand All @@ -933,10 +940,10 @@ async def __call__(

async def run_batch(
self, input_iterator: Iterator[Dict[str, Any]], batch_size: int = 100
) -> List[Dict[str, BaseNodeOutput]]:
) -> List[Dict[str, BaseModel]]:
"""Run the workflow on a batch of inputs."""
results: List[Dict[str, BaseNodeOutput]] = []
batch_tasks: List[asyncio.Task[Dict[str, BaseNodeOutput]]] = []
results: List[Dict[str, BaseModel]] = []
batch_tasks: List[asyncio.Task[Dict[str, BaseModel]]] = []
for input in input_iterator:
batch_tasks.append(asyncio.create_task(self.run(input)))
if len(batch_tasks) == batch_size:
Expand Down
Loading