Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sqlalchemy import and_, func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, lazyload
from sqlalchemy.sql import select
from structlog.contextvars import bind_contextvars

Expand Down Expand Up @@ -425,7 +425,7 @@ def ti_update_state(
"Error updating Task Instance state. Setting the task to failed.",
payload=ti_patch_payload,
)
ti = session.get(TI, task_instance_id, with_for_update=True)
ti = session.get(TI, task_instance_id, options=[lazyload(TI.dag_run)], with_for_update=True)
if session.bind is not None:
query = TI.duration_expression_update(timezone.utcnow(), query, session.bind)
query = query.values(state=(updated_state := TaskInstanceState.FAILED))
Expand Down Expand Up @@ -528,7 +528,7 @@ def _create_ti_state_update_query_and_update_state(
dag_id: str,
) -> tuple[Update, TaskInstanceState]:
if isinstance(ti_patch_payload, (TITerminalStatePayload, TIRetryStatePayload, TISuccessStatePayload)):
ti = session.get(TI, task_instance_id, with_for_update=True)
ti = session.get(TI, task_instance_id, options=[lazyload(TI.dag_run)], with_for_update=True)
updated_state = TaskInstanceState(ti_patch_payload.state.value)
if session.bind is not None:
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
Expand Down
57 changes: 40 additions & 17 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@
TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule"
""":meta private:"""

_TRIGGER_TIMEOUT_BATCH_SIZE = 1000
"""Maximum number of task instances to lock per trigger-timeout batch."""


def _eager_load_dag_run_for_validation() -> tuple[LoaderOption, LoaderOption]:
"""
Expand Down Expand Up @@ -2883,25 +2886,45 @@ def check_trigger_timeouts(
self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION
) -> None:
"""Mark any "deferred" task as failed if the trigger or execution timeout has passed."""
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
result = session.execute(
update(TI)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < timezone.utcnow(),
while True:
task_instance_ids = []
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
now = timezone.utcnow()
candidates = (
select(TI.id)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < now,
)
.order_by(TI.id)
.limit(_TRIGGER_TIMEOUT_BATCH_SIZE)
)
.values(
state=TaskInstanceState.SCHEDULED,
next_method=TRIGGER_FAIL_REPR,
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
scheduled_dttm=timezone.utcnow(),
trigger_id=None,
task_instance_ids = list(
session.scalars(
with_row_locks(candidates, of=TI, session=session, skip_locked=True)
).all()
)
)
num_timed_out_tasks = getattr(result, "rowcount", 0)
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)
if task_instance_ids:
result = session.execute(
update(TI)
.where(TI.id.in_(task_instance_ids))
.values(
state=TaskInstanceState.SCHEDULED,
next_method=TRIGGER_FAIL_REPR,
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
scheduled_dttm=now,
trigger_id=None,
)
.execution_options(synchronize_session=False)
)
num_timed_out_tasks = getattr(result, "rowcount", 0)
if num_timed_out_tasks:
self.log.info(
"Timed out %i deferred tasks without fired triggers", num_timed_out_tasks
)
if len(task_instance_ids) < _TRIGGER_TIMEOUT_BATCH_SIZE:
break

# [START find_and_purge_task_instances_without_heartbeats]
def _find_and_purge_task_instances_without_heartbeats(self) -> None:
Expand Down
40 changes: 31 additions & 9 deletions airflow-core/src/airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@

log = logging.getLogger(__name__)

_TRIGGER_ID_CLEANUP_BATCH_SIZE = 1000
"""Maximum number of task instances to lock per trigger-id cleanup batch."""


class TriggerFailureReason(str, Enum):
"""
Expand Down Expand Up @@ -227,16 +230,35 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None:
Triggers have a one-to-many relationship to task instances, so we need to clean those up first.
Afterward we can drop the triggers not referenced by anyone.
"""
# Update all task instances with trigger IDs that are not DEFERRED to remove them
for attempt in run_with_db_retries():
with attempt:
session.execute(
update(TaskInstance)
.where(
TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.is_not(None)
# Clear task-instance trigger references in primary-key order to avoid locking the same rows in
# a different order than scheduler timeout handling.
while True:
task_instance_ids = []
for attempt in run_with_db_retries():
with attempt:
candidates = (
select(TaskInstance.id)
.where(
TaskInstance.state != TaskInstanceState.DEFERRED,
TaskInstance.trigger_id.is_not(None),
)
.order_by(TaskInstance.id)
.limit(_TRIGGER_ID_CLEANUP_BATCH_SIZE)
)
task_instance_ids = list(
session.scalars(
with_row_locks(candidates, of=TaskInstance, session=session, skip_locked=True)
).all()
)
.values(trigger_id=None)
)
if task_instance_ids:
session.execute(
update(TaskInstance)
.where(TaskInstance.id.in_(task_instance_ids))
.values(trigger_id=None)
.execution_options(synchronize_session=False)
)
if len(task_instance_ids) < _TRIGGER_ID_CLEANUP_BATCH_SIZE:
break

# Get all triggers that have no task instances, assets, or callbacks depending on them and delete them
ids = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,48 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta
assert response.status_code == 500
assert response.json()["detail"] == "Database error occurred"

def test_ti_update_state_terminal_does_not_lock_dag_run(self, client, session, create_task_instance):
"""
Regression guard: session.get(TI, pk, with_for_update=True) must use
options=[lazyload(TI.dag_run)] to avoid inadvertently locking dag_run.

TaskInstance.dag_run has lazy="joined", so without the lazyload override the
ORM emits FOR UPDATE on both task_instance and dag_run. The scheduler holds
a dag_run lock while bulk-updating task_instance rows in
_verify_integrity_if_dag_changed, producing a lock-order inversion deadlock.
"""
from sqlalchemy.orm import lazyload

ti = create_task_instance(
task_id="test_ti_update_state_no_dag_run_lock",
state=State.RUNNING,
start_date=DEFAULT_START_DATE,
)
session.commit()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible , could we use this

def create_session(scoped: bool = True) -> Generator[SASession, None, None]:

with create_session so we don't need to commit explicitly ?

WDYT ?

Copy link
Copy Markdown
Author

@Usuychik Usuychik May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can be wrong, cos don`t know so deep airflow stack, but from observation:
Why session.commit() is required here:
The HTTP client that hits the route handler opens its own separate DB session. The create_task_instance data must be committed before the client request so the route's session can see it. This is why every single test in
TestTIUpdateState does:
ti = create_task_instance(...)
session.commit() # ← makes the TI visible to the route's session
response = client.patch(...)

Copy link
Copy Markdown
Contributor

@Prab-27 Prab-27 May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !! will take a look and get back to you

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes ! we need this session.commit() here ! Thanks !!


captured_for_update_calls: list[dict] = []
real_get = Session.get

def spy_get(self, entity, ident, **kwargs):
if kwargs.get("with_for_update"):
captured_for_update_calls.append({"entity": entity, "options": kwargs.get("options") or []})
return real_get(self, entity, ident, **kwargs)

with mock.patch.object(Session, "get", spy_get):
response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={"state": State.SUCCESS, "end_date": DEFAULT_END_DATE.isoformat()},
)
assert response.status_code == 204

ti_for_update_calls = [c for c in captured_for_update_calls if c["entity"] is TaskInstance]
assert ti_for_update_calls, "Expected at least one session.get(TaskInstance, ..., with_for_update=True)"
for call in ti_for_update_calls:
assert any(isinstance(opt, lazyload) for opt in call["options"]), (
"session.get(TaskInstance, ..., with_for_update=True) must pass "
"options=[lazyload(TI.dag_run)] to prevent inadvertent dag_run row lock"
)

@pytest.mark.parametrize("queues_enabled", [False, True])
def test_ti_update_state_to_deferred(
self, client, session, create_task_instance, time_machine, queues_enabled: bool
Expand Down
39 changes: 39 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7123,6 +7123,45 @@ def test_timeout_triggers(self, dag_maker):
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED

def test_timeout_triggers_processes_more_than_one_batch(self, dag_maker, monkeypatch):
"""Timed-out deferred task instances are all updated when they span multiple batches."""
import airflow.jobs.scheduler_job_runner as scheduler_job_runner_module

monkeypatch.setattr(scheduler_job_runner_module, "_TRIGGER_TIMEOUT_BATCH_SIZE", 2)

session = settings.Session()
with dag_maker(
dag_id="test_timeout_triggers_processes_more_than_one_batch",
start_date=DEFAULT_DATE,
schedule="@once",
max_active_runs=5,
session=session,
):
EmptyOperator(task_id="dummy1")

past = timezone.utcnow() - datetime.timedelta(seconds=60)
task_instances = []
for index in range(5):
dag_run = dag_maker.create_dagrun(
run_id=f"test_batch_{index}",
logical_date=DEFAULT_DATE + datetime.timedelta(seconds=index),
)
task_instance = dag_run.get_task_instance("dummy1", session)
task_instance.state = State.DEFERRED
task_instance.trigger_timeout = past
task_instances.append(task_instance)
session.flush()

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

self.job_runner.check_trigger_timeouts(session=session)

for task_instance in task_instances:
session.refresh(task_instance)
assert task_instance.state == State.SCHEDULED
assert task_instance.next_method == "__fail__"

def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker, testing_dag_bundle, session):
"""
Tests that it will retry on DB error like deadlock when updating timeout triggers.
Expand Down
32 changes: 32 additions & 0 deletions airflow-core/tests/unit/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,38 @@ def test_clean_unused(session, dag_maker):
assert {result.id for result in results} == {trigger1.id, trigger4.id, trigger5.id, trigger6.id}


def test_clean_unused_clears_trigger_ids_in_batches(session, dag_maker, monkeypatch):
"""Non-deferred task instances have trigger references cleared when they span multiple batches."""
import airflow.models.trigger as trigger_module

monkeypatch.setattr(trigger_module, "_TRIGGER_ID_CLEANUP_BATCH_SIZE", 2)

triggers = [
Trigger(classpath=f"airflow.triggers.testing.SuccessTrigger{index}", kwargs={}) for index in range(5)
]
session.add_all(triggers)
session.flush()

with dag_maker(session=session, dag_id="test_clean_unused_clears_trigger_ids_in_batches"):
for index in range(5):
EmptyOperator(task_id=f"fake{index}")

dag_run = dag_maker.create_dagrun(logical_date=timezone.utcnow())
task_instances = {task_instance.task_id: task_instance for task_instance in dag_run.task_instances}
for index, trigger in enumerate(triggers):
task_instance = task_instances[f"fake{index}"]
task_instance.state = State.SUCCESS
task_instance.trigger_id = trigger.id
session.flush()

Trigger.clean_unused(session=session)

for task_instance in task_instances.values():
session.refresh(task_instance)
assert task_instance.trigger_id is None
assert session.scalar(select(func.count()).select_from(Trigger)) == 0


@patch.object(TriggererCallback, "handle_event")
def test_submit_event(mock_callback_handle_event, session, create_task_instance):
"""
Expand Down
Loading