class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, spec: OffloadingSpec):
assert len(spec.gpu_block_size) == 1
self.gpu_block_size = spec.gpu_block_size[0]
self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor
self.block_size_factor = spec.block_size_factor
self.manager: OffloadingManager = spec.get_manager()
self._requests: dict[ReqId, Request] = {}
# list of GPU block IDs per request
self._request_block_ids: dict[ReqId, list[int]] = {}
# requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
self._next_stored_block_idx: dict[ReqId, int] = {}
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self._blocks_being_loaded: set[BlockHash] | None = (
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
)
# request ID -> set(block hashes being stored/load)
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
def _get_block_hashes(
self,
req: Request,
start_idx: int = 0,
end_idx: int | None = None,
) -> Iterable[BlockHash]:
return islice(
req.block_hashes,
self.block_size_factor * start_idx + self.block_size_factor - 1,
self.block_size_factor * end_idx if end_idx else None,
self.block_size_factor,
)
def get_num_new_matched_tokens(
self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
num_blocks = request.num_tokens // self.offloaded_block_size
assert len(request.block_hashes) // self.block_size_factor == num_blocks
block_hashes = self._get_block_hashes(request)
self.manager.touch(block_hashes)
full_block_tokens = self.offloaded_block_size * num_blocks
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
# we can load less than a block, skip
return 0, False
start_block_idx = num_computed_tokens // self.offloaded_block_size
hits = self.manager.lookup(
self._get_block_hashes(request, start_idx=start_block_idx)
)
if hits is None:
# indicates a lookup that should be tried later
return None, False
if hits == 0:
return 0, False
num_hit_tokens = (
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
)
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
if num_hit_tokens < self.offloaded_block_size:
return 0, False
if self._blocks_being_loaded:
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
)
if any(
block_hash in self._blocks_being_loaded for block_hash in block_hashes
):
# hit blocks are being loaded, delay request
logger.debug(
"Delaying request %s since some of its blocks are already"
" being loaded",
request.request_id,
)
return None, False
return num_hit_tokens, True
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
self._requests[request.request_id] = request
# the block ids are updated in _get_reqs_to_store
self._request_block_ids[request.request_id] = []
if num_external_tokens == 0:
return
block_groups = blocks.get_block_ids()
block_ids = block_groups[0]
num_computed_gpu_blocks = sum(
block.block_hash is not None for block in blocks.blocks[0]
)
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
full_block_tokens = num_computed_tokens + num_external_tokens
assert full_block_tokens % self.offloaded_block_size == 0
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
start_block_idx = num_computed_tokens // self.offloaded_block_size
num_blocks = full_block_tokens // self.offloaded_block_size
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
)
src_spec = self.manager.prepare_load(block_hashes)
dst_spec = GPULoadStoreSpec(
block_ids[num_computed_gpu_blocks:],
group_sizes=(num_pending_gpu_blocks,),
block_indices=(num_computed_gpu_blocks,),
)
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
)
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
req_blocks_being_loaded.update(block_hashes)
self._next_stored_block_idx[request.request_id] = num_blocks
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(req_blocks_being_loaded)
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
if preempted:
self._request_block_ids[req_id] = []
if new_block_id_groups:
new_block_ids = new_block_id_groups[0]
self._request_block_ids[req_id] += new_block_ids
block_ids = self._request_block_ids[req_id]
req = self._requests[req_id]
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
expected_tokens = req.num_computed_tokens + new_tokens
# with async scheduling, some tokens may be missing
total_tokens = min(expected_tokens, req.num_tokens)
num_blocks = total_tokens // self.offloaded_block_size
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
num_new_blocks = num_blocks - start_block_idx
if num_new_blocks <= 0:
continue
num_gpu_blocks = num_blocks * self.block_size_factor
assert len(req.block_hashes) >= num_gpu_blocks
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks
)
store_output = self.manager.prepare_store(new_block_hashes)
if store_output is None:
logger.warning(
"Request %s: cannot store %s blocks", req_id, num_new_blocks
)
continue
self._next_stored_block_idx[req_id] = num_blocks
if not store_output.block_hashes_to_store:
continue
block_hashes_to_store = set(store_output.block_hashes_to_store)
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
self.manager.touch(block_hashes)
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks
)
dst_spec = store_output.store_spec
src_block_ids: list[int] = []
for idx, blk_hash in enumerate(new_block_hashes):
if blk_hash not in block_hashes_to_store:
continue
offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * self.block_size_factor
for i in range(self.block_size_factor):
src_block_ids.append(block_ids[gpu_block_idx + i])
src_spec = GPULoadStoreSpec(
src_block_ids, group_sizes=(len(src_block_ids),)
)
reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= block_hashes_to_store
logger.debug(
"Request %s offloading %s blocks starting from block #%d",
req_id,
len(block_hashes_to_store),
start_block_idx,
)
return reqs_to_store
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output),
reqs_to_flush=scheduler_output.preempted_req_ids,
)
self._reqs_to_load = {}
# NOTE (orozery): we should move this logic to update_connector_output
# once KVConnectorOutput allows us to report completed transfers
for req_id in scheduler_output.preempted_req_ids or ():
block_hashes = self._reqs_being_stored.get(req_id)
if block_hashes:
self.manager.complete_store(block_hashes)
block_hashes.clear()
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
for req_id in connector_output.finished_sending or []:
block_hashes = self._reqs_being_stored.pop(req_id, None)
if block_hashes:
self.manager.complete_store(block_hashes)
for req_id in connector_output.finished_recving or []:
block_hashes = self._reqs_being_loaded.pop(req_id, None)
if block_hashes:
if self._blocks_being_loaded:
self._blocks_being_loaded.difference_update(block_hashes)
self.manager.complete_load(block_hashes)
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
self._requests.pop(req_id, None)
self._request_block_ids.pop(req_id, None)
# TODO(orozery): possibly kickoff offload for last block
# which may have been deferred due to async scheduling
self._next_stored_block_idx.pop(req_id, None)
request_being_stored = req_id in self._reqs_being_stored
return request_being_stored, None
def take_events(self) -> Iterable[KVCacheEvent]:
"""Take the KV cache events from the connector.
Returns:
A list of KV cache events.
"""
for event in self.manager.take_events():
if event.removed:
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
else:
yield BlockStored(
block_hashes=event.block_hashes,
parent_block_hash=None,
token_ids=[],
lora_id=None,
block_size=event.block_size,
medium=event.medium,
lora_name=None,
)