-
Notifications
You must be signed in to change notification settings - Fork 364
[kv_cache] integrated vlm code for benchmark (Stacked on #3527) #3652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: kv_cache
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,387 @@ | ||
""" | ||
.. _run_vlm: | ||
|
||
Running VLM inference with Torch-TensorRT | ||
========================================================== | ||
|
||
This script mirrors the style and structure of *run_llm.py*, illustrating a | ||
Torch-TensorRT (dynamo backend) workflow for Visual-Language Models (VLMs). | ||
""" | ||
|
||
import argparse | ||
import copy | ||
import os | ||
import sys | ||
from contextlib import nullcontext | ||
from typing import Tuple | ||
|
||
import requests | ||
import torch | ||
import torch_tensorrt | ||
from PIL import Image | ||
from torchtrt_ext import register_sdpa | ||
from transformers import AutoModel, AutoProcessor | ||
from utils import ( | ||
generate_mm, | ||
generate_mm_with_static_cache, | ||
record_stats, | ||
time_generate_mm, | ||
) | ||
|
||
# -----------------------------------------------------------------------------# | ||
# Global configuration | ||
# -----------------------------------------------------------------------------# | ||
DEVICE = torch.device("cuda:0") | ||
|
||
# Register SDPA as a standalone operator. Converter & lowering pass are defined | ||
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | ||
import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402 | ||
|
||
mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"] | ||
|
||
# -----------------------------------------------------------------------------# | ||
# Model loading helpers | ||
# -----------------------------------------------------------------------------# | ||
|
||
|
||
def _load_eagle2(device: torch.device, torch_dtype: torch.dtype): | ||
""" | ||
Load Eagle2 model and processor. | ||
|
||
Returns | ||
------- | ||
tuple[torch.nn.Module, transformers.AutoProcessor, torch.nn.Embedding] | ||
The model, its processor and the language-model input embedding layer. | ||
""" | ||
model_id = "nvidia/Eagle2-2B" | ||
with torch.no_grad(): | ||
model = ( | ||
AutoModel.from_pretrained( | ||
model_id, trust_remote_code=True, torch_dtype=torch_dtype | ||
) | ||
.eval() | ||
.to(device) | ||
) | ||
|
||
processor = AutoProcessor.from_pretrained( | ||
model_id, trust_remote_code=True, use_fast=True | ||
) | ||
if hasattr(processor, "tokenizer"): | ||
processor.tokenizer.padding_side = "left" | ||
|
||
emb_layer = model.language_model.get_input_embeddings().to(torch_dtype).to(device) | ||
return model, processor, emb_layer | ||
|
||
|
||
def _load_model( | ||
model_name: str, device: torch.device, torch_dtype: torch.dtype | ||
) -> Tuple[torch.nn.Module, AutoProcessor, torch.nn.Embedding]: | ||
"""Dispatch helper for supported VLMs.""" | ||
if model_name.lower() == "eagle2": | ||
return _load_eagle2(device, torch_dtype) | ||
msg = f"Unsupported model: {model_name}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider modifying the message to Encountered model: {model_name} is not supported. Supported models include nvidia/Eagle2-2B, PaliGemma2 (and others if they are supported) |
||
raise ValueError(msg) | ||
|
||
|
||
# -----------------------------------------------------------------------------# | ||
# Torch-TensorRT compilation helpers | ||
# -----------------------------------------------------------------------------# | ||
|
||
|
||
class _LMNoCache(torch.nn.Module): | ||
""" | ||
Thin wrapper that exposes a language model via ``inputs_embeds`` without KV-cache. | ||
""" | ||
|
||
def __init__(self, lm): | ||
super().__init__() | ||
self.lm = lm | ||
|
||
def forward(self, inputs_embeds, position_ids): | ||
out = self.lm(inputs_embeds=inputs_embeds, position_ids=position_ids) | ||
return out.logits if hasattr(out, "logits") else out | ||
|
||
|
||
def _compile_eagle2_lm( | ||
language_model: torch.nn.Module, | ||
input_embeds: torch.Tensor, | ||
args: argparse.Namespace, | ||
) -> torch.nn.Module: | ||
""" | ||
Compile Eagle2 language model with Torch-TensorRT. | ||
|
||
The function follows the same precision-specific flag logic used in | ||
*run_llm.py* for consistency. | ||
""" | ||
lm_wrap = _LMNoCache(language_model).to(DEVICE).eval() | ||
max_seq_len = input_embeds.shape[1] + args.num_tokens | ||
|
||
S = torch.export.Dim("seq", min=1, max=max_seq_len) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider renaming |
||
position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(DEVICE) | ||
dyn_shapes = {"inputs_embeds": {1: S}, "position_ids": {1: S}} | ||
|
||
# Precision-specific flags --------------------------------------------------# | ||
use_fp32_acc = False | ||
use_explicit_typing = False | ||
if args.precision == "FP16": | ||
enabled_precisions = {torch.float32} | ||
use_fp32_acc = True | ||
use_explicit_typing = True | ||
elif args.precision == "BF16": | ||
enabled_precisions = {torch.bfloat16} | ||
else: # FP32 | ||
enabled_precisions = {torch.float32} | ||
|
||
with torch.inference_mode(): | ||
exported = torch.export.export( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor comment: consider renaming |
||
lm_wrap, | ||
(input_embeds, position_ids), | ||
dynamic_shapes=dyn_shapes, | ||
strict=False, | ||
) | ||
|
||
with torch_tensorrt.logging.debug() if args.debug else nullcontext(): | ||
trt_mod = torch_tensorrt.dynamo.compile( | ||
exported, | ||
inputs=[input_embeds, position_ids], | ||
enabled_precisions=enabled_precisions, | ||
use_explicit_typing=use_explicit_typing, | ||
use_fp32_acc=use_fp32_acc, | ||
device=DEVICE, | ||
disable_tf32=True, | ||
use_python_runtime=True, | ||
debug=args.debug, | ||
offload_module_to_cpu=True, | ||
min_block_size=args.min_block_size, | ||
) | ||
return trt_mod | ||
|
||
|
||
def compile_torchtrt( | ||
model: torch.nn.Module, args: argparse.Namespace | ||
) -> torch.nn.Module: | ||
""" | ||
Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`. | ||
|
||
Depending on the target VLM, delegates to the appropriate compile routine. | ||
Comment on lines
+164
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider making this docstring more informative. It looks like this function only compiles the LLM part of the VLM. Please mention that. |
||
""" | ||
torch_dtype = { | ||
"FP16": torch.float16, | ||
"BF16": torch.bfloat16, | ||
}.get(args.precision, torch.float32) | ||
|
||
example_embeds = torch.randn( | ||
1, | ||
2560, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider declaring this in a map eg: image_tokens_map = {model_name: 2560}EAGLE_IMG_TOKENS = 2560. |
||
model.language_model.config.hidden_size, | ||
dtype=torch_dtype, | ||
device=DEVICE, | ||
) | ||
|
||
if args.model.lower() == "eagle2": | ||
return _compile_eagle2_lm(model.language_model, example_embeds, args) | ||
|
||
msg = f"Unsupported model for compilation: {args.model}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider changing this message similar to the comment above (by including list of supported models) |
||
raise ValueError(msg) | ||
|
||
|
||
# -----------------------------------------------------------------------------# | ||
# Utility helpers | ||
# -----------------------------------------------------------------------------# | ||
|
||
|
||
def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer): | ||
"""Pretty-print generated text for comparison.""" | ||
print(f"========= {backend_name} =========") | ||
print( | ||
f"{backend_name} model generated text: ", | ||
tokenizer.decode(gen_tokens[0], skip_special_tokens=True), | ||
) | ||
print("===================================") | ||
|
||
|
||
# -----------------------------------------------------------------------------# | ||
# Main driver | ||
# -----------------------------------------------------------------------------# | ||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Run VLM inference (PyTorch & TensorRT back-ends)" | ||
) | ||
parser.add_argument("--model", default="eagle2", help="VLM model name") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let the model names be exactly same as HF model names instead of using short forms. Eg: the default should be |
||
parser.add_argument("--prompt", default="Describe this image.", help="Prompt text") | ||
parser.add_argument( | ||
"--precision", | ||
default="FP16", | ||
choices=["FP16", "BF16", "FP32"], | ||
help="Computation precision", | ||
) | ||
parser.add_argument("--iterations", type=int, default=5, help="# iterations") | ||
parser.add_argument("--min_block_size", type=int, default=1, help="Min block size") | ||
parser.add_argument("--num_tokens", type=int, default=128, help="# new tokens") | ||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size") | ||
parser.add_argument("--isl", type=int, default=2048, help="Input seq length") | ||
parser.add_argument( | ||
"--enable_pytorch_run", | ||
action="store_true", | ||
help="Run the PyTorch baseline as well", | ||
) | ||
parser.add_argument( | ||
"--cache", | ||
default="", | ||
choices=["", "static_v1"], | ||
help="KV-cache variant to use", | ||
) | ||
parser.add_argument( | ||
"--debug", action="store_true", help="Enable Torch-TensorRT debug logs" | ||
) | ||
parser.add_argument( | ||
"--benchmark", action="store_true", help="Enable benchmarking mode" | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# -------------------------------------------------------------------------# | ||
# 1. Model / processor / embeddings | ||
# -------------------------------------------------------------------------# | ||
dtype = { | ||
"FP16": torch.float16, | ||
"BF16": torch.bfloat16, | ||
}.get(args.precision, torch.float32) | ||
|
||
model, processor, emb_layer = _load_model(args.model, DEVICE, dtype) | ||
|
||
# -------------------------------------------------------------------------# | ||
# 2. Input construction (image + text prompt) | ||
# -------------------------------------------------------------------------# | ||
url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg" | ||
image = Image.open(requests.get(url, stream=True).raw) | ||
Comment on lines
+256
to
+257
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this to a load image function |
||
|
||
if args.benchmark: | ||
prompt_len = args.isl - 1792 - 26 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is 1792 and 26 ? Please specify them as variables and add comments indicating what they are |
||
prompt_txt = " ".join(["token"] * max(prompt_len, 0)) | ||
else: | ||
prompt_txt = args.prompt | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "image", "image": image}, | ||
{"type": "text", "text": prompt_txt}, | ||
], | ||
} | ||
] | ||
|
||
txt = [ | ||
processor.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
] | ||
img_in, vid_in = processor.process_vision_info(messages) | ||
inputs = processor( | ||
text=txt, images=img_in, videos=vid_in, return_tensors="pt", padding=True | ||
).to(DEVICE) | ||
|
||
max_output_len = inputs["input_ids"].shape[1] + args.num_tokens | ||
|
||
# -------------------------------------------------------------------------# | ||
# 3. Optional: PyTorch baseline | ||
# -------------------------------------------------------------------------# | ||
pyt_gen_tokens = pyt_timings = pyt_stats = None | ||
if args.enable_pytorch_run: | ||
pyt_gen_tokens = generate_mm( | ||
model, | ||
inputs["pixel_values"], | ||
inputs["input_ids"], | ||
max_output_len, | ||
processor.tokenizer.eos_token_id, | ||
emb_layer, | ||
) | ||
if args.benchmark: | ||
pyt_timings = time_generate_mm( | ||
generate_mm, | ||
model, | ||
inputs["pixel_values"].clone(), | ||
inputs["input_ids"].clone(), | ||
max_output_len, | ||
processor.tokenizer.eos_token_id, | ||
emb_layer, | ||
iterations=args.iterations, | ||
) | ||
pyt_stats = record_stats( | ||
"PyTorch", | ||
pyt_timings, | ||
args.precision, | ||
batch_size=args.batch_size, | ||
compile_time_s=None, | ||
) | ||
|
||
# Register static cache lowering passes if requested | ||
if args.cache == "static_v1": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is static_v2 not working ? |
||
import static_cache_v1 # noqa: F401 | ||
|
||
# -------------------------------------------------------------------------# | ||
# 4. Torch-TensorRT compile & run | ||
# -------------------------------------------------------------------------# | ||
trt_lm = compile_torchtrt(model, args) | ||
trt_model = copy.deepcopy(model) | ||
trt_model.language_model = trt_lm | ||
|
||
emb_layer = emb_layer.to(DEVICE) | ||
|
||
if args.cache == "static_v1": | ||
trt_generate = generate_mm_with_static_cache | ||
else: | ||
trt_generate = generate_mm | ||
|
||
trt_gen_tokens = trt_generate( | ||
trt_model, | ||
inputs["pixel_values"], | ||
inputs["input_ids"], | ||
max_output_len, | ||
processor.tokenizer.eos_token_id, | ||
emb_layer, | ||
DEVICE if args.cache == "static_v1" else None, # device arg only for static_v1 | ||
) | ||
|
||
if args.benchmark: | ||
trt_timings = time_generate_mm( | ||
trt_generate, | ||
trt_model, | ||
inputs["pixel_values"].clone(), | ||
inputs["input_ids"].clone(), | ||
max_output_len, | ||
processor.tokenizer.eos_token_id, | ||
emb_layer, | ||
iterations=args.iterations, | ||
device=DEVICE if args.cache == "static_v1" else None, | ||
) | ||
trt_stats = record_stats( | ||
"TensorRT", | ||
trt_timings, | ||
args.precision, | ||
batch_size=args.batch_size, | ||
compile_time_s=None, | ||
) | ||
|
||
# -------------------------------------------------------------------------# | ||
# 5. Reporting | ||
# -------------------------------------------------------------------------# | ||
if not args.benchmark: | ||
if args.enable_pytorch_run: | ||
print_outputs("PyTorch", pyt_gen_tokens, processor.tokenizer) | ||
print_outputs("TensorRT", trt_gen_tokens, processor.tokenizer) | ||
|
||
if args.enable_pytorch_run: | ||
print( | ||
f"PyTorch and TensorRT outputs match: " | ||
f"{torch.equal(pyt_gen_tokens, trt_gen_tokens)}" | ||
) | ||
|
||
if args.benchmark: | ||
if args.enable_pytorch_run: | ||
print("========= PyTorch PERFORMANCE =========\n") | ||
print(pyt_stats) | ||
print("=====================\n") | ||
print("========= TensorRT PERFORMANCE =========\n") | ||
print(trt_stats) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to modify the model config after it is initialized instead of modifying the entries ?