Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@
TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule"
""":meta private:"""

_TRIGGER_TIMEOUT_BATCH_SIZE = 1000
"""Maximum number of task instances to lock per trigger-timeout batch."""


def _eager_load_dag_run_for_validation() -> tuple[LoaderOption, LoaderOption]:
"""
Expand Down Expand Up @@ -2878,25 +2881,45 @@ def check_trigger_timeouts(
self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION
) -> None:
"""Mark any "deferred" task as failed if the trigger or execution timeout has passed."""
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
result = session.execute(
update(TI)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < timezone.utcnow(),
while True:
task_instance_ids = []
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
now = timezone.utcnow()
candidates = (
select(TI.id)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < now,
)
.order_by(TI.id)
.limit(_TRIGGER_TIMEOUT_BATCH_SIZE)
)
.values(
state=TaskInstanceState.SCHEDULED,
next_method=TRIGGER_FAIL_REPR,
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
scheduled_dttm=timezone.utcnow(),
trigger_id=None,
task_instance_ids = list(
session.scalars(
with_row_locks(candidates, of=TI, session=session, skip_locked=True)
).all()
)
)
num_timed_out_tasks = getattr(result, "rowcount", 0)
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)
if task_instance_ids:
result = session.execute(
update(TI)
.where(TI.id.in_(task_instance_ids))
.values(
state=TaskInstanceState.SCHEDULED,
next_method=TRIGGER_FAIL_REPR,
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
scheduled_dttm=now,
trigger_id=None,
)
.execution_options(synchronize_session=False)
)
num_timed_out_tasks = getattr(result, "rowcount", 0)
if num_timed_out_tasks:
self.log.info(
"Timed out %i deferred tasks without fired triggers", num_timed_out_tasks
)
if len(task_instance_ids) < _TRIGGER_TIMEOUT_BATCH_SIZE:
break

# [START find_and_purge_task_instances_without_heartbeats]
def _find_and_purge_task_instances_without_heartbeats(self) -> None:
Expand Down
40 changes: 31 additions & 9 deletions airflow-core/src/airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@

log = logging.getLogger(__name__)

_TRIGGER_ID_CLEANUP_BATCH_SIZE = 1000
"""Maximum number of task instances to lock per trigger-id cleanup batch."""


class TriggerFailureReason(str, Enum):
"""
Expand Down Expand Up @@ -226,16 +229,35 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None:
Triggers have a one-to-many relationship to task instances, so we need to clean those up first.
Afterward we can drop the triggers not referenced by anyone.
"""
# Update all task instances with trigger IDs that are not DEFERRED to remove them
for attempt in run_with_db_retries():
with attempt:
session.execute(
update(TaskInstance)
.where(
TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.is_not(None)
# Clear task-instance trigger references in primary-key order to avoid locking the same rows in
# a different order than scheduler timeout handling.
while True:
task_instance_ids = []
for attempt in run_with_db_retries():
with attempt:
candidates = (
select(TaskInstance.id)
.where(
TaskInstance.state != TaskInstanceState.DEFERRED,
TaskInstance.trigger_id.is_not(None),
)
.order_by(TaskInstance.id)
.limit(_TRIGGER_ID_CLEANUP_BATCH_SIZE)
)
task_instance_ids = list(
session.scalars(
with_row_locks(candidates, of=TaskInstance, session=session, skip_locked=True)
).all()
)
.values(trigger_id=None)
)
if task_instance_ids:
session.execute(
update(TaskInstance)
.where(TaskInstance.id.in_(task_instance_ids))
.values(trigger_id=None)
.execution_options(synchronize_session=False)
)
if len(task_instance_ids) < _TRIGGER_ID_CLEANUP_BATCH_SIZE:
break

# Get all triggers that have no task instances, assets, or callbacks depending on them and delete them
ids = (
Expand Down
39 changes: 39 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6831,6 +6831,45 @@ def test_timeout_triggers(self, dag_maker):
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED

def test_timeout_triggers_processes_more_than_one_batch(self, dag_maker, monkeypatch):
"""Timed-out deferred task instances are all updated when they span multiple batches."""
import airflow.jobs.scheduler_job_runner as scheduler_job_runner_module

monkeypatch.setattr(scheduler_job_runner_module, "_TRIGGER_TIMEOUT_BATCH_SIZE", 2)

session = settings.Session()
with dag_maker(
dag_id="test_timeout_triggers_processes_more_than_one_batch",
start_date=DEFAULT_DATE,
schedule="@once",
max_active_runs=5,
session=session,
):
EmptyOperator(task_id="dummy1")

past = timezone.utcnow() - datetime.timedelta(seconds=60)
task_instances = []
for index in range(5):
dag_run = dag_maker.create_dagrun(
run_id=f"test_batch_{index}",
logical_date=DEFAULT_DATE + datetime.timedelta(seconds=index),
)
task_instance = dag_run.get_task_instance("dummy1", session)
task_instance.state = State.DEFERRED
task_instance.trigger_timeout = past
task_instances.append(task_instance)
session.flush()

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

self.job_runner.check_trigger_timeouts(session=session)

for task_instance in task_instances:
session.refresh(task_instance)
assert task_instance.state == State.SCHEDULED
assert task_instance.next_method == "__fail__"

def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker, testing_dag_bundle, session):
"""
Tests that it will retry on DB error like deadlock when updating timeout triggers.
Expand Down
32 changes: 32 additions & 0 deletions airflow-core/tests/unit/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,38 @@ def test_clean_unused(session, dag_maker):
assert {result.id for result in results} == {trigger1.id, trigger4.id, trigger5.id, trigger6.id}


def test_clean_unused_clears_trigger_ids_in_batches(session, dag_maker, monkeypatch):
"""Non-deferred task instances have trigger references cleared when they span multiple batches."""
import airflow.models.trigger as trigger_module

monkeypatch.setattr(trigger_module, "_TRIGGER_ID_CLEANUP_BATCH_SIZE", 2)

triggers = [
Trigger(classpath=f"airflow.triggers.testing.SuccessTrigger{index}", kwargs={}) for index in range(5)
]
session.add_all(triggers)
session.flush()

with dag_maker(session=session, dag_id="test_clean_unused_clears_trigger_ids_in_batches"):
for index in range(5):
EmptyOperator(task_id=f"fake{index}")

dag_run = dag_maker.create_dagrun(logical_date=timezone.utcnow())
task_instances = {task_instance.task_id: task_instance for task_instance in dag_run.task_instances}
for index, trigger in enumerate(triggers):
task_instance = task_instances[f"fake{index}"]
task_instance.state = State.SUCCESS
task_instance.trigger_id = trigger.id
session.flush()

Trigger.clean_unused(session=session)

for task_instance in task_instances.values():
session.refresh(task_instance)
assert task_instance.trigger_id is None
assert session.scalar(select(func.count()).select_from(Trigger)) == 0


@patch.object(TriggererCallback, "handle_event")
def test_submit_event(mock_callback_handle_event, session, create_task_instance):
"""
Expand Down
Loading