Skip to content

Commit 85487a8

Browse files
authored
Enable disabling of dag bundle versioning (apache#47473)
This allows folks to opt in to running on the latest bundle version for each task - basically how Airflow 2 operates, even if they have a versioned bundle under the hood. Approach: * config options, something like use_bundle_versioning * DAG kwarg that uses that config option as the default * conditionally add bundle_version to the dagrun
1 parent 1b643a4 commit 85487a8

File tree

7 files changed

+96
-3
lines changed

7 files changed

+96
-3
lines changed

airflow/config_templates/config.yml

+8
Original file line numberDiff line numberDiff line change
@@ -2671,6 +2671,14 @@ dag_processor:
26712671
type: integer
26722672
example: ~
26732673
default: "30"
2674+
disable_bundle_versioning:
2675+
description: |
2676+
Always run tasks with the latest code. If set to True, the bundle version will not
2677+
be stored on the dag run and therefore, the latest code will always be used.
2678+
version_added: ~
2679+
type: boolean
2680+
example: ~
2681+
default: "False"
26742682
bundle_refresh_check_interval:
26752683
description: |
26762684
How often the DAG processor should check if any DAG bundles are ready for a refresh, either by hitting

airflow/models/dag.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,11 @@ def _create_orm_dagrun(
261261
triggered_by: DagRunTriggeredByType,
262262
session: Session = NEW_SESSION,
263263
) -> DagRun:
264+
bundle_version = None
265+
if not dag.disable_bundle_versioning:
266+
bundle_version = session.scalar(
267+
select(DagModel.bundle_version).where(DagModel.dag_id == dag.dag_id),
268+
)
264269
run = DagRun(
265270
dag_id=dag.dag_id,
266271
run_id=run_id,
@@ -274,7 +279,7 @@ def _create_orm_dagrun(
274279
data_interval=data_interval,
275280
triggered_by=triggered_by,
276281
backfill_id=backfill_id,
277-
bundle_version=session.scalar(select(DagModel.bundle_version).where(DagModel.dag_id == dag.dag_id)),
282+
bundle_version=bundle_version,
278283
)
279284
# Load defaults into the following two fields to ensure result can be serialized detached
280285
run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id))))

airflow/serialization/schema.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@
206206
{ "$ref": "#/definitions/task_group" }
207207
]},
208208
"edge_info": { "$ref": "#/definitions/edge_info" },
209-
"dag_dependencies": { "$ref": "#/definitions/dag_dependencies" }
209+
"dag_dependencies": { "$ref": "#/definitions/dag_dependencies" },
210+
"disable_bundle_versioning": {"type": "boolean"}
210211
},
211212
"required": [
212213
"dag_id",

task-sdk/src/airflow/sdk/definitions/dag.py

+7
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def __rich_repr__(self):
439439

440440
has_on_success_callback: bool = attrs.field(init=False)
441441
has_on_failure_callback: bool = attrs.field(init=False)
442+
disable_bundle_versioning: bool = attrs.field(init=True)
442443

443444
def __attrs_post_init__(self):
444445
from airflow.utils import timezone
@@ -510,6 +511,12 @@ def _default_timetable(instance: DAG):
510511
else:
511512
return _create_timetable(schedule, instance.timezone)
512513

514+
@disable_bundle_versioning.default
515+
def _disable_bundle_versioning_default(self):
516+
from airflow.configuration import conf as airflow_conf
517+
518+
return airflow_conf.getboolean("dag_processor", "disable_bundle_versioning")
519+
513520
@timezone.default
514521
def _extract_tz(instance):
515522
import pendulum

task-sdk/tests/task_sdk/definitions/test_dag.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def test_continuous_schedule_linmits_max_active_runs(self):
282282
dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=25)
283283

284284

285-
# Test some of the arg valiadtion. This is not all the validations we perform, just some of them.
285+
# Test some of the arg validation. This is not all the validations we perform, just some of them.
286286
@pytest.mark.parametrize(
287287
["attr", "value"],
288288
[

tests/models/test_dag.py

+38
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
from tests_common.test_utils.asserts import assert_queries_count
9999
from tests_common.test_utils.db import (
100100
clear_db_assets,
101+
clear_db_dag_bundles,
101102
clear_db_dags,
102103
clear_db_runs,
103104
clear_db_serialized_dags,
@@ -120,9 +121,11 @@
120121
def clear_dags():
121122
clear_db_dags()
122123
clear_db_serialized_dags()
124+
clear_db_dag_bundles()
123125
yield
124126
clear_db_dags()
125127
clear_db_serialized_dags()
128+
clear_db_dag_bundles()
126129

127130

128131
@pytest.fixture
@@ -3493,3 +3496,38 @@ def test_validate_setup_teardown_trigger_rule(self):
34933496
Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS."
34943497
):
34953498
dag.validate_setup_teardown()
3499+
3500+
3501+
@pytest.mark.parametrize(
3502+
"disable, bundle_version, expected",
3503+
[
3504+
(True, "some-version", None),
3505+
(False, "some-version", "some-version"),
3506+
],
3507+
)
3508+
def test_disable_bundle_versioning(disable, bundle_version, expected, dag_maker, session, clear_dags):
3509+
"""When bundle versioning is disabled for a dag, the dag run should not have a bundle version."""
3510+
3511+
def hello():
3512+
print("hello")
3513+
3514+
with dag_maker(disable_bundle_versioning=disable, session=session, serialized=True) as dag:
3515+
PythonOperator(task_id="hi", python_callable=hello)
3516+
3517+
assert dag.disable_bundle_versioning is disable
3518+
3519+
# the dag *always* has bundle version
3520+
dag_model = session.scalar(select(DagModel).where(DagModel.dag_id == dag.dag_id))
3521+
dag_model.bundle_version = bundle_version
3522+
session.commit()
3523+
3524+
dr = dag.create_dagrun(
3525+
run_id="abcoercuhcrh",
3526+
run_after=pendulum.now(),
3527+
run_type="manual",
3528+
triggered_by=DagRunTriggeredByType.TEST,
3529+
state=None,
3530+
)
3531+
3532+
# but it only gets stamped on the dag run when bundle versioning not disabled
3533+
assert dr.bundle_version == expected

tests/serialization/test_dag_serialization.py

+34
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from airflow.utils.task_group import TaskGroup
8787
from airflow.utils.xcom import XCOM_RETURN_KEY
8888

89+
from tests_common.test_utils.config import conf_vars
8990
from tests_common.test_utils.mock_operators import (
9091
AirflowLink2,
9192
CustomOperator,
@@ -154,6 +155,7 @@
154155
},
155156
"is_paused_upon_creation": False,
156157
"dag_id": "simple_dag",
158+
"disable_bundle_versioning": False,
157159
"doc_md": "### DAG Tutorial Documentation",
158160
"fileloc": None,
159161
"_processor_dags_folder": f"{repo_root}/tests/dags",
@@ -2068,6 +2070,38 @@ def test_dag_on_failure_callback_roundtrip(self, passed_failure_callback, expect
20682070

20692071
assert deserialized_dag.has_on_failure_callback is expected_value
20702072

2073+
@pytest.mark.parametrize(
2074+
"dag_arg, conf_arg, expected",
2075+
[
2076+
(True, "True", True),
2077+
(True, "False", True),
2078+
(False, "True", False),
2079+
(False, "False", False),
2080+
(None, "True", True),
2081+
(None, "False", False),
2082+
],
2083+
)
2084+
def test_dag_disable_bundle_versioning_roundtrip(self, dag_arg, conf_arg, expected):
2085+
"""
2086+
Test that when disable_bundle_versioning is passed to the DAG, has_disable_bundle_versioning is stored
2087+
in Serialized JSON blob. And when it is de-serialized dag.has_disable_bundle_versioning is set to True.
2088+
2089+
When the callback is not set, has_disable_bundle_versioning should not be stored in Serialized blob
2090+
and so default to False on de-serialization
2091+
"""
2092+
with conf_vars({("dag_processor", "disable_bundle_versioning"): conf_arg}):
2093+
kwargs = {}
2094+
kwargs["disable_bundle_versioning"] = dag_arg
2095+
dag = DAG(
2096+
dag_id="test_dag_disable_bundle_versioning_roundtrip",
2097+
schedule=None,
2098+
**kwargs,
2099+
)
2100+
BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1))
2101+
serialized_dag = SerializedDAG.to_dict(dag)
2102+
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
2103+
assert deserialized_dag.disable_bundle_versioning is expected
2104+
20712105
@pytest.mark.parametrize(
20722106
"object_to_serialized, expected_output",
20732107
[

0 commit comments

Comments
 (0)