Skip to content

Commit

Permalink
Merge pull request #49 from tjni/update-tests
Browse files Browse the repository at this point in the history
Update Pregel tests.
  • Loading branch information
tjni authored Jan 20, 2025
2 parents 3127af8 + b360b89 commit a44e78e
Show file tree
Hide file tree
Showing 4 changed files with 559 additions and 51 deletions.
261 changes: 261 additions & 0 deletions langgraph-tests/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2398,3 +2398,264 @@ def second_node(state: State):
# Verify the error was recorded in checkpoint
failed_checkpoint = next(c for c in history if c.tasks and c.tasks[0].error)
assert "RuntimeError('Simulated failure')" in failed_checkpoint.tasks[0].error


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multiple_subgraphs(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
a: int
b: int

class Output(TypedDict):
result: int

# Define the subgraphs
def add(state):
return {"result": state["a"] + state["b"]}

add_subgraph = (
StateGraph(State, output=Output).add_node(add).add_edge(START, "add").compile()
)

def multiply(state):
return {"result": state["a"] * state["b"]}

multiply_subgraph = (
StateGraph(State, output=Output)
.add_node(multiply)
.add_edge(START, "multiply")
.compile()
)

# Test calling the same subgraph multiple times
def call_same_subgraph(state):
result = add_subgraph.invoke(state)
another_result = add_subgraph.invoke({"a": result["result"], "b": 10})
return another_result

parent_call_same_subgraph = (
StateGraph(State, output=Output)
.add_node(call_same_subgraph)
.add_edge(START, "call_same_subgraph")
.compile(checkpointer=checkpointer)
)
config = {"configurable": {"thread_id": "1"}}
assert parent_call_same_subgraph.invoke({"a": 2, "b": 3}, config) == {"result": 15}

# Test calling multiple subgraphs
class Output(TypedDict):
add_result: int
multiply_result: int

def call_multiple_subgraphs(state):
add_result = add_subgraph.invoke(state)
multiply_result = multiply_subgraph.invoke(state)
return {
"add_result": add_result["result"],
"multiply_result": multiply_result["result"],
}

parent_call_multiple_subgraphs = (
StateGraph(State, output=Output)
.add_node(call_multiple_subgraphs)
.add_edge(START, "call_multiple_subgraphs")
.compile(checkpointer=checkpointer)
)
config = {"configurable": {"thread_id": "2"}}
assert parent_call_multiple_subgraphs.invoke({"a": 2, "b": 3}, config) == {
"add_result": 5,
"multiply_result": 6,
}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multiple_subgraphs_functional(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

# Define addition subgraph
@entrypoint()
def add(inputs):
a, b = inputs
return a + b

# Define multiplication subgraph using tasks
@task
def multiply_task(a, b):
return a * b

@entrypoint()
def multiply(inputs):
return multiply_task(*inputs).result()

# Test calling the same subgraph multiple times
@task
def call_same_subgraph(a, b):
result = add.invoke([a, b])
another_result = add.invoke([result, 10])
return another_result

@entrypoint(checkpointer=checkpointer)
def parent_call_same_subgraph(inputs):
return call_same_subgraph(*inputs).result()

config = {"configurable": {"thread_id": "1"}}
assert parent_call_same_subgraph.invoke([2, 3], config) == 15

# Test calling multiple subgraphs
@task
def call_multiple_subgraphs(a, b):
add_result = add.invoke([a, b])
multiply_result = multiply.invoke([a, b])
return [add_result, multiply_result]

@entrypoint(checkpointer=checkpointer)
def parent_call_multiple_subgraphs(inputs):
return call_multiple_subgraphs(*inputs).result()

config = {"configurable": {"thread_id": "2"}}
assert parent_call_multiple_subgraphs.invoke([2, 3], config) == [5, 6]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multiple_subgraphs_mixed(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
a: int
b: int

class Output(TypedDict):
result: int

# Define the subgraphs
def add(state):
return {"result": state["a"] + state["b"]}

add_subgraph = (
StateGraph(State, output=Output).add_node(add).add_edge(START, "add").compile()
)

def multiply(state):
return {"result": state["a"] * state["b"]}

multiply_subgraph = (
StateGraph(State, output=Output)
.add_node(multiply)
.add_edge(START, "multiply")
.compile()
)

# Test calling the same subgraph multiple times
@task
def call_same_subgraph(a, b):
result = add_subgraph.invoke({"a": a, "b": b})["result"]
another_result = add_subgraph.invoke({"a": result, "b": 10})["result"]
return another_result

@entrypoint(checkpointer=checkpointer)
def parent_call_same_subgraph(inputs):
return call_same_subgraph(*inputs).result()

config = {"configurable": {"thread_id": "1"}}
assert parent_call_same_subgraph.invoke([2, 3], config) == 15

# Test calling multiple subgraphs
@task
def call_multiple_subgraphs(a, b):
add_result = add_subgraph.invoke({"a": a, "b": b})["result"]
multiply_result = multiply_subgraph.invoke({"a": a, "b": b})["result"]
return [add_result, multiply_result]

@entrypoint(checkpointer=checkpointer)
def parent_call_multiple_subgraphs(inputs):
return call_multiple_subgraphs(*inputs).result()

config = {"configurable": {"thread_id": "2"}}
assert parent_call_multiple_subgraphs.invoke([2, 3], config) == [5, 6]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multiple_subgraphs_mixed_checkpointer(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class SubgraphState(TypedDict):
sub_counter: Annotated[int, operator.add]

def subgraph_node(state):
return {"sub_counter": 2}

sub_graph_1 = (
StateGraph(SubgraphState)
.add_node(subgraph_node)
.add_edge(START, "subgraph_node")
.compile(checkpointer=True)
)

class OtherSubgraphState(TypedDict):
other_sub_counter: Annotated[int, operator.add]

def other_subgraph_node(state):
return {"other_sub_counter": 3}

sub_graph_2 = (
StateGraph(OtherSubgraphState)
.add_node(other_subgraph_node)
.add_edge(START, "other_subgraph_node")
.compile()
)

class ParentState(TypedDict):
parent_counter: int

def parent_node(state):
result = sub_graph_1.invoke({"sub_counter": state["parent_counter"]})
other_result = sub_graph_2.invoke({"other_sub_counter": result["sub_counter"]})
return {"parent_counter": other_result["other_sub_counter"]}

parent_graph = (
StateGraph(ParentState)
.add_node(parent_node)
.add_edge(START, "parent_node")
.compile(checkpointer=checkpointer)
)

config = {"configurable": {"thread_id": "1"}}
assert parent_graph.invoke({"parent_counter": 0}, config) == {"parent_counter": 5}
assert parent_graph.invoke({"parent_counter": 0}, config) == {"parent_counter": 7}
config = {"configurable": {"thread_id": "2"}}
assert [
c
for c in parent_graph.stream(
{"parent_counter": 0}, config, subgraphs=True, stream_mode="updates"
)
] == [
(("parent_node",), {"subgraph_node": {"sub_counter": 2}}),
(
(AnyStr("parent_node:"), "1"),
{"other_subgraph_node": {"other_sub_counter": 3}},
),
((), {"parent_node": {"parent_counter": 5}}),
]
assert [
c
for c in parent_graph.stream(
{"parent_counter": 0}, config, subgraphs=True, stream_mode="updates"
)
] == [
(("parent_node",), {"subgraph_node": {"sub_counter": 2}}),
(
(AnyStr("parent_node:"), "1"),
{"other_subgraph_node": {"other_sub_counter": 3}},
),
((), {"parent_node": {"parent_counter": 7}}),
]
Loading

0 comments on commit a44e78e

Please sign in to comment.