Skip to content

Llama4 VLM Continuous Batching Support #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,11 @@ def __init__(
self._set_tokenizer_params() # set tokenizer params
# Skip inputs/outputs
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("past_")]
[
x
for x in self._session.input_names + self._session.output_names
if x.startswith("past_") or x.endswith("_RetainedState")
]
)

def _set_tokenizer_params(self):
Expand Down
29 changes: 23 additions & 6 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def update(

else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx]))

# Update the position_ids to handle the sliding window
Expand All @@ -460,10 +461,22 @@ def update(
valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
if batch_index is not None:
invalid_scatter_index = torch.iinfo(torch.int32).max
scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids)

self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)

self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
Expand All @@ -483,8 +496,12 @@ def update(
final_indices = torch.where(
(is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices
)
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, final_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, final_indices)
else:
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out
95 changes: 67 additions & 28 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,15 @@ def __init__(self, model):
self.language_model = self.model.language_model
self.config = self.model.config

def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
def forward(
self,
input_ids,
vision_embeds,
position_ids,
image_idx,
past_key_values,
batch_index: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids)
selected = input_ids == self.model.config.image_token_index
indices1 = selected.to(torch.int64).cumsum(1) - 1
Expand All @@ -836,7 +844,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
inputs_embeds=inputs_embeds,
position_ids=position_ids,
past_key_values=past_key_values,
batch_index=batch_index,
use_cache=True,
)
next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
image_idx = torch.where(image_idx < next_idx, next_idx, image_idx)
Expand Down Expand Up @@ -883,6 +895,9 @@ def get_specializations(
ctx_len: int,
img_size: int,
kv_offload: bool = False,
continuous_batching: bool = False,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
**compiler_options,
):
max_num_tiles = compiler_options.pop("max_num_tiles", None)
Expand Down Expand Up @@ -939,28 +954,42 @@ def get_specializations(
"img_size": img_size,
}
]
lang = [
{
"batch_size": batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"max_num_tiles": max_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
},
{
"batch_size": batch_size,
"seq_len": "1",
"ctx_len": ctx_len,
"max_num_tiles": max_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
},
]

lang_prefill = {
"batch_size": 1 if continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"max_num_tiles": max_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
}
if continuous_batching:
lang_prefill["full_batch_size"] = kv_cache_batch_size
else:
lang_prefill["batch_size"] = kv_cache_batch_size
if full_batch_size:
lang_prefill["full_batch_exec_size"] = full_batch_size

lang_decode = {
"batch_size": full_batch_size if continuous_batching else batch_size,
"seq_len": 1,
"ctx_len": ctx_len,
"max_num_tiles": max_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
}
if continuous_batching:
lang_decode["full_batch_size"] = kv_cache_batch_size
else:
lang_decode["batch_size"] = kv_cache_batch_size

lang = []
lang.append(lang_prefill)
lang.append(lang_decode)

specializations = {}

Expand All @@ -969,18 +998,22 @@ def get_specializations(
specializations["lang"] = lang
return specializations, compiler_options
else:
lang[0].pop("vision_size")
lang[1].pop("vision_size")
return lang, compiler_options

def get_onnx_dynamic_axes(self, kv_offload: bool = False):
def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False):
# Define dynamic axes
vision_dynamic_axes = {}
lang_dynamic_axes = {}
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["vision_embeds"] = {0: "vision_size"}
if continuous_batching:
lang_dynamic_axes["batch_index"] = {0: "batch_size"}
vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"}

pkv_dynamic_axes = {0: "batch_size"}
pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"}
for i in range(self.language_model.config.num_hidden_layers):
# switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers.
if int((i + 1) % 4 != 0):
Expand Down Expand Up @@ -1043,7 +1076,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
past_key_values.append(pkv)
return past_key_values

def get_dummy_inputs(self, kv_offload: bool = False):
def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False):
if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", 336)
else:
Expand Down Expand Up @@ -1088,10 +1121,14 @@ def get_dummy_inputs(self, kv_offload: bool = False):
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64)

bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS

# Add data for KV
past_key_values = self.get_dummy_pkv_cache(
config=self.language_model.config,
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
batch_size=fbs if continuous_batching else bs,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

Expand All @@ -1100,6 +1137,8 @@ def get_dummy_inputs(self, kv_offload: bool = False):
for kv in ["key", "value"]:
lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32))

if continuous_batching:
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
inputs = {}
if kv_offload:
inputs["vision"] = vision_inputs
Expand Down
62 changes: 49 additions & 13 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ class _QEffAutoModelForImageTextToTextDualQPC:
def __init__(
self,
model: nn.Module,
continuous_batching,
**kwargs,
):
if kwargs.pop("full_batch_size", None):
Expand All @@ -588,6 +589,7 @@ def __init__(
self.model.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model)
self.lang_model = QEffCausalLMForTextImageToTextModel(model)
self.continuous_batching = continuous_batching
self.input_shapes, self.output_names = None, None

@property
Expand Down Expand Up @@ -627,8 +629,8 @@ def export(
export_dir: Optional[str] = None,
**kwargs,
) -> str:
inputs = self.model.get_dummy_inputs(kv_offload=True)
dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True)
inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching)
dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True, continuous_batching=self.continuous_batching)
output_names = self.model.get_output_names(kv_offload=True)
self.vision_model.export(
inputs["vision"],
Expand All @@ -637,6 +639,9 @@ def export(
export_dir,
)

import ipdb

ipdb.set_trace()
self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir)
return self.onnx_path

Expand All @@ -661,14 +666,20 @@ def compile(
skip_lang: Optional[bool] = False,
**compiler_options,
) -> str:
if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]):
if skip_lang and skip_vision:
raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False")

if self.continuous_batching and full_batch_size is None:
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")

if kv_cache_batch_size and not full_batch_size:
raise ValueError(
f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: "
f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, "
"KV caching requires continuous batching. Please set `full_batch_size` and "
"enable `continuous_batching=True` in `from_pretrained`."
)

if skip_lang and skip_vision:
raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False")
# Infer kv_cache_batch_size if not provided
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size

output_names = self.model.get_output_names(kv_offload=True)

Expand All @@ -678,6 +689,9 @@ def compile(
ctx_len=ctx_len,
img_size=img_size,
kv_offload=True,
continuous_batching=self.continuous_batching,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
**compiler_options,
)

Expand Down Expand Up @@ -746,6 +760,8 @@ def compile(
def generate(
self,
inputs: torch.Tensor,
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None,
prompts: List[str] = None,
streamer: Optional[TextStreamer] = None,
device_ids: List[int] = None,
runtime_ai100: bool = True,
Expand All @@ -763,6 +779,14 @@ def generate(
"""
if not runtime_ai100:
raise NotImplementedError("PyTorch execution is not supported yet for this model!")
if tokenizer and prompts:
return QEfficient.cloud_ai_100_exec_kv(
tokenizer,
self.lang_model.qpc_path,
prompt=prompts,
device_id=device_ids,
generation_len=generation_len,
)

return self.kv_offload_generate(
inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len
Expand Down Expand Up @@ -1304,15 +1328,21 @@ class QEFFAutoModelForImageTextToText:

_hf_auto_class = AutoModelForImageTextToText

def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs):
def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs):
if kv_offload:
return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs)
return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs)
else:
return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs)

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs):
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
kv_offload: Optional[bool] = None,
continuous_batching: bool = False,
**kwargs,
):
"""Used to load models supported by transformers.AutoModelForImageTextToText for Cloud AI 100.

Args:
Expand All @@ -1329,12 +1359,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
if kwargs.get("low_cpu_mem_usage", None):
logger.warning("Updating low_cpu_mem_usage=False")

if kwargs.pop("continuous_batching", None):
NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
if continuous_batching and not kv_offload:
NotImplementedError("Continuous batching is not supported for kv_offload = False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
return cls(
model,
kv_offload=kv_offload,
continuous_batching=continuous_batching,
pretrained_model_name_or_path=pretrained_model_name_or_path,
**kwargs,
)


MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText}
Expand Down
Loading