Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 62 additions & 24 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import anyio
import attrs
import structlog
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import func, select
from structlog.contextvars import bind_contextvars as bind_log_contextvars
Expand Down Expand Up @@ -87,6 +90,7 @@
from airflow.utils.session import provide_session

if TYPE_CHECKING:
from opentelemetry.util._decorator import _AgnosticContextManager
from sqlalchemy.orm import Session
from structlog.typing import FilteringBoundLogger, WrappedLogger

Expand All @@ -96,6 +100,33 @@
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI

logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)


def _prepare_span(
ti: TaskInstanceDTO | None, trigger_id: int, name: str
) -> _AgnosticContextManager[trace.Span]:
parent_context = (
TraceContextTextMapPropagator().extract(ti.context_carrier) if ti and ti.context_carrier else None
)
span_name = f"trigger_run.{ti.task_id}" if ti else f"trigger_run.{trigger_id}"
if ti and ti.map_index >= 0:
span_name += f"_{ti.map_index}"
attributes: dict[str, str | int] = {
"airflow.trigger.name": name,
}
if ti:
attributes = {
**attributes,
"airflow.dag_id": ti.dag_id,
"airflow.task_id": ti.task_id,
"airflow.dag_run.run_id": ti.run_id,
"airflow.task_instance.try_number": ti.try_number,
"airflow.task_instance.map_index": ti.map_index,
}

return tracer.start_as_current_span(span_name, attributes=attributes, context=parent_context)


__all__ = [
"TriggerRunner",
Expand Down Expand Up @@ -694,6 +725,7 @@ def update_triggers(self, requested_trigger_ids: set[int]):
ti=ser_ti, # type: ignore
)

ser_ti.context_carrier = new_trigger_orm.task_instance.dag_run.context_carrier
workload.ti = ser_ti
workload.timeout_after = new_trigger_orm.task_instance.trigger_timeout

Expand Down Expand Up @@ -1179,30 +1211,36 @@ async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after

name = self.triggers[trigger_id]["name"]
self.log.info("trigger %s starting", name)
try:
async for event in trigger.run():
await self.log.ainfo(
"Trigger fired event", name=self.triggers[trigger_id]["name"], result=event
)
self.triggers[trigger_id]["events"] += 1
self.events.append((trigger_id, event))
except asyncio.CancelledError:
# We get cancelled by the scheduler changing the task state. But if we do lets give a nice error
# message about it
if timeout := timeout_after:
timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout
if timeout < timezone.utcnow():
await self.log.aerror("Trigger cancelled due to timeout")
raise
finally:
# CancelledError will get injected when we're stopped - which is
# fine, the cleanup process will understand that, but we want to
# allow triggers a chance to cleanup, either in that case or if
# they exit cleanly. Exception from cleanup methods are ignored.
with suppress(Exception):
await trigger.cleanup()

await self.log.ainfo("trigger completed", name=name)

with _prepare_span(ti=trigger.task_instance, trigger_id=trigger_id, name=name) as span:
try:
async for event in trigger.run():
await self.log.ainfo(
event="Trigger fired event",
name=name,
result=event,
)
self.triggers[trigger_id]["events"] += 1
self.events.append((trigger_id, event))
span.set_status(Status(StatusCode.OK))
except asyncio.CancelledError as e:
# We get cancelled by the scheduler changing the task state. But if we do lets give a nice error
# message about it
if timeout := timeout_after:
timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout
if timeout < timezone.utcnow():
await self.log.aerror("Trigger cancelled due to timeout")
span.set_status(Status(StatusCode.ERROR), description=str(e))
raise
finally:
# CancelledError will get injected when we're stopped - which is
# fine, the cleanup process will understand that, but we want to
# allow triggers a chance to cleanup, either in that case or if
# they exit cleanly. Exception from cleanup methods are ignored.
with suppress(Exception):
await trigger.cleanup()

await self.log.ainfo("trigger completed", name=name)

def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]:
"""
Expand Down
148 changes: 148 additions & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
TriggerLoggingFactory,
TriggerRunner,
TriggerRunnerSupervisor,
_prepare_span,
messages,
)
from airflow.models import Connection, DagModel, DagRun, Trigger, Variable
Expand Down Expand Up @@ -347,6 +348,7 @@ def test_run_inline_trigger_canceled(self, session) -> None:
1: {"task": MagicMock(spec=asyncio.Task), "is_watcher": False, "name": "mock_name", "events": 0}
}
mock_trigger = MagicMock(spec=BaseTrigger)
mock_trigger.task_instance = None
mock_trigger.timeout_after = None
mock_trigger.run.side_effect = asyncio.CancelledError()

Expand All @@ -360,6 +362,7 @@ def test_run_inline_trigger_timeout(self, session, cap_structlog) -> None:
1: {"task": MagicMock(spec=asyncio.Task), "is_watcher": False, "name": "mock_name", "events": 0}
}
mock_trigger = MagicMock(spec=BaseTrigger)
mock_trigger.task_instance = None
mock_trigger.run.side_effect = asyncio.CancelledError()

with pytest.raises(asyncio.CancelledError):
Expand Down Expand Up @@ -500,6 +503,151 @@ async def asend_side_effect(msg):
assert len(second_call.events) == 2


class TestPrepareSpan:
"""Tests for _prepare_span which creates OTel spans for trigger execution."""

def _make_ti(self, **overrides):
from airflow.executors.workloads.task import TaskInstanceDTO

defaults = {
"id": "00000000-0000-0000-0000-000000000001",
"dag_version_id": "00000000-0000-0000-0000-000000000002",
"task_id": "my_task",
"dag_id": "my_dag",
"run_id": "run_1",
"try_number": 1,
"map_index": -1,
"pool_slots": 1,
"queue": "default",
"priority_weight": 1,
"context_carrier": None,
}
defaults.update(overrides)
return TaskInstanceDTO(**defaults)

def test_span_name_from_task_id(self):
"""Span name should be derived from task_id when TI is provided."""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
test_tracer = provider.get_tracer("test")

ti = self._make_ti()
with patch("airflow.jobs.triggerer_job_runner.tracer", test_tracer):
with _prepare_span(ti=ti, trigger_id=99, name="test_trigger"):
pass

spans = exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].name == "trigger_run.my_task"

def test_span_name_includes_map_index(self):
"""Span name should include map_index suffix for mapped tasks."""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
test_tracer = provider.get_tracer("test")

ti = self._make_ti(map_index=3)
with patch("airflow.jobs.triggerer_job_runner.tracer", test_tracer):
with _prepare_span(ti=ti, trigger_id=99, name="test_trigger"):
pass

spans = exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].name == "trigger_run.my_task_3"

def test_span_attributes(self):
"""Span should have correct airflow attributes."""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
test_tracer = provider.get_tracer("test")

ti = self._make_ti()
with patch("airflow.jobs.triggerer_job_runner.tracer", test_tracer):
with _prepare_span(ti=ti, trigger_id=99, name="my_dag/run_1/my_task/-1/1 (ID 5)"):
pass

spans = exporter.get_finished_spans()
assert len(spans) == 1
attrs = dict(spans[0].attributes)
assert attrs == {
"airflow.dag_id": "my_dag",
"airflow.task_id": "my_task",
"airflow.dag_run.run_id": "run_1",
"airflow.task_instance.try_number": 1,
"airflow.task_instance.map_index": -1,
"airflow.trigger.name": "my_dag/run_1/my_task/-1/1 (ID 5)",
}

def test_span_inherits_parent_trace_from_context_carrier(self):
"""When context_carrier is set, the span should be a child of that trace."""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
test_tracer = provider.get_tracer("test")

# Create a parent span and extract its context_carrier
parent_carrier: dict[str, str] = {}
with test_tracer.start_as_current_span("parent_dag_run") as parent_span:
TraceContextTextMapPropagator().inject(parent_carrier)
parent_trace_id = parent_span.get_span_context().trace_id

exporter.clear()

ti = self._make_ti(context_carrier=parent_carrier)
with patch("airflow.jobs.triggerer_job_runner.tracer", test_tracer):
with _prepare_span(ti=ti, trigger_id=99, name="test_trigger"):
pass

spans = exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].context.trace_id == parent_trace_id

def test_span_without_context_carrier_starts_new_trace(self):
"""When context_carrier is None, a new trace should be started."""
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
test_tracer = provider.get_tracer("test")

# Create a span with a known trace_id to ensure the trigger span is different
with test_tracer.start_as_current_span("other_trace") as other_span:
other_trace_id = other_span.get_span_context().trace_id
exporter.clear()

ti = self._make_ti(context_carrier=None)
with patch("airflow.jobs.triggerer_job_runner.tracer", test_tracer):
with _prepare_span(ti=ti, trigger_id=99, name="test_trigger"):
pass

spans = exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].context.trace_id != other_trace_id


@pytest.mark.asyncio
@pytest.mark.usefixtures("testing_dag_bundle")
async def test_trigger_create_race_condition_38599(session, supervisor_builder):
Expand Down
Loading