Skip to content

Commit

Permalink
make the run execution concurrency context aware of pool granularity
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Feb 1, 2025
1 parent 4090bc9 commit 3650bb9
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def get_steps_to_execute(
step_priority = 0

if not self._instance_concurrency_context.claim(
concurrency_key, step.key, step_priority
concurrency_key, step.key, step_priority, is_legacy_tag=not step.pool
):
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing_extensions import Self

from dagster._core.instance import DagsterInstance
from dagster._core.instance.config import PoolGranularity
from dagster._core.storage.dagster_run import DagsterRun
from dagster._core.storage.tags import PRIORITY_TAG

Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, instance: DagsterInstance, dagster_run: DagsterRun):
self._pending_timeouts = defaultdict(float)
self._pending_claim_counts = defaultdict(int)
self._pending_claims = set()
self._default_limit = instance.global_op_concurrency_default_limit
self._pool_config = instance.event_log_storage.get_pool_config()
self._claims = set()
try:
self._run_priority = int(dagster_run.tags.get(PRIORITY_TAG, "0"))
Expand Down Expand Up @@ -86,15 +87,26 @@ def _sync_pools(self) -> None:
pool_limits = self._instance.event_log_storage.get_pool_limits()
self._pools = {pool.name: pool for pool in pool_limits}

def claim(self, concurrency_key: str, step_key: str, step_priority: int = 0):
def claim(
self,
concurrency_key: str,
step_key: str,
step_priority: int = 0,
is_legacy_tag: bool = False,
) -> bool:
if not self._instance.event_log_storage.supports_global_concurrency_limits:
return True

if self._pool_config.pool_granularity == PoolGranularity.RUN or (
self._pool_config.pool_granularity is None and not is_legacy_tag
):
# short-circuit claiming global op concurrency slot claiming
return True

default_limit = self._pool_config.default_pool_limit
pool_info = self.get_pool_info(concurrency_key)
if (pool_info is None and self._default_limit is not None) or (
pool_info is not None
and pool_info.from_default
and pool_info.limit != self._default_limit
if (pool_info is None and default_limit is not None) or (
pool_info is not None and pool_info.from_default and pool_info.limit != default_limit
):
self._instance.event_log_storage.initialize_concurrency_limit_to_default(
concurrency_key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@pytest.fixture()
def concurrency_instance():
def concurrency_instance_default_granularity():
with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test(
overrides={
Expand All @@ -21,6 +21,42 @@ def concurrency_instance():
yield instance


@pytest.fixture()
def concurrency_instance_run_granularity():
with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test(
overrides={
"event_log_storage": {
"module": "dagster.utils.test",
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir},
},
"concurrency": {
"pools": {"granularity": "run"},
},
}
) as instance:
yield instance


@pytest.fixture()
def concurrency_instance_op_granularity():
with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test(
overrides={
"event_log_storage": {
"module": "dagster.utils.test",
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir},
},
"concurrency": {
"pools": {"granularity": "op"},
},
}
) as instance:
yield instance


@pytest.fixture()
def concurrency_instance_with_default_one():
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -31,7 +67,9 @@ def concurrency_instance_with_default_one():
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir},
},
"concurrency": {"default_op_concurrency_limit": 1},
"concurrency": {
"pools": {"granularity": "op", "default_limit": 1},
},
}
) as instance:
yield instance
Expand All @@ -47,6 +85,9 @@ def concurrency_custom_sleep_instance():
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir, "sleep_interval": CUSTOM_SLEEP_INTERVAL},
},
"concurrency": {
"pools": {"granularity": "op"},
},
}
) as instance:
yield instance
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def test_active_concurrency(use_tags):
"class": "ConcurrencyEnabledSqliteTestEventLogStorage",
"config": {"base_dir": temp_dir},
},
"concurrency": {
"pools": {"granularity": "op"},
},
}
) as instance:
assert instance.event_log_storage.supports_global_concurrency_limits
Expand Down Expand Up @@ -248,7 +251,9 @@ def __init__(self, interval: float):
def global_concurrency_keys(self) -> set[str]:
return {"foo"}

def claim(self, concurrency_key: str, step_key: str, priority: int = 0):
def claim(
self, concurrency_key: str, step_key: str, priority: int = 0, is_legacy_tag: bool = False
):
self._pending_claims.add(step_key)
return False

Expand Down
Loading

0 comments on commit 3650bb9

Please sign in to comment.