Skip to content
6 changes: 6 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5714,6 +5714,12 @@ async def remove_worker(
f"Removing worker {ws.address!r} caused the cluster to lose scattered "
f"data, which can't be recovered: {lost_keys} ({stimulus_id=})"
)
if not expected and processing_keys:
logger.warning(
f"Worker {ws.address!r} dropped unexpectedly. "
f"Interrupting {len(processing_keys)} processing tasks: "
f"{processing_keys} ({stimulus_id=})"
)

event_msg = {
"action": "remove-worker",
Expand Down
26 changes: 24 additions & 2 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(self, scheduler: Scheduler):
self.metrics = {
"request_count_total": defaultdict(int),
"request_cost_total": defaultdict(int),
"reject_count_margin_total": defaultdict(int),
}
self._request_counter = 0
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
Expand Down Expand Up @@ -486,10 +487,20 @@ def balance(self) -> None:
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
compute = self.scheduler._get_prefix_duration(ts.prefix)
if (

# Require at least 50% ROI on the network transfer cost to prevent thrashing
margin = comm_cost_thief * 0.5

would_steal_without_margin = (
occ_thief + comm_cost_thief + compute
<= occ_victim - (comm_cost_victim + compute) / 2
):
)
would_steal_with_margin = (
occ_thief + comm_cost_thief + compute + margin
<= occ_victim - (comm_cost_victim + compute) / 2
)

if would_steal_with_margin:
self.move_task_request(ts, victim, thief)
cost = compute + comm_cost_victim
log.append(
Expand Down Expand Up @@ -520,6 +531,17 @@ def balance(self) -> None:
# for removing ts from stealable. If we made sure to
# properly clean up, we would not need this
stealable.discard(ts)
elif would_steal_without_margin:
self.metrics["reject_count_margin_total"][level] += 1
logger.debug(
"Work-stealing margin heuristic rejected steal of task %s "
"(thief=%s, victim=%s, level=%d, margin=%.4f)",
ts.key,
thief.address,
victim.address,
level,
margin,
)
self.scheduler.check_idle_saturated(
victim, occ=combined_occupancy(victim)
)
Expand Down
57 changes: 55 additions & 2 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence
from operator import mul
from time import sleep
from unittest.mock import patch

import pytest
from tlz import merge, sliding_window
Expand Down Expand Up @@ -1448,8 +1449,8 @@ def func(*args):
"cost, ntasks, expect_steal",
[
pytest.param(10, 10, False, id="not enough work to steal"),
pytest.param(10, 12, True, id="enough work to steal"),
pytest.param(20, 12, False, id="not enough work for increased cost"),
pytest.param(10, 17, True, id="enough work to steal"),
pytest.param(20, 17, False, id="not enough work for increased cost"),
],
)
def test_balance_expensive_tasks(cost, ntasks, expect_steal):
Expand Down Expand Up @@ -2010,6 +2011,58 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
await block_event.set()


@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * 2,
config={
"distributed.scheduler.work-stealing-interval": "100ms",
"distributed.scheduler.default-task-durations": {"slowidentity": 0.021},
**NO_AMM,
},
)
async def test_reject_count_margin_metric(c, s, a, b):
"""
Verify that the margin heuristic increments reject_count_margin_total
when a steal is suppressed that old logic would have permitted.
"""
steal = s.extensions["stealing"]
await steal.stop()

futures = c.map(
slowidentity,
range(21),
workers=a.address,
allow_other_workers=True,
delay=0.021,
)

while len(s.tasks) < 21:
await asyncio.sleep(0.01)

while len(a.state.tasks) < 21:
await asyncio.sleep(0.01)

for ws in s.workers.values():
s.check_idle_saturated(ws)

a_ws = s.workers[a.address]
b_ws = s.workers[b.address]
assert a_ws in s.saturated, (
f"Worker A not saturated: occupancy={a_ws.occupancy:.3f}, "
f"nthreads={a_ws.nthreads}, processing={len(a_ws.processing)}"
)
assert (
b_ws in s.idle.values()
), f"Worker B not idle: processing={len(b_ws.processing)}"

with patch.object(
s, "get_comm_cost", side_effect=lambda ts, ws: 0.3 if ws == b_ws else 0.0
):
steal.balance()

assert sum(steal.metrics["reject_count_margin_total"].values()) >= 1


@gen_cluster(
nthreads=[("", 1)],
client=True,
Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3008,6 +3008,8 @@ async def test_log_remove_worker(c, s, a, b):
"(stimulus_id='ungraceful')",
f"Removing worker '{b.address}' caused the cluster to lose scattered "
"data, which can't be recovered: {'z'} (stimulus_id='ungraceful')",
f"Worker {b.address!r} dropped unexpectedly. Interrupting 1 "
"processing tasks: {'y'} (stimulus_id='ungraceful')",
"Lost all workers",
]

Expand Down
Loading