Skip to content

Commit 499a425

Browse files
committed
Send kv events from worker side to scheduler side
This is required for when worker side operations like CPU offloading generate KV cache events. This commit enables theses events to be passed to the scheduler side so that they can be published by the engine. Signed-off-by: Martin Hickey <[email protected]>
1 parent 155ad56 commit 499a425

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
if TYPE_CHECKING:
5050
from vllm.attention.backends.abstract import AttentionMetadata
5151
from vllm.config import VllmConfig
52-
from vllm.distributed.kv_events import KVCacheEvent
52+
from vllm.distributed.kv_events import KVCacheEvent, KVEventBatch
5353
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
5454
KVConnectorPromMetrics,
5555
KVConnectorStats,
@@ -343,6 +343,12 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
343343
"""
344344
return None
345345

346+
def get_kv_connector_kv_cache_events(self) -> Optional["KVEventBatch"]:
347+
"""
348+
Get the KV connector kv cache events collected during the last interval.
349+
"""
350+
return None
351+
346352
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
347353
"""
348354
Get the KVConnector handshake metadata for this connector.

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import TYPE_CHECKING, Any
3+
import time
4+
from collections.abc import Iterable
5+
from typing import TYPE_CHECKING, Any, Optional
46

57
import torch
68
from lmcache.integration.vllm.vllm_v1_adapter import (
79
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
810
)
911

1012
from vllm.config import VllmConfig
13+
from vllm.distributed.kv_events import BlockStored, KVCacheEvent, KVEventBatch
1114
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1215
KVConnectorBase_V1,
1316
KVConnectorMetadata,
1417
KVConnectorRole,
1518
)
1619
from vllm.logger import init_logger
1720
from vllm.v1.core.sched.output import SchedulerOutput
21+
from vllm.v1.outputs import KVConnectorOutput
1822

1923
if TYPE_CHECKING:
2024
from vllm.attention.backends.abstract import AttentionMetadata
@@ -54,6 +58,8 @@ def __init__(
5458

5559
self._lmcache_engine = cls(vllm_config, role, self)
5660

61+
self._kv_events: list[KVCacheEvent] = []
62+
5763
# ==============================
5864
# Worker-side methods
5965
# ==============================
@@ -136,6 +142,30 @@ def get_finished(
136142
"""
137143
return self._lmcache_engine.get_finished(finished_req_ids)
138144

145+
def get_kv_connector_kv_cache_events(self) -> Optional["KVEventBatch"]:
146+
"""
147+
Get the KV connector kv cache events collected during the last interval.
148+
"""
149+
events = self._lmcache_engine.get_kv_events()
150+
if not events:
151+
return None
152+
153+
lmcache_kv_events: KVEventBatch | None = None
154+
for event in events:
155+
if lmcache_kv_events is None:
156+
lmcache_kv_events = KVEventBatch(ts=time.time(), events=[])
157+
block = BlockStored(
158+
block_hashes=event.block_hashes,
159+
parent_block_hash=event.parent_block_hash,
160+
token_ids=event.token_ids,
161+
lora_id=event.lora_id,
162+
block_size=event.block_size,
163+
medium=event.medium,
164+
)
165+
lmcache_kv_events.events.append(block)
166+
167+
return lmcache_kv_events
168+
139169
# ==============================
140170
# Scheduler-side methods
141171
# ==============================
@@ -183,6 +213,25 @@ def build_connector_meta(
183213
"""
184214
return self._lmcache_engine.build_connector_meta(scheduler_output)
185215

216+
def update_connector_output(self, connector_output: KVConnectorOutput):
217+
"""
218+
Update KVConnector state from worker-side connectors output.
219+
220+
Args:
221+
connector_output (KVConnectorOutput): the worker-side
222+
connectors output.
223+
"""
224+
# Get the KV events
225+
kv_events = connector_output.kv_cache_events
226+
if (
227+
not kv_events
228+
or not isinstance(kv_events, KVEventBatch)
229+
or not kv_events.events
230+
):
231+
return
232+
self._kv_events = kv_events.events
233+
return
234+
186235
def request_finished(
187236
self,
188237
request: "Request",
@@ -199,3 +248,14 @@ def request_finished(
199248
returned by the engine.
200249
"""
201250
return self._lmcache_engine.request_finished(request, block_ids)
251+
252+
def take_events(self) -> Iterable["KVCacheEvent"]:
253+
"""
254+
Take the KV cache events from the connector.
255+
256+
Yields:
257+
New KV cache events since the last call.
258+
"""
259+
if self._kv_events is not None:
260+
yield from self._kv_events
261+
self._kv_events.clear()

vllm/v1/outputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import torch
99

1010
if TYPE_CHECKING:
11+
from vllm.distributed.kv_events import KVEventBatch
1112
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
1213
else:
1314
KVConnectorStats = object
15+
KVEventBatch = object
1416

1517

1618
class LogprobsLists(NamedTuple):
@@ -109,6 +111,7 @@ class KVConnectorOutput:
109111
finished_sending: set[str] | None = None
110112
finished_recving: set[str] | None = None
111113
kv_connector_stats: KVConnectorStats | None = None
114+
kv_cache_events: KVEventBatch | None = None
112115
# IDs of externally computed KV blocks that failed to load.
113116
# Requests referencing these blocks should be rescheduled to recompute them
114117
invalid_block_ids: set[int] = field(default_factory=set)
@@ -124,6 +127,7 @@ def is_empty(self):
124127
not self.finished_sending
125128
and not self.finished_recving
126129
and not self.kv_connector_stats
130+
and not self.kv_cache_events
127131
and not self.invalid_block_ids
128132
)
129133

vllm/v1/worker/kv_connector_model_runner_mixin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,19 @@ def _get_kv_connector_output(
135135
output.kv_connector_stats = (
136136
KVConnectorModelRunnerMixin.get_kv_connector_stats()
137137
)
138+
output.kv_cache_events = (
139+
KVConnectorModelRunnerMixin.get_kv_connector_kv_cache_events()
140+
)
138141
kv_connector.clear_connector_metadata()
139142

140143
@staticmethod
141144
def get_kv_connector_stats() -> KVConnectorStats | None:
142145
if has_kv_transfer_group():
143146
return get_kv_transfer_group().get_kv_connector_stats()
144147
return None
148+
149+
@staticmethod
150+
def get_kv_connector_kv_cache_events() -> KVConnectorStats | None:
151+
if has_kv_transfer_group():
152+
return get_kv_transfer_group().get_kv_connector_kv_cache_events()
153+
return None

0 commit comments

Comments
 (0)