Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change step execution to be aware of pool granularity #27478

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def get_steps_to_execute(
step_priority = 0

if not self._instance_concurrency_context.claim(
concurrency_key, step.key, step_priority
concurrency_key, step.key, step_priority, is_legacy_tag=not step.pool
):
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing_extensions import Self

from dagster._core.instance import DagsterInstance
from dagster._core.instance.config import PoolGranularity
from dagster._core.storage.dagster_run import DagsterRun
from dagster._core.storage.tags import PRIORITY_TAG

Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, instance: DagsterInstance, dagster_run: DagsterRun):
self._pending_timeouts = defaultdict(float)
self._pending_claim_counts = defaultdict(int)
self._pending_claims = set()
self._default_limit = instance.global_op_concurrency_default_limit
self._pool_config = instance.event_log_storage.get_pool_config()
self._claims = set()
try:
self._run_priority = int(dagster_run.tags.get(PRIORITY_TAG, "0"))
Expand Down Expand Up @@ -86,15 +87,26 @@ def _sync_pools(self) -> None:
pool_limits = self._instance.event_log_storage.get_pool_limits()
self._pools = {pool.name: pool for pool in pool_limits}

def claim(self, concurrency_key: str, step_key: str, step_priority: int = 0):
def claim(
self,
concurrency_key: str,
step_key: str,
step_priority: int = 0,
is_legacy_tag: bool = False,
) -> bool:
if not self._instance.event_log_storage.supports_global_concurrency_limits:
return True

if self._pool_config.pool_granularity == PoolGranularity.RUN or (
self._pool_config.pool_granularity is None and not is_legacy_tag
):
# short-circuit claiming global op concurrency slot claiming
return True

default_limit = self._pool_config.default_pool_limit
pool_info = self.get_pool_info(concurrency_key)
if (pool_info is None and self._default_limit is not None) or (
pool_info is not None
and pool_info.from_default
and pool_info.limit != self._default_limit
if (pool_info is None and default_limit is not None) or (
pool_info is not None and pool_info.from_default and pool_info.limit != default_limit
):
self._instance.event_log_storage.initialize_concurrency_limit_to_default(
concurrency_key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@pytest.fixture()
def concurrency_instance():
def concurrency_instance_default_granularity():
with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test(
overrides={
Expand All @@ -21,6 +21,42 @@ def concurrency_instance():
yield instance


@pytest.fixture()
def concurrency_instance_run_granularity():
with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test(
overrides={
"event_log_storage": {
"module": "dagster.utils.test",
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir},
},
"concurrency": {
"pools": {"granularity": "run"},
},
}
) as instance:
yield instance


@pytest.fixture()
def concurrency_instance_op_granularity():
with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test(
overrides={
"event_log_storage": {
"module": "dagster.utils.test",
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir},
},
"concurrency": {
"pools": {"granularity": "op"},
},
}
) as instance:
yield instance


@pytest.fixture()
def concurrency_instance_with_default_one():
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -31,7 +67,9 @@ def concurrency_instance_with_default_one():
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir},
},
"concurrency": {"default_op_concurrency_limit": 1},
"concurrency": {
"pools": {"granularity": "op", "default_limit": 1},
},
}
) as instance:
yield instance
Expand All @@ -47,6 +85,9 @@ def concurrency_custom_sleep_instance():
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir, "sleep_interval": CUSTOM_SLEEP_INTERVAL},
},
"concurrency": {
"pools": {"granularity": "op"},
},
}
) as instance:
yield instance
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def test_active_concurrency(use_tags):
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir},
},
"concurrency": {
"pools": {"granularity": "op"},
},
}
) as instance:
assert instance.event_log_storage.supports_global_concurrency_limits
Expand Down Expand Up @@ -248,7 +251,9 @@ def __init__(self, interval: float):
def global_concurrency_keys(self) -> set[str]:
return {"foo"}

def claim(self, concurrency_key: str, step_key: str, priority: int = 0):
def claim(
self, concurrency_key: str, step_key: str, priority: int = 0, is_legacy_tag: bool = False
):
self._pending_claims.add(step_key)
return False

Expand Down
Loading