Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .codespellignorelines
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
The platform supports **C**reate, **R**ead, **U**pdate, and **D**elete operations on most resources.
<pre><code>Code block\ndoes not\nrespect\nnewlines\n</code></pre>
"trough",
assert "task_instance_id" in route.dependant.path_param_names, (
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ Adding a new Execution API feature touches multiple packages. All of these must
- Triggerer handler: `airflow-core/src/airflow/jobs/triggerer_job_runner.py`
- Task SDK generated models: `task-sdk/src/airflow/sdk/api/datamodels/_generated.py`
- Full versioning guide: [`contributing-docs/19_execution_api_versioning.rst`](../../../../contributing-docs/19_execution_api_versioning.rst)

## Token Scope Infrastructure

Token types (`"execution"`, `"workload"`), route-level enforcement via `ExecutionAPIRoute` + `require_auth`, and the `ti:self` path-parameter validation are documented in the module docstring of `security.py`.
26 changes: 19 additions & 7 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def replace_any_of_with_one_of(spec):
if prop.get("type") == "string" and (const := prop.pop("const", None)):
prop["enum"] = [const]

# Remove internal x-airflow-* extension fields from OpenAPI spec
# These are used for runtime validation but shouldn't be exposed in the public API
for path_item in openapi_schema.get("paths", {}).values():
for operation in path_item.values():
if isinstance(operation, dict):
keys_to_remove = [key for key in operation.keys() if key.startswith("x-airflow-")]
for key in keys_to_remove:
del operation[key]

return openapi_schema


Expand Down Expand Up @@ -304,23 +313,26 @@ def app(self):
if not self._app:
from airflow.api_fastapi.common.dagbag import create_dag_bag
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
from airflow.api_fastapi.execution_api.deps import (
JWTBearerDep,
JWTBearerTIPathDep,
)
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.routes.connections import has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import has_variable_access
from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access
from airflow.api_fastapi.execution_api.security import _jwt_bearer

self._app = create_task_execution_api_app()

# Set up dag_bag in app state for dependency injection
self._app.state.dag_bag = create_dag_bag()

async def always_allow(): ...
async def always_allow(request: Request):
from uuid import UUID

ti_id = UUID(
request.path_params.get("task_instance_id", "00000000-0000-0000-0000-000000000000")
)
return TIToken(id=ti_id, claims={"scope": "execution"})

self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow
self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow
self._app.dependency_overrides[_jwt_bearer] = always_allow
self._app.dependency_overrides[has_connection_access] = always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
self._app.dependency_overrides[has_xcom_access] = always_allow
Expand Down
90 changes: 1 addition & 89 deletions airflow-core/src/airflow/api_fastapi/execution_api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,8 @@
# Disable future annotations in this file to work around https://github.com/fastapi/fastapi/issues/13056
# ruff: noqa: I002

from typing import Any

import structlog
import svcs
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer
from sqlalchemy import select

from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.common.db.common import AsyncSessionDep
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.configuration import conf
from airflow.models import DagModel, TaskInstance
from airflow.models.dagbundle import DagBundleModel
from airflow.models.team import Team

log = structlog.get_logger(logger_name=__name__)
from fastapi import Depends, Request


# See https://github.com/fastapi/fastapi/issues/13056
Expand All @@ -44,76 +29,3 @@ async def _container(request: Request):


DepContainer: svcs.Container = Depends(_container)


class JWTBearer(HTTPBearer):
"""
A FastAPI security dependency that validates JWT tokens using for the Execution API.

This will validate the tokens are signed and that the ``sub`` is a UUID, but nothing deeper than that.

The dependency result will be an `TIToken` object containing the ``id`` UUID (from the ``sub``) and other
validated claims.
"""

def __init__(
self,
path_param_name: str | None = None,
required_claims: dict[str, Any] | None = None,
):
super().__init__(auto_error=False)
self.path_param_name = path_param_name
self.required_claims = required_claims or {}

async def __call__( # type: ignore[override]
self,
request: Request,
services=DepContainer,
) -> TIToken | None:
creds = await super().__call__(request)
if not creds:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token")

validator: JWTValidator = await services.aget(JWTValidator)

try:
# Example: Validate "task_instance_id" component of the path matches the one in the token
if self.path_param_name:
id = request.path_params[self.path_param_name]
validators: dict[str, Any] = {
**self.required_claims,
"sub": {"essential": True, "value": id},
}
else:
validators = self.required_claims
claims = await validator.avalidated_claims(creds.credentials, validators)
return TIToken(id=claims["sub"], claims=claims)
except Exception as err:
log.warning(
"Failed to validate JWT",
exc_info=True,
token=creds.credentials,
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}")


JWTBearerDep: TIToken = Depends(JWTBearer())

# This checks that the UUID in the url matches the one in the token for us.
JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))


async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None:
"""Return the team name associated to the task (if any)."""
if not conf.getboolean("core", "multi_team"):
return None

stmt = (
select(Team.name)
.select_from(TaskInstance)
.join(DagModel, DagModel.dag_id == TaskInstance.dag_id)
.join(DagBundleModel, DagBundleModel.name == DagModel.bundle_name)
.join(DagBundleModel.teams)
.where(TaskInstance.id == token.id)
)
return await session.scalar(stmt)
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from __future__ import annotations

from cadwyn import VersionedAPIRouter
from fastapi import APIRouter
from fastapi import APIRouter, Security

from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.routes import (
asset_events,
assets,
Expand All @@ -32,12 +31,13 @@
variables,
xcoms,
)
from airflow.api_fastapi.execution_api.security import require_auth

execution_api_router = APIRouter()
execution_api_router.include_router(health.router, prefix="/health", tags=["Health"])

# _Every_ single endpoint under here must be authenticated. Some do further checks on top of these
authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) # type: ignore[list-item]
authenticated_router = VersionedAPIRouter(dependencies=[Security(require_auth)]) # type: ignore[list-item]

authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
from fastapi import APIRouter, Depends, HTTPException, Path, status

from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.deps import JWTBearerDep, get_team_name_dep
from airflow.api_fastapi.execution_api.security import CurrentTIToken, get_team_name_dep
from airflow.exceptions import AirflowNotFoundException
from airflow.models.connection import Connection


async def has_connection_access(
connection_id: str = Path(),
token=JWTBearerDep,
token=CurrentTIToken,
) -> bool:
"""Check if the task has access to the connection."""
# TODO: Placeholder for actual implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import attrs
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, status
from fastapi import Body, HTTPException, Query, Security, status
from pydantic import JsonValue
from sqlalchemy import func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
Expand Down Expand Up @@ -59,7 +59,7 @@
TISuccessStatePayload,
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dag import DagModel
Expand All @@ -78,10 +78,10 @@
router = VersionedAPIRouter()

ti_id_router = VersionedAPIRouter(
route_class=ExecutionAPIRoute,
dependencies=[
# This checks that the UUID in the url matches the one in the token for us.
JWTBearerTIPathDep
]
Security(require_auth, scopes=["ti:self"]),
],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
VariablePostBody,
VariableResponse,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerDep, get_team_name_dep
from airflow.api_fastapi.execution_api.security import CurrentTIToken, get_team_name_dep
from airflow.models.variable import Variable


async def has_variable_access(
request: Request,
variable_key: str = Path(),
token=JWTBearerDep,
token=CurrentTIToken,
):
"""Check if the task has access to the variable."""
write = request.method not in {"GET", "HEAD", "OPTIONS"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
XComSequenceIndexResponse,
XComSequenceSliceResponse,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.security import CurrentTIToken
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XComModel
from airflow.utils.db import get_query_count
Expand All @@ -44,7 +44,7 @@ async def has_xcom_access(
task_id: str,
xcom_key: Annotated[str, Path(alias="key", min_length=1)],
request: Request,
token=JWTBearerDep,
token=CurrentTIToken,
) -> bool:
"""Check if the task has access to the XCom."""
# TODO: Placeholder for actual implementation
Expand Down
Loading
Loading