Skip to content

Commit ea420f8

Browse files
committed
refactor(BA-1004): Move elapsed time function
1 parent 7b87326 commit ea420f8

File tree

4 files changed

+16
-19
lines changed

4 files changed

+16
-19
lines changed

src/ai/backend/manager/cleanup/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python_sources(name="src")

src/ai/backend/manager/cleanup/kernel.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from contextlib import asynccontextmanager as actxmgr
66
from contextlib import suppress
7-
from typing import AsyncIterator, Any, Mapping, Protocol, Sequence
7+
from typing import Any, AsyncIterator, Mapping, Protocol, Sequence
88

99
import aiotools
1010
import sqlalchemy as sa
@@ -17,11 +17,9 @@
1717

1818
from ..api.context import RootContext
1919
from ..config import session_hang_tolerance_iv
20-
from ..models import KernelRow, SessionRow, DEAD_SESSION_STATUSES, DEAD_KERNEL_STATUSES
20+
from ..models import DEAD_KERNEL_STATUSES, DEAD_SESSION_STATUSES, KernelRow, SessionRow
2121
from ..models.utils import ExtendedAsyncSAEngine
2222

23-
__all__ = ("stale_kernel_collection_ctx",)
24-
2523
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
2624

2725

@@ -103,10 +101,10 @@ async def stale_kernel_collection_ctx(root_ctx: RootContext) -> AsyncIterator[No
103101
await root_ctx.shared_config.etcd.get_prefix_dict("config/session/hang-tolerance")
104102
)
105103
default_interval_sec = 60.0
106-
interval_sec: float = float("inf")
104+
interval_sec = float("inf")
107105
threshold: TimeDelta
108106
for threshold in session_hang_tolerance["threshold"].values():
109-
interval_sec = min(interval_sec, threshold.total_seconds())
107+
interval_sec = min(interval_sec, threshold.seconds)
110108
if interval_sec == float("inf"):
111109
interval_sec = default_interval_sec
112110
task = aiotools.create_timer(
@@ -120,7 +118,7 @@ async def stale_kernel_collection_ctx(root_ctx: RootContext) -> AsyncIterator[No
120118

121119
yield
122120

123-
if not task.done:
121+
if not task.done():
124122
task.cancel()
125123
with suppress(asyncio.CancelledError):
126124
await task

src/ai/backend/manager/cleanup/session.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from contextlib import asynccontextmanager as actxmgr
55
from contextlib import suppress
66
from datetime import datetime, timedelta
7-
from typing import Any, AsyncIterator, Mapping, Optional, Protocol, Sequence
7+
from typing import Any, AsyncIterator, Mapping, Optional, Protocol
88

99
import aiotools
1010
import sqlalchemy as sa
@@ -22,8 +22,6 @@
2222
from ..models.session import SessionStatus
2323
from ..models.utils import ExtendedAsyncSAEngine
2424

25-
__all__ = ("stale_session_collection_ctx",)
26-
2725
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
2826

2927

@@ -38,7 +36,7 @@ async def destroy_session(
3836
) -> Mapping[str, Any]: ...
3937

4038

41-
def get_interval(
39+
def _get_interval(
4240
threshold: TimeDelta,
4341
*,
4442
max_interval: float = timedelta(hours=1).total_seconds(),
@@ -60,11 +58,7 @@ async def handle_stale_sessions(
6058
query = (
6159
sa.select(SessionRow)
6260
.where(SessionRow.status == status)
63-
.where(
64-
now
65-
- SessionRow.status_history[status.name].astext.cast(sa.types.DateTime(timezone=True))
66-
> threshold
67-
)
61+
.where(SessionRow.get_status_elapsed_time(status, now).total_seconds() > threshold.seconds)
6862
.options(
6963
noload("*"),
7064
load_only(SessionRow.id, SessionRow.name, SessionRow.access_key),
@@ -109,7 +103,7 @@ async def stale_session_collection_ctx(root_ctx: RootContext) -> AsyncIterator[N
109103
log.warning(f"Invalid session status for hang-threshold: '{raw_status}'")
110104
continue
111105

112-
interval = get_interval(threshold)
106+
interval = _get_interval(threshold)
113107
tasks.append(
114108
aiotools.create_timer(
115109
functools.partial(
@@ -126,7 +120,7 @@ async def stale_session_collection_ctx(root_ctx: RootContext) -> AsyncIterator[N
126120
yield
127121

128122
for task in tasks:
129-
if not task.done:
123+
if not task.done():
130124
task.cancel()
131125
with suppress(asyncio.CancelledError):
132126
await task

src/ai/backend/manager/models/session.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Iterable, Mapping, Sequence
88
from contextlib import asynccontextmanager as actxmgr
99
from dataclasses import dataclass, field
10-
from datetime import datetime
10+
from datetime import datetime, timedelta
1111
from decimal import Decimal
1212
from typing import (
1313
TYPE_CHECKING,
@@ -1311,6 +1311,10 @@ async def get_network_ref(self, db_sess: SASession) -> str | None:
13111311
case _:
13121312
return None
13131313

1314+
@classmethod
1315+
def get_status_elapsed_time(cls, status: SessionStatus, now: datetime) -> timedelta:
1316+
return now - cls.status_history[status.name].astext.cast(sa.types.DateTime(timezone=True))
1317+
13141318

13151319
class SessionLifecycleManager:
13161320
status_set_key = "session_status_update"

0 commit comments

Comments
 (0)