Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions airflow-core/src/airflow/api_fastapi/core_api/datamodels/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ class XComCollectionResponse(BaseModel):
total_entries: int


def _check_forbidden_xcom_keys(value: Any) -> Any:
"""Recursively reject forbidden deserialization keys in user-provided XCom data."""
from airflow._shared.serialization import FORBIDDEN_XCOM_KEYS

def _walk(obj: Any, path: str = "value") -> None:
if isinstance(obj, dict):
found = FORBIDDEN_XCOM_KEYS & obj.keys()
if found:
raise ValueError(
f"XCom {path} contains reserved serialization keys: {', '.join(sorted(found))}. "
f"These keys are reserved for internal use."
)
for k, v in obj.items():
_walk(v, f"{path}.{k}")
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
_walk(item, f"{path}[{i}]")

_walk(value)
return value


class XComCreateBody(StrictBaseModel):
"""Payload serializer for creating an XCom entry."""

Expand All @@ -93,29 +115,16 @@ class XComCreateBody(StrictBaseModel):
@field_validator("value")
@classmethod
def _check_forbidden_keys(cls, value: Any) -> Any:
"""Recursively check for forbidden deserialization keys in user-provided XCom data."""
from airflow._shared.serialization import FORBIDDEN_XCOM_KEYS

def _walk_forbidden_keys(obj: Any, path: str = "value") -> None:
if isinstance(obj, dict):
found = FORBIDDEN_XCOM_KEYS & obj.keys()
if found:
raise ValueError(
f"XCom {path} contains reserved serialization keys: {', '.join(sorted(found))}. "
f"These keys are reserved for internal use."
)
for k, v in obj.items():
_walk_forbidden_keys(v, f"{path}.{k}")
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
_walk_forbidden_keys(item, f"{path}[{i}]")

_walk_forbidden_keys(value)
return value
return _check_forbidden_xcom_keys(value)


class XComUpdateBody(StrictBaseModel):
"""Payload serializer for updating an XCom entry."""

value: Any
map_index: int = -1

@field_validator("value")
@classmethod
def _check_forbidden_keys(cls, value: Any) -> Any:
return _check_forbidden_xcom_keys(value)
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,23 @@ def test_patch_xcom_entry_with_slash_key(self, test_client, session):
assert response.json()["key"] == slash_key
assert response.json()["value"] == json.dumps(new_value)
check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None)

@pytest.mark.parametrize(
("key", "value"),
[
("__classname__", {"__classname__": "airflow.sdk.definitions.connection.Connection"}),
("__type", {"__type": "airflow.sdk.definitions.connection.Connection", "__var": {}}),
("__data__", {"nested": {"__data__": "malicious"}}),
],
)
def test_patch_xcom_entry_blocks_forbidden_keys(self, test_client, key, value):
"""Test that XCom update blocks deserialization metadata keys."""
self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE)
response = test_client.patch(
f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}",
json={"value": value, "map_index": -1},
)
assert response.status_code == 422
detail = str(response.json()["detail"])
assert "reserved serialization keys" in detail
assert key in detail
Loading