Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better graph validation errors #7614

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 39 additions & 46 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")

def __str__(self):
return f"{self.source.node_id}.{self.source.field} -> {self.destination.node_id}.{self.destination.field}"


def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
Expand Down Expand Up @@ -440,17 +443,19 @@ def validate_self(self) -> None:
self.get_node(edge.destination.node_id),
edge.destination.field,
):
raise InvalidEdgeError(
f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
raise InvalidEdgeError(f"Edge source and target types do not match ({edge})")

# Validate all iterators & collectors
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
for node in self.nodes.values():
if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id):
raise InvalidEdgeError(f"Invalid iterator node {node.id}")
if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id):
raise InvalidEdgeError(f"Invalid collector node {node.id}")
if isinstance(node, IterateInvocation):
err = self._is_iterator_connection_valid(node.id)
if err is not None:
raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}")
if isinstance(node, CollectInvocation):
err = self._is_collector_connection_valid(node.id)
if err is not None:
raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}")

return None

Expand Down Expand Up @@ -491,55 +496,44 @@ def _validate_edge(self, edge: Edge):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
raise InvalidEdgeError(f"One or both nodes don't exist ({edge})")

# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
raise InvalidEdgeError(
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
)
raise InvalidEdgeError(f"Edge already exists ({edge})")

# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
)
raise InvalidEdgeError(f"Edge creates a cycle in the graph ({edge})")

# Validate that the field types are compatible
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
raise InvalidEdgeError(
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
raise InvalidEdgeError(f"Field types are incompatible ({edge})")

# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}")

# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")

# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")

# Validate that we are not connecting collector to iterator (currently unsupported)
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
raise InvalidEdgeError(
f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
raise InvalidEdgeError(f"Cannot connect collector to iterator ({edge})")

# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
if (
Expand All @@ -548,10 +542,9 @@ def _validate_edge(self, edge: Edge):
and not self._is_destination_field_list_of_Any(edge)
and not self._is_destination_field_Any(edge)
):
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
err = self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Collector input type does not match collector output type ({edge}): {err}")

def has_node(self, node_id: str) -> bool:
"""Determines whether or not a node exists in the graph."""
Expand Down Expand Up @@ -634,7 +627,7 @@ def _is_iterator_connection_valid(
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]

Expand All @@ -645,29 +638,29 @@ def _is_iterator_connection_valid(

# Only one input is allowed for iterators
if len(inputs) > 1:
return False
return "Iterator may only have one input edge"

# Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]

# Input type must be a list
if get_origin(input_field) is not list:
return False
return "Iterator input must be a collection"

# Validate that all outputs match the input type
input_field_item_type = get_args(input_field)[0]
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
return False
return "Iterator outputs must connect to an input with a matching type"

return True
return None

def _is_collector_connection_valid(
self,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]

Expand All @@ -692,23 +685,23 @@ def _is_collector_connection_valid(
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return False # There is more than one root type
return "Collector input collection items must be of a single type"

# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore

# Verify that all outputs are lists
if not all(is_list_or_contains_list(f) for f in output_fields):
return False
return "Collector output must connect to a collection input"

# Verify that all outputs match the input type (are a base class or the same class)
if not all(
is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0])
for f in output_fields
):
return False
return "Collector outputs must connect to a collection input with a matching type"

return True
return None

def nx_graph(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the layout of this graph"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { truncate } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';
Expand Down Expand Up @@ -52,15 +52,12 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
const result = zPydanticValidationError.safeParse(response);
if (result.success) {
result.data.data.detail.map((e) => {
const description = truncate(e.msg.replace(/^(Value|Index|Key) error, /i, ''), { length: 256 });
toast({
id: 'QUEUE_BATCH_FAILED',
title: truncate(upperFirst(e.msg), { length: 128 }),
title: t('queue.batchFailedToQueue'),
status: 'error',
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
description,
});
});
} else if (response.status !== 403) {
Expand Down
Loading