Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions airflow-core/src/airflow/api_fastapi/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,33 @@ def signing_arg(self) -> AllowedPrivateKeys | str:
assert self._secret_key
return self._secret_key

def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str:
def generate_workload_token(self, sub: str) -> str:
"""Generate a long-lived workload token for executor queues."""
from airflow.configuration import conf

workload_valid_for = conf.getint(
"execution_api", "jwt_workload_token_expiration_time", fallback=86400
)
return self.generate(
extras={"sub": sub, "scope": "workload"},
valid_for=workload_valid_for,
)

def generate(
self,
extras: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
valid_for: float | None = None,
) -> str:
"""Generate a signed JWT for the subject."""
now = int(datetime.now(tz=timezone.utc).timestamp())
effective_valid_for = valid_for if valid_for is not None else self.valid_for
claims = {
"jti": uuid.uuid4().hex,
"iss": self.issuer,
"aud": self.audience,
"nbf": now,
"exp": int(now + self.valid_for),
"exp": int(now + effective_valid_for),
"iat": now,
}

Expand Down
17 changes: 17 additions & 0 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,13 @@ class InProcessExecutionAPI:
@cached_property
def app(self):
if not self._app:
import svcs

from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.common.dagbag import create_dag_bag
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.deps import _container
from airflow.api_fastapi.execution_api.routes.connections import has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import has_variable_access
from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access
Expand All @@ -332,10 +336,23 @@ async def always_allow(request: Request):
)
return TIToken(id=ti_id, claims={"scope": "execution"})

# Override _container (the svcs service locator behind DepContainer).
# The default _container reads request.app.state.svcs_registry, but
# Cadwyn's versioned sub-apps don't inherit the main app's state,
# so lookups raise ServiceNotFoundError. This registry provides
# services needed by routes called during dag.test().
registry = svcs.Registry()
registry.register_factory(JWTGenerator, _jwt_generator)

async def _in_process_container(request: Request):
async with svcs.Container(registry) as cont:
yield cont

self._app.dependency_overrides[_jwt_bearer] = always_allow
self._app.dependency_overrides[has_connection_access] = always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
self._app.dependency_overrides[has_xcom_access] = always_allow
self._app.dependency_overrides[_container] = _in_process_container

return self._app

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import attrs
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, Security, status
from fastapi import Body, HTTPException, Query, Response, Security, status
from pydantic import JsonValue
from sqlalchemy import and_, func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
Expand All @@ -38,6 +38,7 @@
from structlog.contextvars import bind_contextvars

from airflow._shared.timezones import timezone
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
Expand All @@ -59,6 +60,7 @@
TISuccessStatePayload,
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import DepContainer
from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
Expand Down Expand Up @@ -92,18 +94,21 @@
@ti_id_router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])],
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"},
},
response_model_exclude_unset=True,
)
def ti_run(
async def ti_run(
task_instance_id: UUID,
ti_run_payload: Annotated[TIEnterRunningPayload, Body()],
response: Response,
session: SessionDep,
dag_bag: DagBagDep,
services=DepContainer,
) -> TIRunContext:
"""
Run a TaskInstance.
Expand Down Expand Up @@ -282,6 +287,10 @@ def ti_run(
context.next_method = ti.next_method
context.next_kwargs = ti.next_kwargs

generator: JWTGenerator = await services.aget(JWTGenerator)
execution_token = generator.generate(extras={"sub": str(task_instance_id)})
response.headers["X-Execution-Token"] = execution_token

return context
except SQLAlchemyError:
log.exception("Error marking Task Instance state as running")
Expand Down
9 changes: 9 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,15 @@ execution_api:
type: integer
example: ~
default: "600"
jwt_workload_token_expiration_time:
description: |
Seconds until workload JWT tokens expire. These long-lived tokens are sent
with task workloads to executors and can only call the /run endpoint.
Set long enough to cover maximum expected queue wait time.
version_added: ~
type: integer
example: ~
default: "86400"
jwt_audience:
version_added: 3.0.0
description: |
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/executors/workloads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class BaseWorkloadSchema(BaseModel):

@staticmethod
def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str:
return generator.generate({"sub": sub_id}) if generator else ""
return generator.generate_workload_token(sub_id) if generator else ""


class BaseDagBundleWorkload(BaseWorkloadSchema, ABC):
Expand Down
43 changes: 43 additions & 0 deletions airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,49 @@ def test_secret_key_with_configured_kid():
assert header["kid"] == "my-custom-kid"


def test_generate_workload_token():
"""generate_workload_token() produces a token with scope 'workload' and 24h expiry."""
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)

with patch.dict(
"os.environ",
{"AIRFLOW__EXECUTION_API__JWT_WORKLOAD_TOKEN_EXPIRATION_TIME": "86400"},
):
token = generator.generate_workload_token(sub="ti-123")

claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test")
assert claims["sub"] == "ti-123"
assert claims["scope"] == "workload"
# Workload token should have ~24h validity, not the generator's default 60s
assert claims["exp"] - claims["iat"] == 86400


def test_generate_with_custom_valid_for():
"""generate() accepts a valid_for override."""
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)
token = generator.generate(extras={"sub": "user"}, valid_for=3600)
claims = jwt.decode(token, "test-secret", algorithms=["HS512"], audience="test")
assert claims["exp"] - claims["iat"] == 3600


def test_workload_token_vs_regular_token_scope():
"""Regular tokens have no scope, workload tokens have scope 'workload'."""
generator = JWTGenerator(secret_key="test-secret", audience="test", valid_for=60)

regular = generator.generate(extras={"sub": "user"})
regular_claims = jwt.decode(regular, "test-secret", algorithms=["HS512"], audience="test")
assert "scope" not in regular_claims

with patch.dict(
"os.environ",
{"AIRFLOW__EXECUTION_API__JWT_WORKLOAD_TOKEN_EXPIRATION_TIME": "86400"},
):
workload = generator.generate_workload_token(sub="ti-123")

workload_claims = jwt.decode(workload, "test-secret", algorithms=["HS512"], audience="test")
assert workload_claims["scope"] == "workload"


@pytest.fixture
def jwt_generator(ed25519_private_key: Ed25519PrivateKey):
key = ed25519_private_key
Expand Down
9 changes: 9 additions & 0 deletions airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock

import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient

from airflow.api_fastapi.app import cached_app
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.execution_api.app import lifespan
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.security import _jwt_bearer

Expand Down Expand Up @@ -53,6 +57,11 @@ async def mock_jwt_bearer(request: Request):
exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer

with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
mock_generator = MagicMock(spec=JWTGenerator)
mock_generator.generate.return_value = "mock-execution-token"
mock_generator.generate_workload_token.return_value = "mock-workload-token"
lifespan.registry.register_value(JWTGenerator, mock_generator)

yield client

exec_app.dependency_overrides.pop(_jwt_bearer, None)
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,36 @@ def test_ti_run_state_to_running(
)
assert response.status_code == 409

def test_ti_run_returns_execution_token(self, client, session, create_task_instance, time_machine):
"""PATCH /run should return an X-Execution-Token header on success."""
instant = timezone.parse("2024-10-31T12:00:00Z")
time_machine.move_to(instant, tick=False)

ti = create_task_instance(
task_id="test_exec_token",
state=State.QUEUED,
dagrun_state=DagRunState.RUNNING,
session=session,
start_date=instant,
dag_id=str(uuid4()),
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "test-host",
"unixname": "test-user",
"pid": 100,
"start_date": "2024-10-31T12:00:00Z",
},
)

assert response.status_code == 200
assert "X-Execution-Token" in response.headers
assert response.headers["X-Execution-Token"] == "mock-execution-token"

def test_dynamic_task_mapping_with_parse_time_value(self, client, dag_maker):
"""Test that dynamic task mapping works correctly with parse-time values."""
with dag_maker("test_dynamic_task_mapping_with_parse_time_value", serialized=True):
Expand Down
1 change: 1 addition & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def set_instance_attrs(self) -> Generator:
def mock_executors(self):
mock_jwt_generator = MagicMock(spec=JWTGenerator)
mock_jwt_generator.generate.return_value = "mock-token"
mock_jwt_generator.generate_workload_token.return_value = "mock-workload-token"

default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0)
default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path")
Expand Down
1 change: 1 addition & 0 deletions devel-common/src/tests_common/test_utils/mock_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, do_update=True, *args, **kwargs):
# Mock JWT generator for token generation
mock_jwt_generator = MagicMock()
mock_jwt_generator.generate.return_value = "mock-token"
mock_jwt_generator.generate_workload_token.return_value = "mock-workload-token"

self.jwt_generator = mock_jwt_generator

Expand Down
5 changes: 4 additions & 1 deletion task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,10 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
)

def _update_auth(self, response: httpx.Response):
if new_token := response.headers.get("Refreshed-API-Token"):
if new_token := response.headers.get("X-Execution-Token"):
log.debug("Received execution token, swapping auth")
self.auth = BearerAuth(new_token)
elif new_token := response.headers.get("Refreshed-API-Token"):
log.debug("Execution API issued us a refreshed Task token")
self.auth = BearerAuth(new_token)

Expand Down
30 changes: 30 additions & 0 deletions task-sdk/tests/task_sdk/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,36 @@ def test_token_renewal(self):
assert response.status_code == 200
assert response.request.headers["Authorization"] == "Bearer abc"

def test_execution_token_swap(self):
"""X-Execution-Token header should replace the auth token."""
responses: list[httpx.Response] = [
httpx.Response(200, json={"ok": "1"}, headers={"X-Execution-Token": "exec-token-123"}),
httpx.Response(200, json={"ok": "2"}),
]
client = make_client_w_responses(responses)
response = client.get("/")
assert response.status_code == 200

assert client.auth is not None
assert client.auth.token == "exec-token-123"

response = client.get("/")
assert response.status_code == 200
assert response.request.headers["Authorization"] == "Bearer exec-token-123"

def test_execution_token_takes_priority_over_refreshed_token(self):
"""When both headers present, X-Execution-Token should take priority."""
responses: list[httpx.Response] = [
httpx.Response(
200,
json={"ok": "1"},
headers={"X-Execution-Token": "exec-tok", "Refreshed-API-Token": "refresh-tok"},
),
]
client = make_client_w_responses(responses)
client.get("/")
assert client.auth.token == "exec-tok"

@pytest.mark.parametrize(
("status_code", "description"),
[
Expand Down
Loading