Skip to content

Commit 078da31

Browse files
[HPU][Bugfix] set_forward_context and CI test execution (vllm-project#12014)
Signed-off-by: Konrad Zawora <[email protected]>
1 parent 1a40125 commit 078da31

File tree

3 files changed

+23
-18
lines changed

3 files changed

+23
-18
lines changed

.buildkite/run-hpu-test.sh

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@ set -ex
88
docker build -t hpu-test-env -f Dockerfile.hpu .
99

1010
# Setup cleanup
11+
EXITCODE=1
1112
remove_docker_container() { docker rm -f hpu-test || true; }
12-
trap remove_docker_container EXIT
13+
remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; }
14+
trap remove_docker_container_and_exit EXIT
1315
remove_docker_container
1416

1517
# Run the image and launch offline inference
16-
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py
18+
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py
19+
EXITCODE=$?

Dockerfile.hpu

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
1+
FROM vault.habana.ai/gaudi-docker/1.19.1/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
22

33
COPY ./ /workspace/vllm
44

vllm/worker/hpu_model_runner.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,14 @@ def forward_hook(module, args, output):
289289

290290
class HpuModelAdapter:
291291

292-
def __init__(self, model, block_size, dtype, enforce_eager):
292+
def __init__(self, model, vllm_config):
293293
self.model = model
294294
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
295295
'0').lower() in ['1', 'true']
296-
self.block_size = block_size
297-
self.dtype = dtype
296+
self.vllm_config = vllm_config
297+
self.block_size = vllm_config.cache_config.block_size
298+
self.dtype = vllm_config.model_config.dtype
299+
enforce_eager = vllm_config.model_config.enforce_eager
298300
if not htorch.utils.internal.is_lazy() and not enforce_eager:
299301
self.model = torch.compile(self.model,
300302
backend='hpu_backend',
@@ -353,14 +355,20 @@ def forward(self, *args, **kwargs):
353355
selected_token_indices = kwargs.pop('selected_token_indices')
354356
if 'warmup_mode' in kwargs:
355357
kwargs.pop('warmup_mode')
358+
virtual_engine = 0
359+
if 'virtual_engine' in kwargs:
360+
virtual_engine = kwargs.pop('virtual_engine')
356361
input_ids = kwargs['input_ids']
357362
kwargs['attn_metadata'] = self._update_metadata(
358363
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
359364
input_ids.device, self.dtype)
360365
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
361-
hidden_states = self.model(*args, **kwargs)
362-
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
363-
hidden_states = hidden_states.index_select(0, selected_token_indices)
366+
with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
367+
virtual_engine):
368+
hidden_states = self.model(*args, **kwargs)
369+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
370+
hidden_states = hidden_states.index_select(0,
371+
selected_token_indices)
364372
return hidden_states
365373

366374
def compute_logits(self, *args, **kwargs):
@@ -660,10 +668,7 @@ def load_model(self) -> None:
660668

661669
with HabanaMemoryProfiler() as m_wrap:
662670
self.model = _maybe_wrap_in_hpu_graph(
663-
self.model,
664-
self.block_size,
665-
dtype=self.model_config.dtype,
666-
enforce_eager=self.enforce_eager)
671+
self.model, vllm_config=self.vllm_config)
667672
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
668673
logger.info(msg)
669674

@@ -1934,6 +1939,7 @@ def execute_model(
19341939
"attn_metadata": self.trim_attn_metadata(attn_metadata),
19351940
"intermediate_tensors": intermediate_tensors,
19361941
"lora_mask": lora_mask,
1942+
"virtual_engine": model_input.virtual_engine,
19371943
**(model_input.multi_modal_kwargs or {}),
19381944
}
19391945
if htorch.utils.internal.is_lazy():
@@ -1948,11 +1954,7 @@ def execute_model(
19481954
f"graphs{'T' if use_graphs else 'F'}")
19491955
else:
19501956
model_event_name = 'model_executable'
1951-
with set_forward_context(
1952-
model_input.attn_metadata, self.vllm_config,
1953-
model_input.virtual_engine), \
1954-
self.profiler.record_event(
1955-
'internal', model_event_name):
1957+
with self.profiler.record_event('internal', model_event_name):
19561958
hidden_states = self.model.forward(
19571959
**execute_model_kwargs,
19581960
selected_token_indices=sampling_metadata.selected_token_indices

0 commit comments

Comments
 (0)