Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions burr/core/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,35 @@ def _stable_app_id_hash(app_id: str, child_key: str) -> str:
return hashlib.sha256(f"{app_id}:{child_key}".encode()).hexdigest()


def _salt_task_app_id(task: "SubGraphTask", sequence_id: Optional[int]) -> "SubGraphTask":
"""Salts the sub-application ID with the parent's sequence_id so that repeated
invocations of the same parallel action within a parent application yield distinct
sub-application IDs.

Without this, sub-app IDs collide across invocations and a cascaded
``state_initializer`` (e.g. from ``initialize_from(...)`` on the parent) will
silently hydrate the prior call's persisted state instead of running the action.
See https://github.com/apache/burr/issues/761.

``sequence_id`` is the parent application's per-step counter, which is incremented
on every action execution -- making it the right discriminator for "which
invocation of this parallel action are we in".

BREAKING (vs versions without this salting): sub-application IDs for any
``TaskBasedParallelAction`` (``MapStates``, ``MapActions``, ``MapActionsAndStates``)
have changed. Sub-app state persisted under the old ID scheme is orphaned --
on the first resume after upgrade, the sub-actions re-execute fresh rather
than load the old persisted result. This is the *fix* for #761; the old
scheme would have silently returned stale data instead.
"""
if sequence_id is None:
return task
task.application_id = hashlib.sha256(
f"{task.application_id}:{sequence_id}".encode()
).hexdigest()
return task


class TaskBasedParallelAction(SingleStepAction):
"""The base class for actions that run a set of tasks in parallel and reduce the results.
This is more power-user mode -- if you need fine-grained control over the set of tasks
Expand Down Expand Up @@ -269,6 +298,11 @@ def _run_and_update():
delete=[item for item in state.keys() if item.startswith("__")]
)
task_generator = self.tasks(state_without_internals, context, run_kwargs)
# Salt sub-app IDs with the parent sequence_id so repeated invocations
# don't collide and silently replay prior persisted state (#761).
task_generator = (
_salt_task_app_id(task, context.sequence_id) for task in task_generator
Copy link
Copy Markdown
Contributor

@skrawcz skrawcz May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to enable configuring this? Is there a use case where you want the prior behavior?

)

def execute_task(task):
return task.run(run_kwargs["__context"])
Expand Down Expand Up @@ -296,6 +330,9 @@ async def state_generator():
This way we run through all of the task generators. These correspond to the task generation capabilities above (the map*/task generation stuff)
"""
all_tasks = await async_utils.arealize(task_generator)
# Salt sub-app IDs with the parent sequence_id so repeated invocations
# don't collide and silently replay prior persisted state (#761).
all_tasks = [_salt_task_app_id(task, context.sequence_id) for task in all_tasks]
coroutines = [item.arun(context) for item in all_tasks]
results = await asyncio.gather(*coroutines)
# TODO -- yield in order...
Expand Down
77 changes: 76 additions & 1 deletion tests/core/test_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,17 @@
MapStates,
RunnableGraph,
SubGraphTask,
SubgraphType,
TaskBasedParallelAction,
_cascade_adapter,
map_reduce_action,
)
from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData
from burr.core.persistence import (
BaseStateLoader,
BaseStateSaver,
InMemoryPersister,
PersistedStateData,
)
from burr.tracking.base import SyncTrackingClient
from burr.visibility import ActionSpan

Expand Down Expand Up @@ -1227,3 +1233,72 @@ def reads(self) -> list[str]:
assert task.state_initializer is not None
assert task.tracker is not None
assert task.state_persister is task.state_initializer # This ensures they're the same


def test_map_states_reexecutes_on_repeated_invocations_with_initializer():
"""Regression test for https://github.com/apache/burr/issues/761.

When a parent application is built with ``initialize_from(...)``, the cascaded
initializer used to hydrate sub-applications by ID. Sub-app IDs were
deterministic in ``(parent_app_id, i, j)`` only, so a second invocation of the
same parallel action collided with the first and silently replayed the prior
persisted state instead of re-running the action.

This asserts that repeated invocations now produce fresh outputs.
"""
counter = {"n": 0}

@action(reads=[], writes=["x"])
def pick(state: State) -> State:
counter["n"] += 1
return state.update(x=counter["n"])

@action(reads=[], writes=[])
def back(state: State) -> State:
return state

class Fan(MapStates):
def action(self, state: State, inputs: Dict[str, Any]) -> SubgraphType:
return pick

def states(
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
) -> Generator[State, None, None]:
for _ in range(3):
yield state

def reduce(self, state: State, results: Generator[State, None, None]) -> State:
return state.update(xs=[s["x"] for s in results])

@property
def reads(self) -> list[str]:
return []

@property
def writes(self) -> list[str]:
return ["xs"]

persister = InMemoryPersister()
app = (
ApplicationBuilder()
.with_actions(fan=Fan(), back=back)
.with_transitions(("fan", "back"), ("back", "fan"))
.with_state_persister(persister)
.initialize_from(
persister,
resume_at_next_action=True,
default_state={},
default_entrypoint="fan",
)
.build()
)
invocations = []
for _ in range(3):
app.run(halt_after=["fan"])
invocations.append(list(app.state["xs"]))
# Each invocation should run the action 3 times, producing strictly increasing
# counter values across invocations. If the bug regresses the same xs would
# appear in every invocation.
assert invocations[0] != invocations[1]
assert invocations[1] != invocations[2]
assert counter["n"] == 9
Loading