Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 01cee78

Browse files
authoredMar 11, 2025··
Move chain, chain_linear & cross_downstream to Task SDK (#47639)
This functions are for DAG Authors to define relationship between multiple tasks in batch
1 parent b9ab634 commit 01cee78

File tree

18 files changed

+523
-491
lines changed

18 files changed

+523
-491
lines changed
 

‎airflow/example_dags/example_asset_with_watchers.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
from __future__ import annotations
2222

2323
from airflow.decorators import task
24-
from airflow.models.baseoperator import chain
2524
from airflow.models.dag import DAG
2625
from airflow.providers.standard.triggers.file import FileDeleteTrigger
27-
from airflow.sdk import Asset, AssetWatcher
26+
from airflow.sdk import Asset, AssetWatcher, chain
2827

2928
file_path = "/tmp/test"
3029

‎airflow/example_dags/example_bash_decorator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from airflow.decorators import dag, task
2323
from airflow.exceptions import AirflowSkipException
24-
from airflow.models.baseoperator import chain
2524
from airflow.providers.standard.operators.empty import EmptyOperator
25+
from airflow.sdk import chain
2626
from airflow.utils.trigger_rule import TriggerRule
2727
from airflow.utils.weekday import WeekDay
2828

‎airflow/example_dags/example_complex.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323

2424
import pendulum
2525

26-
from airflow.models.baseoperator import chain
2726
from airflow.models.dag import DAG
2827
from airflow.providers.standard.operators.bash import BashOperator
28+
from airflow.sdk import chain
2929

3030
with DAG(
3131
dag_id="example_complex",

‎airflow/example_dags/example_short_circuit_decorator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import pendulum
2222

2323
from airflow.decorators import dag, task
24-
from airflow.models.baseoperator import chain
2524
from airflow.providers.standard.operators.empty import EmptyOperator
25+
from airflow.sdk import chain
2626
from airflow.utils.trigger_rule import TriggerRule
2727

2828

‎airflow/example_dags/example_short_circuit_operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222
import pendulum
2323

24-
from airflow.models.baseoperator import chain
2524
from airflow.models.dag import DAG
2625
from airflow.providers.standard.operators.empty import EmptyOperator
2726
from airflow.providers.standard.operators.python import ShortCircuitOperator
27+
from airflow.sdk import chain
2828
from airflow.utils.trigger_rule import TriggerRule
2929

3030
with DAG(

‎airflow/models/baseoperator.py

+6-267
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import functools
2727
import logging
2828
import operator
29-
from collections.abc import Collection, Iterable, Iterator, Sequence
29+
from collections.abc import Collection, Iterable, Iterator
3030
from datetime import datetime, timedelta
3131
from functools import singledispatchmethod
3232
from types import FunctionType
@@ -54,14 +54,16 @@
5454
NotMapped,
5555
)
5656
from airflow.models.taskinstance import TaskInstance, clear_task_instances
57-
from airflow.models.taskmixin import DependencyMixin
5857
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator
5958
from airflow.sdk.definitions.baseoperator import (
60-
get_merged_defaults as get_merged_defaults, # Re-export for compat
59+
# Re-export for compat
60+
chain as chain,
61+
chain_linear as chain_linear,
62+
cross_downstream as cross_downstream,
63+
get_merged_defaults as get_merged_defaults,
6164
)
6265
from airflow.sdk.definitions.context import Context
6366
from airflow.sdk.definitions.dag import BaseOperator as TaskSDKBaseOperator
64-
from airflow.sdk.definitions.edges import EdgeModifier as TaskSDKEdgeModifier
6567
from airflow.sdk.definitions.mappedoperator import MappedOperator
6668
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
6769
from airflow.serialization.enums import DagAttributeTypes
@@ -72,7 +74,6 @@
7274
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
7375
from airflow.utils import timezone
7476
from airflow.utils.context import context_get_outlet_events
75-
from airflow.utils.edgemodifier import EdgeModifier
7677
from airflow.utils.operator_helpers import ExecutionCallableRunner
7778
from airflow.utils.operator_resources import Resources
7879
from airflow.utils.session import NEW_SESSION, provide_session
@@ -811,265 +812,3 @@ def iter_mapped_task_group_lengths(group) -> Iterator[int]:
811812
group = group.parent_group
812813

813814
return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group))
814-
815-
816-
def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
817-
r"""
818-
Given a number of tasks, builds a dependency chain.
819-
820-
This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
821-
or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
822-
lists you must ensure they have the same length.
823-
824-
Using classic operators/sensors:
825-
826-
.. code-block:: python
827-
828-
chain(t1, [t2, t3], [t4, t5], t6)
829-
830-
is equivalent to::
831-
832-
/ -> t2 -> t4 \
833-
t1 -> t6
834-
\ -> t3 -> t5 /
835-
836-
.. code-block:: python
837-
838-
t1.set_downstream(t2)
839-
t1.set_downstream(t3)
840-
t2.set_downstream(t4)
841-
t3.set_downstream(t5)
842-
t4.set_downstream(t6)
843-
t5.set_downstream(t6)
844-
845-
Using task-decorated functions aka XComArgs:
846-
847-
.. code-block:: python
848-
849-
chain(x1(), [x2(), x3()], [x4(), x5()], x6())
850-
851-
is equivalent to::
852-
853-
/ -> x2 -> x4 \
854-
x1 -> x6
855-
\ -> x3 -> x5 /
856-
857-
.. code-block:: python
858-
859-
x1 = x1()
860-
x2 = x2()
861-
x3 = x3()
862-
x4 = x4()
863-
x5 = x5()
864-
x6 = x6()
865-
x1.set_downstream(x2)
866-
x1.set_downstream(x3)
867-
x2.set_downstream(x4)
868-
x3.set_downstream(x5)
869-
x4.set_downstream(x6)
870-
x5.set_downstream(x6)
871-
872-
Using TaskGroups:
873-
874-
.. code-block:: python
875-
876-
chain(t1, task_group1, task_group2, t2)
877-
878-
t1.set_downstream(task_group1)
879-
task_group1.set_downstream(task_group2)
880-
task_group2.set_downstream(t2)
881-
882-
883-
It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
884-
885-
.. code-block:: python
886-
887-
chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
888-
889-
is equivalent to::
890-
891-
/ "branch one" -> x1 \
892-
t1 -> task_group1 -> x3
893-
\ "branch two" -> x2 /
894-
895-
.. code-block:: python
896-
897-
x1 = x1()
898-
x2 = x2()
899-
x3 = x3()
900-
label1 = Label("branch one")
901-
label2 = Label("branch two")
902-
t1.set_downstream(label1)
903-
label1.set_downstream(x1)
904-
t2.set_downstream(label2)
905-
label2.set_downstream(x2)
906-
x1.set_downstream(task_group1)
907-
x2.set_downstream(task_group1)
908-
task_group1.set_downstream(x3)
909-
910-
# or
911-
912-
x1 = x1()
913-
x2 = x2()
914-
x3 = x3()
915-
t1.set_downstream(x1, edge_modifier=Label("branch one"))
916-
t1.set_downstream(x2, edge_modifier=Label("branch two"))
917-
x1.set_downstream(task_group1)
918-
x2.set_downstream(task_group1)
919-
task_group1.set_downstream(x3)
920-
921-
922-
:param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
923-
"""
924-
for up_task, down_task in zip(tasks, tasks[1:]):
925-
if isinstance(up_task, DependencyMixin):
926-
up_task.set_downstream(down_task)
927-
continue
928-
if isinstance(down_task, DependencyMixin):
929-
down_task.set_upstream(up_task)
930-
continue
931-
if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
932-
raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
933-
up_task_list = up_task
934-
down_task_list = down_task
935-
if len(up_task_list) != len(down_task_list):
936-
raise AirflowException(
937-
f"Chain not supported for different length Iterable. "
938-
f"Got {len(up_task_list)} and {len(down_task_list)}."
939-
)
940-
for up_t, down_t in zip(up_task_list, down_task_list):
941-
up_t.set_downstream(down_t)
942-
943-
944-
def cross_downstream(
945-
from_tasks: Sequence[DependencyMixin],
946-
to_tasks: DependencyMixin | Sequence[DependencyMixin],
947-
):
948-
r"""
949-
Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
950-
951-
Using classic operators/sensors:
952-
953-
.. code-block:: python
954-
955-
cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
956-
957-
is equivalent to::
958-
959-
t1 ---> t4
960-
\ /
961-
t2 -X -> t5
962-
/ \
963-
t3 ---> t6
964-
965-
.. code-block:: python
966-
967-
t1.set_downstream(t4)
968-
t1.set_downstream(t5)
969-
t1.set_downstream(t6)
970-
t2.set_downstream(t4)
971-
t2.set_downstream(t5)
972-
t2.set_downstream(t6)
973-
t3.set_downstream(t4)
974-
t3.set_downstream(t5)
975-
t3.set_downstream(t6)
976-
977-
Using task-decorated functions aka XComArgs:
978-
979-
.. code-block:: python
980-
981-
cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
982-
983-
is equivalent to::
984-
985-
x1 ---> x4
986-
\ /
987-
x2 -X -> x5
988-
/ \
989-
x3 ---> x6
990-
991-
.. code-block:: python
992-
993-
x1 = x1()
994-
x2 = x2()
995-
x3 = x3()
996-
x4 = x4()
997-
x5 = x5()
998-
x6 = x6()
999-
x1.set_downstream(x4)
1000-
x1.set_downstream(x5)
1001-
x1.set_downstream(x6)
1002-
x2.set_downstream(x4)
1003-
x2.set_downstream(x5)
1004-
x2.set_downstream(x6)
1005-
x3.set_downstream(x4)
1006-
x3.set_downstream(x5)
1007-
x3.set_downstream(x6)
1008-
1009-
It is also possible to mix between classic operator/sensor and XComArg tasks:
1010-
1011-
.. code-block:: python
1012-
1013-
cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1014-
1015-
is equivalent to::
1016-
1017-
t1 ---> x1
1018-
\ /
1019-
x2 -X -> t2
1020-
/ \
1021-
t3 ---> x3
1022-
1023-
.. code-block:: python
1024-
1025-
x1 = x1()
1026-
x2 = x2()
1027-
x3 = x3()
1028-
t1.set_downstream(x1)
1029-
t1.set_downstream(t2)
1030-
t1.set_downstream(x3)
1031-
x2.set_downstream(x1)
1032-
x2.set_downstream(t2)
1033-
x2.set_downstream(x3)
1034-
t3.set_downstream(x1)
1035-
t3.set_downstream(t2)
1036-
t3.set_downstream(x3)
1037-
1038-
:param from_tasks: List of tasks or XComArgs to start from.
1039-
:param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1040-
"""
1041-
for task in from_tasks:
1042-
task.set_downstream(to_tasks)
1043-
1044-
1045-
def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
1046-
"""
1047-
Simplify task dependency definition.
1048-
1049-
E.g.: suppose you want precedence like so::
1050-
1051-
╭─op2─╮ ╭─op4─╮
1052-
op1─┤ ├─├─op5─┤─op7
1053-
╰-op3─╯ ╰-op6─╯
1054-
1055-
Then you can accomplish like so::
1056-
1057-
chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
1058-
1059-
:param elements: a list of operators / lists of operators
1060-
"""
1061-
if not elements:
1062-
raise ValueError("No tasks provided; nothing to do.")
1063-
prev_elem = None
1064-
deps_set = False
1065-
for curr_elem in elements:
1066-
if isinstance(curr_elem, (EdgeModifier, TaskSDKEdgeModifier)):
1067-
raise ValueError("Labels are not supported by chain_linear")
1068-
if prev_elem is not None:
1069-
for task in prev_elem:
1070-
task >> curr_elem
1071-
if not deps_set:
1072-
deps_set = True
1073-
prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem
1074-
if not deps_set:
1075-
raise ValueError("No dependencies were set. Did you forget to expand with `*`?")

‎airflow/utils/edgemodifier.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,5 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from typing import TYPE_CHECKING
20-
21-
import airflow.sdk
22-
23-
if TYPE_CHECKING:
24-
from airflow.typing_compat import TypeAlias
25-
26-
EdgeModifier: TypeAlias = airflow.sdk.definitions.edges.EdgeModifier
27-
28-
29-
# Factory functions
30-
def Label(label: str):
31-
"""Create an EdgeModifier that sets a human-readable label on the edge."""
32-
return EdgeModifier(label=label)
19+
# Re-export for compat
20+
from airflow.sdk.definitions.edges import Label as Label

‎dev/perf/dags/elastic_dag.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from datetime import datetime, timedelta
2323
from enum import Enum
2424

25-
from airflow.models.baseoperator import chain
2625
from airflow.models.dag import DAG
2726
from airflow.providers.standard.operators.bash import BashOperator
27+
from airflow.sdk import chain
2828

2929
# DAG File used in performance tests. Its shape can be configured by environment variables.
3030
RE_TIME_DELTA = re.compile(

‎docs/apache-airflow/core-concepts/dags.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Or, you can also use the more explicit ``set_upstream`` and ``set_downstream`` m
115115

116116
There are also shortcuts to declaring more complex dependencies. If you want to make a list of tasks depend on another list of tasks, you can't use either of the approaches above, so you need to use ``cross_downstream``::
117117

118-
from airflow.models.baseoperator import cross_downstream
118+
from airflow.sdk import cross_downstream
119119

120120
# Replaces
121121
# [op1, op2] >> op3
@@ -124,7 +124,7 @@ There are also shortcuts to declaring more complex dependencies. If you want to
124124

125125
And if you want to chain together dependencies, you can use ``chain``::
126126

127-
from airflow.models.baseoperator import chain
127+
from airflow.sdk import chain
128128

129129
# Replaces op1 >> op2 >> op3 >> op4
130130
chain(op1, op2, op3, op4)
@@ -134,7 +134,7 @@ And if you want to chain together dependencies, you can use ``chain``::
134134

135135
Chain can also do *pairwise* dependencies for lists the same size (this is different from the *cross dependencies* created by ``cross_downstream``!)::
136136

137-
from airflow.models.baseoperator import chain
137+
from airflow.sdk import chain
138138

139139
# Replaces
140140
# op1 >> op2 >> op4 >> op6

‎newsfragments/aip-72.significant.rst

+32-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ As part of this change the following breaking changes have occurred:
3434

3535
.. code-block:: bash
3636
37-
airflow db drop-archived -t "_xcAm_archive"
37+
airflow db drop-archived -t "_xcom_archive"
3838
3939
- The ability to specify scheduling conditions for an operator via the ``deps`` class attribute has been removed.
4040

@@ -46,7 +46,37 @@ As part of this change the following breaking changes have occurred:
4646

4747
Any occurrences of imports from ``airflow.models.baseoperatorlink`` will need to be updated to ``airflow.sdk.definitions.baseoperatorlink``
4848

49-
- With the We have removed DAG level settings that control the UI behaviour.
49+
- ``chain``, ``chain_linear`` and ``cross_downstream`` have been moved to the task SDK.
50+
51+
Any occurrences of imports from ``airflow.models.baseoperator`` will need to be updated to ``airflow.sdk``
52+
53+
Old imports:
54+
55+
.. code-block:: python
56+
57+
from airflow.models.baseoperator import chain, chain_linear, cross_downstream
58+
59+
New imports:
60+
61+
.. code-block:: python
62+
63+
from airflow.sdk import chain, chain_linear, cross_downstream
64+
65+
- The ``Label`` class has been moved to the task SDK.
66+
67+
Old imports:
68+
69+
.. code-block:: python
70+
71+
from airflow.utils.edgemodifier import Label
72+
73+
New imports:
74+
75+
.. code-block:: python
76+
77+
from airflow.sdk import Label
78+
79+
- We have removed DAG level settings that control the UI behaviour.
5080
These are now as per-user settings controlled by the UI
5181

5282
- ``default_view``

‎performance/src/performance_dags/performance_dag/performance_dag.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@
5454
)
5555

5656
from airflow import DAG
57-
from airflow.models.baseoperator import chain
5857
from airflow.operators.bash import BashOperator
5958
from airflow.operators.python import PythonOperator
59+
from airflow.sdk import chain
6060
from airflow.utils.trigger_rule import TriggerRule
6161

6262
# DAG File used in performance tests. Its shape can be configured by environment variables.

‎task-sdk/src/airflow/sdk/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
"Variable",
3939
"XComArg",
4040
"asset",
41+
"chain",
42+
"chain_linear",
43+
"cross_downstream",
4144
"dag",
4245
"get_current_context",
4346
"get_parsing_context",
@@ -50,7 +53,7 @@
5053
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher
5154
from airflow.sdk.definitions.assets.decorators import asset
5255
from airflow.sdk.definitions.assets.metadata import Metadata
53-
from airflow.sdk.definitions.baseoperator import BaseOperator
56+
from airflow.sdk.definitions.baseoperator import BaseOperator, chain, chain_linear, cross_downstream
5457
from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink
5558
from airflow.sdk.definitions.connection import Connection
5659
from airflow.sdk.definitions.context import Context, get_current_context, get_parsing_context
@@ -81,6 +84,9 @@
8184
"Variable": ".definitions.variable",
8285
"XComArg": ".definitions.xcom_arg",
8386
"asset": ".definitions.asset.decorators",
87+
"chain": ".definitions.baseoperator",
88+
"chain_linear": ".definitions.baseoperator",
89+
"cross_downstream": ".definitions.baseoperator",
8490
"dag": ".definitions.dag",
8591
"get_current_context": ".definitions.context",
8692
"get_parsing_context": ".definitions.context",

‎task-sdk/src/airflow/sdk/definitions/baseoperator.py

+264
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@
5050
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
5151
DEFAULT_WEIGHT_RULE,
5252
AbstractOperator,
53+
DependencyMixin,
5354
TaskStateChangeCallback,
5455
)
5556
from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack
5657
from airflow.sdk.definitions._internal.node import validate_key
5758
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, validate_instance_args
59+
from airflow.sdk.definitions.edges import EdgeModifier
5860
from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs
5961
from airflow.sdk.definitions.param import ParamsDict
6062
from airflow.task.priority_strategy import (
@@ -1577,3 +1579,265 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None,
15771579
# Grab the callable off the Operator/Task and add in any kwargs
15781580
execute_callable = getattr(self, next_method)
15791581
return execute_callable(context, **next_kwargs)
1582+
1583+
1584+
def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
1585+
r"""
1586+
Given a number of tasks, builds a dependency chain.
1587+
1588+
This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
1589+
or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
1590+
lists you must ensure they have the same length.
1591+
1592+
Using classic operators/sensors:
1593+
1594+
.. code-block:: python
1595+
1596+
chain(t1, [t2, t3], [t4, t5], t6)
1597+
1598+
is equivalent to::
1599+
1600+
/ -> t2 -> t4 \
1601+
t1 -> t6
1602+
\ -> t3 -> t5 /
1603+
1604+
.. code-block:: python
1605+
1606+
t1.set_downstream(t2)
1607+
t1.set_downstream(t3)
1608+
t2.set_downstream(t4)
1609+
t3.set_downstream(t5)
1610+
t4.set_downstream(t6)
1611+
t5.set_downstream(t6)
1612+
1613+
Using task-decorated functions aka XComArgs:
1614+
1615+
.. code-block:: python
1616+
1617+
chain(x1(), [x2(), x3()], [x4(), x5()], x6())
1618+
1619+
is equivalent to::
1620+
1621+
/ -> x2 -> x4 \
1622+
x1 -> x6
1623+
\ -> x3 -> x5 /
1624+
1625+
.. code-block:: python
1626+
1627+
x1 = x1()
1628+
x2 = x2()
1629+
x3 = x3()
1630+
x4 = x4()
1631+
x5 = x5()
1632+
x6 = x6()
1633+
x1.set_downstream(x2)
1634+
x1.set_downstream(x3)
1635+
x2.set_downstream(x4)
1636+
x3.set_downstream(x5)
1637+
x4.set_downstream(x6)
1638+
x5.set_downstream(x6)
1639+
1640+
Using TaskGroups:
1641+
1642+
.. code-block:: python
1643+
1644+
chain(t1, task_group1, task_group2, t2)
1645+
1646+
t1.set_downstream(task_group1)
1647+
task_group1.set_downstream(task_group2)
1648+
task_group2.set_downstream(t2)
1649+
1650+
1651+
It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
1652+
1653+
.. code-block:: python
1654+
1655+
chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
1656+
1657+
is equivalent to::
1658+
1659+
/ "branch one" -> x1 \
1660+
t1 -> task_group1 -> x3
1661+
\ "branch two" -> x2 /
1662+
1663+
.. code-block:: python
1664+
1665+
x1 = x1()
1666+
x2 = x2()
1667+
x3 = x3()
1668+
label1 = Label("branch one")
1669+
label2 = Label("branch two")
1670+
t1.set_downstream(label1)
1671+
label1.set_downstream(x1)
1672+
t2.set_downstream(label2)
1673+
label2.set_downstream(x2)
1674+
x1.set_downstream(task_group1)
1675+
x2.set_downstream(task_group1)
1676+
task_group1.set_downstream(x3)
1677+
1678+
# or
1679+
1680+
x1 = x1()
1681+
x2 = x2()
1682+
x3 = x3()
1683+
t1.set_downstream(x1, edge_modifier=Label("branch one"))
1684+
t1.set_downstream(x2, edge_modifier=Label("branch two"))
1685+
x1.set_downstream(task_group1)
1686+
x2.set_downstream(task_group1)
1687+
task_group1.set_downstream(x3)
1688+
1689+
1690+
:param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
1691+
"""
1692+
for up_task, down_task in zip(tasks, tasks[1:]):
1693+
if isinstance(up_task, DependencyMixin):
1694+
up_task.set_downstream(down_task)
1695+
continue
1696+
if isinstance(down_task, DependencyMixin):
1697+
down_task.set_upstream(up_task)
1698+
continue
1699+
if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
1700+
raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
1701+
up_task_list = up_task
1702+
down_task_list = down_task
1703+
if len(up_task_list) != len(down_task_list):
1704+
raise ValueError(
1705+
f"Chain not supported for different length Iterable. "
1706+
f"Got {len(up_task_list)} and {len(down_task_list)}."
1707+
)
1708+
for up_t, down_t in zip(up_task_list, down_task_list):
1709+
up_t.set_downstream(down_t)
1710+
1711+
1712+
def cross_downstream(
1713+
from_tasks: Sequence[DependencyMixin],
1714+
to_tasks: DependencyMixin | Sequence[DependencyMixin],
1715+
):
1716+
r"""
1717+
Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
1718+
1719+
Using classic operators/sensors:
1720+
1721+
.. code-block:: python
1722+
1723+
cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
1724+
1725+
is equivalent to::
1726+
1727+
t1 ---> t4
1728+
\ /
1729+
t2 -X -> t5
1730+
/ \
1731+
t3 ---> t6
1732+
1733+
.. code-block:: python
1734+
1735+
t1.set_downstream(t4)
1736+
t1.set_downstream(t5)
1737+
t1.set_downstream(t6)
1738+
t2.set_downstream(t4)
1739+
t2.set_downstream(t5)
1740+
t2.set_downstream(t6)
1741+
t3.set_downstream(t4)
1742+
t3.set_downstream(t5)
1743+
t3.set_downstream(t6)
1744+
1745+
Using task-decorated functions aka XComArgs:
1746+
1747+
.. code-block:: python
1748+
1749+
cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
1750+
1751+
is equivalent to::
1752+
1753+
x1 ---> x4
1754+
\ /
1755+
x2 -X -> x5
1756+
/ \
1757+
x3 ---> x6
1758+
1759+
.. code-block:: python
1760+
1761+
x1 = x1()
1762+
x2 = x2()
1763+
x3 = x3()
1764+
x4 = x4()
1765+
x5 = x5()
1766+
x6 = x6()
1767+
x1.set_downstream(x4)
1768+
x1.set_downstream(x5)
1769+
x1.set_downstream(x6)
1770+
x2.set_downstream(x4)
1771+
x2.set_downstream(x5)
1772+
x2.set_downstream(x6)
1773+
x3.set_downstream(x4)
1774+
x3.set_downstream(x5)
1775+
x3.set_downstream(x6)
1776+
1777+
It is also possible to mix between classic operator/sensor and XComArg tasks:
1778+
1779+
.. code-block:: python
1780+
1781+
cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1782+
1783+
is equivalent to::
1784+
1785+
t1 ---> x1
1786+
\ /
1787+
x2 -X -> t2
1788+
/ \
1789+
t3 ---> x3
1790+
1791+
.. code-block:: python
1792+
1793+
x1 = x1()
1794+
x2 = x2()
1795+
x3 = x3()
1796+
t1.set_downstream(x1)
1797+
t1.set_downstream(t2)
1798+
t1.set_downstream(x3)
1799+
x2.set_downstream(x1)
1800+
x2.set_downstream(t2)
1801+
x2.set_downstream(x3)
1802+
t3.set_downstream(x1)
1803+
t3.set_downstream(t2)
1804+
t3.set_downstream(x3)
1805+
1806+
:param from_tasks: List of tasks or XComArgs to start from.
1807+
:param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1808+
"""
1809+
for task in from_tasks:
1810+
task.set_downstream(to_tasks)
1811+
1812+
1813+
def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
1814+
"""
1815+
Simplify task dependency definition.
1816+
1817+
E.g.: suppose you want precedence like so::
1818+
1819+
╭─op2─╮ ╭─op4─╮
1820+
op1─┤ ├─├─op5─┤─op7
1821+
╰-op3─╯ ╰-op6─╯
1822+
1823+
Then you can accomplish like so::
1824+
1825+
chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
1826+
1827+
:param elements: a list of operators / lists of operators
1828+
"""
1829+
if not elements:
1830+
raise ValueError("No tasks provided; nothing to do.")
1831+
prev_elem = None
1832+
deps_set = False
1833+
for curr_elem in elements:
1834+
if isinstance(curr_elem, EdgeModifier):
1835+
raise ValueError("Labels are not supported by chain_linear")
1836+
if prev_elem is not None:
1837+
for task in prev_elem:
1838+
task >> curr_elem
1839+
if not deps_set:
1840+
deps_set = True
1841+
prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem
1842+
if not deps_set:
1843+
raise ValueError("No dependencies were set. Did you forget to expand with `*`?")

‎task-sdk/src/airflow/sdk/definitions/xcom_arg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535

3636
if TYPE_CHECKING:
3737
from airflow.sdk.definitions.baseoperator import BaseOperator
38+
from airflow.sdk.definitions.edges import EdgeModifier
3839
from airflow.sdk.types import Operator
39-
from airflow.utils.edgemodifier import EdgeModifier
4040

4141
# Callable objects contained by MapXComArg. We only accept callables from
4242
# the user, but deserialize them into strings in a serialized XComArg for

‎task-sdk/tests/task_sdk/definitions/test_baseoperator.py

+198-1
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,18 @@
2828
import pytest
2929
import structlog
3030

31-
from airflow.sdk.definitions.baseoperator import BaseOperator, BaseOperatorMeta, ExecutorSafeguard
31+
from airflow.decorators import task as task_decorator
32+
from airflow.sdk.definitions.baseoperator import (
33+
BaseOperator,
34+
BaseOperatorMeta,
35+
ExecutorSafeguard,
36+
chain,
37+
chain_linear,
38+
cross_downstream,
39+
)
3240
from airflow.sdk.definitions.dag import DAG
41+
from airflow.sdk.definitions.edges import Label
42+
from airflow.sdk.definitions.taskgroup import TaskGroup
3343
from airflow.sdk.definitions.template import literal
3444
from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy
3545

@@ -264,6 +274,193 @@ def test_upstream_is_set_when_template_field_is_xcomarg(self):
264274
assert op1.task_id in op2.upstream_task_ids
265275
assert op2.task_id in op1.downstream_task_ids
266276

277+
def test_cross_downstream(self):
278+
"""Test if all dependencies between tasks are all set correctly."""
279+
dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime.now())
280+
start_tasks = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 4)]
281+
end_tasks = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(4, 7)]
282+
cross_downstream(from_tasks=start_tasks, to_tasks=end_tasks)
283+
284+
for start_task in start_tasks:
285+
assert set(start_task.get_direct_relatives(upstream=False)) == set(end_tasks)
286+
287+
# Begin test for `XComArgs`
288+
xstart_tasks = [
289+
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
290+
for i in range(1, 4)
291+
]
292+
xend_tasks = [
293+
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
294+
for i in range(4, 7)
295+
]
296+
cross_downstream(from_tasks=xstart_tasks, to_tasks=xend_tasks)
297+
298+
for xstart_task in xstart_tasks:
299+
assert set(xstart_task.operator.get_direct_relatives(upstream=False)) == {
300+
xend_task.operator for xend_task in xend_tasks
301+
}
302+
303+
def test_chain(self):
304+
dag = DAG(dag_id="test_chain", schedule=None, start_date=datetime.now())
305+
306+
# Begin test for classic operators with `EdgeModifiers`
307+
[label1, label2] = [Label(label=f"label{i}") for i in range(1, 3)]
308+
[op1, op2, op3, op4, op5, op6] = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 7)]
309+
chain(op1, [label1, label2], [op2, op3], [op4, op5], op6)
310+
311+
assert {op2, op3} == set(op1.get_direct_relatives(upstream=False))
312+
assert [op4] == op2.get_direct_relatives(upstream=False)
313+
assert [op5] == op3.get_direct_relatives(upstream=False)
314+
assert {op4, op5} == set(op6.get_direct_relatives(upstream=True))
315+
316+
assert dag.get_edge_info(upstream_task_id=op1.task_id, downstream_task_id=op2.task_id) == {
317+
"label": "label1"
318+
}
319+
assert dag.get_edge_info(upstream_task_id=op1.task_id, downstream_task_id=op3.task_id) == {
320+
"label": "label2"
321+
}
322+
323+
# Begin test for `XComArgs` with `EdgeModifiers`
324+
[xlabel1, xlabel2] = [Label(label=f"xcomarg_label{i}") for i in range(1, 3)]
325+
[xop1, xop2, xop3, xop4, xop5, xop6] = [
326+
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
327+
for i in range(1, 7)
328+
]
329+
chain(xop1, [xlabel1, xlabel2], [xop2, xop3], [xop4, xop5], xop6)
330+
331+
assert {xop2.operator, xop3.operator} == set(xop1.operator.get_direct_relatives(upstream=False))
332+
assert [xop4.operator] == xop2.operator.get_direct_relatives(upstream=False)
333+
assert [xop5.operator] == xop3.operator.get_direct_relatives(upstream=False)
334+
assert {xop4.operator, xop5.operator} == set(xop6.operator.get_direct_relatives(upstream=True))
335+
336+
assert dag.get_edge_info(
337+
upstream_task_id=xop1.operator.task_id, downstream_task_id=xop2.operator.task_id
338+
) == {"label": "xcomarg_label1"}
339+
assert dag.get_edge_info(
340+
upstream_task_id=xop1.operator.task_id, downstream_task_id=xop3.operator.task_id
341+
) == {"label": "xcomarg_label2"}
342+
343+
# Begin test for `TaskGroups`
344+
[tg1, tg2] = [TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 3)]
345+
[op1, op2] = [BaseOperator(task_id=f"task{i}", dag=dag) for i in range(1, 3)]
346+
[tgop1, tgop2] = [
347+
BaseOperator(task_id=f"task_group_task{i}", task_group=tg1, dag=dag) for i in range(1, 3)
348+
]
349+
[tgop3, tgop4] = [
350+
BaseOperator(task_id=f"task_group_task{i}", task_group=tg2, dag=dag) for i in range(1, 3)
351+
]
352+
chain(op1, tg1, tg2, op2)
353+
354+
assert {tgop1, tgop2} == set(op1.get_direct_relatives(upstream=False))
355+
assert {tgop3, tgop4} == set(tgop1.get_direct_relatives(upstream=False))
356+
assert {tgop3, tgop4} == set(tgop2.get_direct_relatives(upstream=False))
357+
assert [op2] == tgop3.get_direct_relatives(upstream=False)
358+
assert [op2] == tgop4.get_direct_relatives(upstream=False)
359+
360+
def test_chain_linear(self):
361+
dag = DAG(dag_id="test_chain_linear", schedule=None, start_date=datetime.now())
362+
363+
t1, t2, t3, t4, t5, t6, t7 = (BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 8))
364+
chain_linear(t1, [t2, t3, t4], [t5, t6], t7)
365+
366+
assert set(t1.get_direct_relatives(upstream=False)) == {t2, t3, t4}
367+
assert set(t2.get_direct_relatives(upstream=False)) == {t5, t6}
368+
assert set(t3.get_direct_relatives(upstream=False)) == {t5, t6}
369+
assert set(t7.get_direct_relatives(upstream=True)) == {t5, t6}
370+
371+
t1, t2, t3, t4, t5, t6 = (
372+
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
373+
for i in range(1, 7)
374+
)
375+
chain_linear(t1, [t2, t3], [t4, t5], t6)
376+
377+
assert set(t1.operator.get_direct_relatives(upstream=False)) == {t2.operator, t3.operator}
378+
assert set(t2.operator.get_direct_relatives(upstream=False)) == {t4.operator, t5.operator}
379+
assert set(t3.operator.get_direct_relatives(upstream=False)) == {t4.operator, t5.operator}
380+
assert set(t6.operator.get_direct_relatives(upstream=True)) == {t4.operator, t5.operator}
381+
382+
# Begin test for `TaskGroups`
383+
tg1, tg2 = (TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 3))
384+
op1, op2 = (BaseOperator(task_id=f"task{i}", dag=dag) for i in range(1, 3))
385+
tgop1, tgop2 = (
386+
BaseOperator(task_id=f"task_group_task{i}", task_group=tg1, dag=dag) for i in range(1, 3)
387+
)
388+
tgop3, tgop4 = (
389+
BaseOperator(task_id=f"task_group_task{i}", task_group=tg2, dag=dag) for i in range(1, 3)
390+
)
391+
chain_linear(op1, tg1, tg2, op2)
392+
393+
assert set(op1.get_direct_relatives(upstream=False)) == {tgop1, tgop2}
394+
assert set(tgop1.get_direct_relatives(upstream=False)) == {tgop3, tgop4}
395+
assert set(tgop2.get_direct_relatives(upstream=False)) == {tgop3, tgop4}
396+
assert set(tgop3.get_direct_relatives(upstream=False)) == {op2}
397+
assert set(tgop4.get_direct_relatives(upstream=False)) == {op2}
398+
399+
t1, t2 = (BaseOperator(task_id=f"t-{i}", dag=dag) for i in range(1, 3))
400+
with pytest.raises(ValueError, match="Labels are not supported"):
401+
chain_linear(t1, Label("hi"), t2)
402+
403+
with pytest.raises(ValueError, match="nothing to do"):
404+
chain_linear()
405+
406+
with pytest.raises(ValueError, match="Did you forget to expand"):
407+
chain_linear(t1)
408+
409+
def test_chain_not_support_type(self):
410+
dag = DAG(dag_id="test_chain", schedule=None, start_date=datetime.now())
411+
[op1, op2] = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 3)]
412+
with pytest.raises(TypeError):
413+
chain([op1, op2], 1)
414+
415+
# Begin test for `XComArgs`
416+
[xop1, xop2] = [
417+
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
418+
for i in range(1, 3)
419+
]
420+
421+
with pytest.raises(TypeError):
422+
chain([xop1, xop2], 1)
423+
424+
# Begin test for `EdgeModifiers`
425+
with pytest.raises(TypeError):
426+
chain([Label("labe1"), Label("label2")], 1)
427+
428+
# Begin test for `TaskGroups`
429+
[tg1, tg2] = [TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 3)]
430+
431+
with pytest.raises(TypeError):
432+
chain([tg1, tg2], 1)
433+
434+
def test_chain_different_length_iterable(self):
435+
dag = DAG(dag_id="test_chain", schedule=None, start_date=datetime.now())
436+
[label1, label2] = [Label(label=f"label{i}") for i in range(1, 3)]
437+
[op1, op2, op3, op4, op5] = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 6)]
438+
439+
with pytest.raises(ValueError):
440+
chain([op1, op2], [op3, op4, op5])
441+
442+
with pytest.raises(ValueError):
443+
chain([op1, op2, op3], [label1, label2])
444+
445+
# Begin test for `XComArgs` with `EdgeModifiers`
446+
[label3, label4] = [Label(label=f"xcomarg_label{i}") for i in range(1, 3)]
447+
[xop1, xop2, xop3, xop4, xop5] = [
448+
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
449+
for i in range(1, 6)
450+
]
451+
452+
with pytest.raises(ValueError):
453+
chain([xop1, xop2], [xop3, xop4, xop5])
454+
455+
with pytest.raises(ValueError):
456+
chain([xop1, xop2, xop3], [label1, label2])
457+
458+
# Begin test for `TaskGroups`
459+
[tg1, tg2, tg3, tg4, tg5] = [TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 6)]
460+
461+
with pytest.raises(ValueError):
462+
chain([tg1, tg2], [tg3, tg4, tg5])
463+
267464
def test_set_xcomargs_dependencies_works_recursively(self):
268465
with DAG("xcomargs_test", schedule=None):
269466
op1 = BaseOperator(task_id="op1")

‎tests/models/test_baseoperator.py

+1-192
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,16 @@
2525
import pytest
2626

2727
from airflow.decorators import task as task_decorator
28-
from airflow.exceptions import AirflowException, TaskDeferralTimeout
28+
from airflow.exceptions import TaskDeferralTimeout
2929
from airflow.models.baseoperator import (
3030
BaseOperator,
31-
chain,
32-
chain_linear,
33-
cross_downstream,
3431
)
3532
from airflow.models.dag import DAG
3633
from airflow.models.dagrun import DagRun
3734
from airflow.models.taskinstance import TaskInstance
3835
from airflow.models.trigger import TriggerFailureReason
3936
from airflow.providers.common.compat.lineage.entities import File
4037
from airflow.providers.common.sql.operators import sql
41-
from airflow.utils.edgemodifier import Label
4238
from airflow.utils.task_group import TaskGroup
4339
from airflow.utils.trigger_rule import TriggerRule
4440
from airflow.utils.types import DagRunType
@@ -91,89 +87,6 @@ def test_trigger_rule_validation(self):
9187
task_id="test_valid_trigger_rule", dag=non_fail_fast_dag, trigger_rule=TriggerRule.ALWAYS
9288
)
9389

94-
def test_cross_downstream(self):
95-
"""Test if all dependencies between tasks are all set correctly."""
96-
dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime.now())
97-
start_tasks = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 4)]
98-
end_tasks = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(4, 7)]
99-
cross_downstream(from_tasks=start_tasks, to_tasks=end_tasks)
100-
101-
for start_task in start_tasks:
102-
assert set(start_task.get_direct_relatives(upstream=False)) == set(end_tasks)
103-
104-
# Begin test for `XComArgs`
105-
xstart_tasks = [
106-
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
107-
for i in range(1, 4)
108-
]
109-
xend_tasks = [
110-
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
111-
for i in range(4, 7)
112-
]
113-
cross_downstream(from_tasks=xstart_tasks, to_tasks=xend_tasks)
114-
115-
for xstart_task in xstart_tasks:
116-
assert set(xstart_task.operator.get_direct_relatives(upstream=False)) == {
117-
xend_task.operator for xend_task in xend_tasks
118-
}
119-
120-
def test_chain(self):
121-
dag = DAG(dag_id="test_chain", schedule=None, start_date=datetime.now())
122-
123-
# Begin test for classic operators with `EdgeModifiers`
124-
[label1, label2] = [Label(label=f"label{i}") for i in range(1, 3)]
125-
[op1, op2, op3, op4, op5, op6] = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 7)]
126-
chain(op1, [label1, label2], [op2, op3], [op4, op5], op6)
127-
128-
assert {op2, op3} == set(op1.get_direct_relatives(upstream=False))
129-
assert [op4] == op2.get_direct_relatives(upstream=False)
130-
assert [op5] == op3.get_direct_relatives(upstream=False)
131-
assert {op4, op5} == set(op6.get_direct_relatives(upstream=True))
132-
133-
assert dag.get_edge_info(upstream_task_id=op1.task_id, downstream_task_id=op2.task_id) == {
134-
"label": "label1"
135-
}
136-
assert dag.get_edge_info(upstream_task_id=op1.task_id, downstream_task_id=op3.task_id) == {
137-
"label": "label2"
138-
}
139-
140-
# Begin test for `XComArgs` with `EdgeModifiers`
141-
[xlabel1, xlabel2] = [Label(label=f"xcomarg_label{i}") for i in range(1, 3)]
142-
[xop1, xop2, xop3, xop4, xop5, xop6] = [
143-
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
144-
for i in range(1, 7)
145-
]
146-
chain(xop1, [xlabel1, xlabel2], [xop2, xop3], [xop4, xop5], xop6)
147-
148-
assert {xop2.operator, xop3.operator} == set(xop1.operator.get_direct_relatives(upstream=False))
149-
assert [xop4.operator] == xop2.operator.get_direct_relatives(upstream=False)
150-
assert [xop5.operator] == xop3.operator.get_direct_relatives(upstream=False)
151-
assert {xop4.operator, xop5.operator} == set(xop6.operator.get_direct_relatives(upstream=True))
152-
153-
assert dag.get_edge_info(
154-
upstream_task_id=xop1.operator.task_id, downstream_task_id=xop2.operator.task_id
155-
) == {"label": "xcomarg_label1"}
156-
assert dag.get_edge_info(
157-
upstream_task_id=xop1.operator.task_id, downstream_task_id=xop3.operator.task_id
158-
) == {"label": "xcomarg_label2"}
159-
160-
# Begin test for `TaskGroups`
161-
[tg1, tg2] = [TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 3)]
162-
[op1, op2] = [BaseOperator(task_id=f"task{i}", dag=dag) for i in range(1, 3)]
163-
[tgop1, tgop2] = [
164-
BaseOperator(task_id=f"task_group_task{i}", task_group=tg1, dag=dag) for i in range(1, 3)
165-
]
166-
[tgop3, tgop4] = [
167-
BaseOperator(task_id=f"task_group_task{i}", task_group=tg2, dag=dag) for i in range(1, 3)
168-
]
169-
chain(op1, tg1, tg2, op2)
170-
171-
assert {tgop1, tgop2} == set(op1.get_direct_relatives(upstream=False))
172-
assert {tgop3, tgop4} == set(tgop1.get_direct_relatives(upstream=False))
173-
assert {tgop3, tgop4} == set(tgop2.get_direct_relatives(upstream=False))
174-
assert [op2] == tgop3.get_direct_relatives(upstream=False)
175-
assert [op2] == tgop4.get_direct_relatives(upstream=False)
176-
17790
def test_baseoperator_raises_exception_when_task_id_plus_taskgroup_id_exceeds_250_chars(self):
17891
"""Test exception is raised when operator task id + taskgroup id > 250 chars."""
17992
dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now())
@@ -201,110 +114,6 @@ def test_baseoperator_with_task_id_less_than_250_chars(self):
201114
except Exception as e:
202115
pytest.fail(f"Exception raised: {e}")
203116

204-
def test_chain_linear(self):
205-
dag = DAG(dag_id="test_chain_linear", schedule=None, start_date=datetime.now())
206-
207-
t1, t2, t3, t4, t5, t6, t7 = (BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 8))
208-
chain_linear(t1, [t2, t3, t4], [t5, t6], t7)
209-
210-
assert set(t1.get_direct_relatives(upstream=False)) == {t2, t3, t4}
211-
assert set(t2.get_direct_relatives(upstream=False)) == {t5, t6}
212-
assert set(t3.get_direct_relatives(upstream=False)) == {t5, t6}
213-
assert set(t7.get_direct_relatives(upstream=True)) == {t5, t6}
214-
215-
t1, t2, t3, t4, t5, t6 = (
216-
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
217-
for i in range(1, 7)
218-
)
219-
chain_linear(t1, [t2, t3], [t4, t5], t6)
220-
221-
assert set(t1.operator.get_direct_relatives(upstream=False)) == {t2.operator, t3.operator}
222-
assert set(t2.operator.get_direct_relatives(upstream=False)) == {t4.operator, t5.operator}
223-
assert set(t3.operator.get_direct_relatives(upstream=False)) == {t4.operator, t5.operator}
224-
assert set(t6.operator.get_direct_relatives(upstream=True)) == {t4.operator, t5.operator}
225-
226-
# Begin test for `TaskGroups`
227-
tg1, tg2 = (TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 3))
228-
op1, op2 = (BaseOperator(task_id=f"task{i}", dag=dag) for i in range(1, 3))
229-
tgop1, tgop2 = (
230-
BaseOperator(task_id=f"task_group_task{i}", task_group=tg1, dag=dag) for i in range(1, 3)
231-
)
232-
tgop3, tgop4 = (
233-
BaseOperator(task_id=f"task_group_task{i}", task_group=tg2, dag=dag) for i in range(1, 3)
234-
)
235-
chain_linear(op1, tg1, tg2, op2)
236-
237-
assert set(op1.get_direct_relatives(upstream=False)) == {tgop1, tgop2}
238-
assert set(tgop1.get_direct_relatives(upstream=False)) == {tgop3, tgop4}
239-
assert set(tgop2.get_direct_relatives(upstream=False)) == {tgop3, tgop4}
240-
assert set(tgop3.get_direct_relatives(upstream=False)) == {op2}
241-
assert set(tgop4.get_direct_relatives(upstream=False)) == {op2}
242-
243-
t1, t2 = (BaseOperator(task_id=f"t-{i}", dag=dag) for i in range(1, 3))
244-
with pytest.raises(ValueError, match="Labels are not supported"):
245-
chain_linear(t1, Label("hi"), t2)
246-
247-
with pytest.raises(ValueError, match="nothing to do"):
248-
chain_linear()
249-
250-
with pytest.raises(ValueError, match="Did you forget to expand"):
251-
chain_linear(t1)
252-
253-
def test_chain_not_support_type(self):
254-
dag = DAG(dag_id="test_chain", schedule=None, start_date=datetime.now())
255-
[op1, op2] = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 3)]
256-
with pytest.raises(TypeError):
257-
chain([op1, op2], 1)
258-
259-
# Begin test for `XComArgs`
260-
[xop1, xop2] = [
261-
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
262-
for i in range(1, 3)
263-
]
264-
265-
with pytest.raises(TypeError):
266-
chain([xop1, xop2], 1)
267-
268-
# Begin test for `EdgeModifiers`
269-
with pytest.raises(TypeError):
270-
chain([Label("labe1"), Label("label2")], 1)
271-
272-
# Begin test for `TaskGroups`
273-
[tg1, tg2] = [TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 3)]
274-
275-
with pytest.raises(TypeError):
276-
chain([tg1, tg2], 1)
277-
278-
def test_chain_different_length_iterable(self):
279-
dag = DAG(dag_id="test_chain", schedule=None, start_date=datetime.now())
280-
[label1, label2] = [Label(label=f"label{i}") for i in range(1, 3)]
281-
[op1, op2, op3, op4, op5] = [BaseOperator(task_id=f"t{i}", dag=dag) for i in range(1, 6)]
282-
283-
with pytest.raises(AirflowException):
284-
chain([op1, op2], [op3, op4, op5])
285-
286-
with pytest.raises(AirflowException):
287-
chain([op1, op2, op3], [label1, label2])
288-
289-
# Begin test for `XComArgs` with `EdgeModifiers`
290-
[label3, label4] = [Label(label=f"xcomarg_label{i}") for i in range(1, 3)]
291-
[xop1, xop2, xop3, xop4, xop5] = [
292-
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
293-
for i in range(1, 6)
294-
]
295-
296-
with pytest.raises(AirflowException):
297-
chain([xop1, xop2], [xop3, xop4, xop5])
298-
299-
with pytest.raises(AirflowException):
300-
chain([xop1, xop2, xop3], [label1, label2])
301-
302-
# Begin test for `TaskGroups`
303-
[tg1, tg2, tg3, tg4, tg5] = [TaskGroup(group_id=f"tg{i}", dag=dag) for i in range(1, 6)]
304-
305-
with pytest.raises(AirflowException):
306-
chain([tg1, tg2], [tg3, tg4, tg5])
307-
308117
def test_lineage_composition(self):
309118
"""
310119
Test composition with lineage

‎tests/system/example_empty.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
from datetime import datetime
2020

21-
from airflow.models.baseoperator import chain
2221
from airflow.models.dag import DAG
2322
from airflow.providers.standard.operators.empty import EmptyOperator
23+
from airflow.sdk import chain
2424

2525
DAG_ID = "example_empty"
2626

‎tests/utils/test_task_group.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1458,7 +1458,7 @@ def test_add_to_another_group():
14581458

14591459

14601460
def test_task_group_edge_modifier_chain():
1461-
from airflow.models.baseoperator import chain
1461+
from airflow.sdk import chain
14621462
from airflow.utils.edgemodifier import Label
14631463

14641464
with DAG(dag_id="test", schedule=None, start_date=pendulum.DateTime(2022, 5, 20)) as dag:

0 commit comments

Comments
 (0)
Please sign in to comment.