@@ -289,12 +289,14 @@ def forward_hook(module, args, output):
289
289
290
290
class HpuModelAdapter :
291
291
292
- def __init__ (self , model , block_size , dtype , enforce_eager ):
292
+ def __init__ (self , model , vllm_config ):
293
293
self .model = model
294
294
self .prefill_use_fusedsdpa = os .getenv ('VLLM_PROMPT_USE_FUSEDSDPA' ,
295
295
'0' ).lower () in ['1' , 'true' ]
296
- self .block_size = block_size
297
- self .dtype = dtype
296
+ self .vllm_config = vllm_config
297
+ self .block_size = vllm_config .cache_config .block_size
298
+ self .dtype = vllm_config .model_config .dtype
299
+ enforce_eager = vllm_config .model_config .enforce_eager
298
300
if not htorch .utils .internal .is_lazy () and not enforce_eager :
299
301
self .model = torch .compile (self .model ,
300
302
backend = 'hpu_backend' ,
@@ -353,14 +355,20 @@ def forward(self, *args, **kwargs):
353
355
selected_token_indices = kwargs .pop ('selected_token_indices' )
354
356
if 'warmup_mode' in kwargs :
355
357
kwargs .pop ('warmup_mode' )
358
+ virtual_engine = 0
359
+ if 'virtual_engine' in kwargs :
360
+ virtual_engine = kwargs .pop ('virtual_engine' )
356
361
input_ids = kwargs ['input_ids' ]
357
362
kwargs ['attn_metadata' ] = self ._update_metadata (
358
363
kwargs ['attn_metadata' ], input_ids .size (0 ), input_ids .size (1 ),
359
364
input_ids .device , self .dtype )
360
365
LoraMask .setLoraMask (kwargs .pop ('lora_mask' ))
361
- hidden_states = self .model (* args , ** kwargs )
362
- hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
363
- hidden_states = hidden_states .index_select (0 , selected_token_indices )
366
+ with set_forward_context (kwargs ['attn_metadata' ], self .vllm_config ,
367
+ virtual_engine ):
368
+ hidden_states = self .model (* args , ** kwargs )
369
+ hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
370
+ hidden_states = hidden_states .index_select (0 ,
371
+ selected_token_indices )
364
372
return hidden_states
365
373
366
374
def compute_logits (self , * args , ** kwargs ):
@@ -660,10 +668,7 @@ def load_model(self) -> None:
660
668
661
669
with HabanaMemoryProfiler () as m_wrap :
662
670
self .model = _maybe_wrap_in_hpu_graph (
663
- self .model ,
664
- self .block_size ,
665
- dtype = self .model_config .dtype ,
666
- enforce_eager = self .enforce_eager )
671
+ self .model , vllm_config = self .vllm_config )
667
672
msg = f"Wrapping in HPU Graph took { m_wrap .get_summary_string ()} "
668
673
logger .info (msg )
669
674
@@ -1934,6 +1939,7 @@ def execute_model(
1934
1939
"attn_metadata" : self .trim_attn_metadata (attn_metadata ),
1935
1940
"intermediate_tensors" : intermediate_tensors ,
1936
1941
"lora_mask" : lora_mask ,
1942
+ "virtual_engine" : model_input .virtual_engine ,
1937
1943
** (model_input .multi_modal_kwargs or {}),
1938
1944
}
1939
1945
if htorch .utils .internal .is_lazy ():
@@ -1948,11 +1954,7 @@ def execute_model(
1948
1954
f"graphs{ 'T' if use_graphs else 'F' } " )
1949
1955
else :
1950
1956
model_event_name = 'model_executable'
1951
- with set_forward_context (
1952
- model_input .attn_metadata , self .vllm_config ,
1953
- model_input .virtual_engine ), \
1954
- self .profiler .record_event (
1955
- 'internal' , model_event_name ):
1957
+ with self .profiler .record_event ('internal' , model_event_name ):
1956
1958
hidden_states = self .model .forward (
1957
1959
** execute_model_kwargs ,
1958
1960
selected_token_indices = sampling_metadata .selected_token_indices
0 commit comments