Skip to content

Commit 4882fb5

Browse files
committed
refactor(BA-1004): Sweeper
1 parent ea420f8 commit 4882fb5

File tree

9 files changed

+222
-254
lines changed

9 files changed

+222
-254
lines changed

src/ai/backend/manager/cleanup/kernel.py

-124
This file was deleted.

src/ai/backend/manager/cleanup/session.py

-126
This file was deleted.

src/ai/backend/manager/models/session.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Iterable, Mapping, Sequence
88
from contextlib import asynccontextmanager as actxmgr
99
from dataclasses import dataclass, field
10-
from datetime import datetime, timedelta
10+
from datetime import datetime
1111
from decimal import Decimal
1212
from typing import (
1313
TYPE_CHECKING,
@@ -1312,8 +1312,10 @@ async def get_network_ref(self, db_sess: SASession) -> str | None:
13121312
return None
13131313

13141314
@classmethod
1315-
def get_status_elapsed_time(cls, status: SessionStatus, now: datetime) -> timedelta:
1316-
return now - cls.status_history[status.name].astext.cast(sa.types.DateTime(timezone=True))
1315+
def get_status_elapsed_time(
1316+
cls, status: SessionStatus, until: datetime
1317+
) -> sa.sql.elements.BinaryExpression:
1318+
return until - cls.status_history[status.name].astext.cast(sa.types.DateTime(timezone=True))
13171319

13181320

13191321
class SessionLifecycleManager:

src/ai/backend/manager/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@
9292
CleanupContext,
9393
WebRequestHandler,
9494
)
95-
from .cleanup import stale_kernel_collection_ctx, stale_session_collection_ctx
9695
from .config import LocalConfig, SharedConfig, volume_config_iv
9796
from .config import load as load_config
9897
from .exceptions import InvalidArgument
98+
from .sweeper import stale_kernel_collection_ctx, stale_session_collection_ctx
9999
from .types import DistributedLockFactory
100100

101101
VALID_VERSIONS: Final = frozenset([
File renamed without changes.
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import abc
2+
3+
4+
class AbstractSweeper(abc.ABC):
5+
@abc.abstractmethod
6+
async def sweep(self, *args) -> None:
7+
raise NotImplementedError
+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import asyncio
2+
import logging
3+
from collections import defaultdict
4+
from contextlib import asynccontextmanager as actxmgr
5+
from contextlib import suppress
6+
from typing import AsyncIterator, override
7+
8+
import aiotools
9+
import sqlalchemy as sa
10+
from sqlalchemy.orm import load_only, noload
11+
12+
from ai.backend.common.events import KernelLifecycleEventReason
13+
from ai.backend.common.validators import TimeDelta
14+
from ai.backend.logging import BraceStyleAdapter
15+
16+
from ..api.context import RootContext
17+
from ..config import session_hang_tolerance_iv
18+
from ..models import DEAD_KERNEL_STATUSES, DEAD_SESSION_STATUSES, KernelRow, SessionRow
19+
from .base import AbstractSweeper
20+
21+
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
22+
23+
24+
class KernelSweeper(AbstractSweeper):
25+
_root_ctx: RootContext
26+
27+
def __init__(self, root_ctx: RootContext) -> None:
28+
self._root_ctx = root_ctx
29+
30+
@override
31+
async def sweep(self, *args) -> None:
32+
query = (
33+
sa.select(KernelRow)
34+
.join(SessionRow, KernelRow.session_id == SessionRow.id)
35+
.where(KernelRow.status.not_in(DEAD_KERNEL_STATUSES))
36+
.where(SessionRow.status.in_(DEAD_SESSION_STATUSES))
37+
.options(
38+
noload("*"),
39+
load_only(
40+
KernelRow.id,
41+
KernelRow.session_id,
42+
KernelRow.agent,
43+
KernelRow.agent_addr,
44+
KernelRow.container_id,
45+
),
46+
)
47+
)
48+
49+
async with self._root_ctx.db.begin_readonly() as conn:
50+
result = await conn.execute(query)
51+
kernels = result.fetchall()
52+
53+
kernels_per_session = defaultdict(list)
54+
for kernel in kernels:
55+
kernels_per_session[kernel.session_id].append(kernel)
56+
57+
await asyncio.gather(
58+
*[
59+
asyncio.create_task(
60+
self._root_ctx.registry.destroy_session_lowlevel(
61+
session_id,
62+
[
63+
{
64+
"id": kernel.id,
65+
"session_id": kernel.session_id,
66+
"agent": kernel.agent,
67+
"agent_addr": kernel.agent_addr,
68+
"container_id": kernel.container_id,
69+
}
70+
for kernel in kernels_
71+
],
72+
reason=KernelLifecycleEventReason.HANG_TIMEOUT,
73+
)
74+
)
75+
for session_id, kernels_ in kernels_per_session.items()
76+
],
77+
return_exceptions=False,
78+
)
79+
80+
81+
@actxmgr
82+
async def stale_kernel_collection_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
83+
session_hang_tolerance = session_hang_tolerance_iv.check(
84+
await root_ctx.shared_config.etcd.get_prefix_dict("config/session/hang-tolerance")
85+
)
86+
default_interval_sec = 60.0
87+
interval_sec = float("inf")
88+
threshold: TimeDelta
89+
for threshold in session_hang_tolerance["threshold"].values():
90+
interval_sec = min(interval_sec, threshold.seconds)
91+
if interval_sec == float("inf"):
92+
interval_sec = default_interval_sec
93+
task = aiotools.create_timer(
94+
KernelSweeper(root_ctx).sweep,
95+
interval=interval_sec,
96+
)
97+
98+
yield
99+
100+
if not task.done():
101+
task.cancel()
102+
with suppress(asyncio.CancelledError):
103+
await task

0 commit comments

Comments
 (0)