Skip to content

[kv_cache] integrated vlm code for benchmark (Stacked on #3527) #3652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: kv_cache
Choose a base branch
from

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Jul 3, 2025

Description

Base branch: kv_cache (PR #3527 )

  1. Integrated VLM benchmark framework
    • Currently supports Eagle2
    • Planned support: Paligemma, Qwen 2.5-VL, etc.
  2. Added custom token-generation function** for multi-modal (MM) models

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 requested a review from peri044 July 3, 2025 01:46
@chohk88 chohk88 self-assigned this Jul 3, 2025
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py	2025-07-03 01:46:24.189295+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py	2025-07-03 01:46:57.661981+00:00
@@ -318,14 +318,16 @@
    generated = 0

    while generated < osl:
        cur_embeds = seq_embeds  # full seq first step or cache off
        position_ids = (
-                torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
-            )
+            torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
+        )
        with torch.no_grad():
-            logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids)
+            logits = model.language_model(
+                inputs_embeds=cur_embeds, position_ids=position_ids
+            )
            if hasattr(logits, "logits"):
                logits = logits.logits

        next_tok = torch.argmax(logits[:, -1, :], dim=-1)  # (B,)
        # append token & embed
@@ -381,13 +383,11 @@
        mask = seq_tokens.view(B * N) == model.image_token_index
        flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()]
        seq_embeds = flat.view(B, N, C)

    # ───────────────────── KV-cache initialization ─────────────────────
-    kv_cache = get_zeroed_static_cache_inputs(
-        model.language_model
-    )
+    kv_cache = get_zeroed_static_cache_inputs(model.language_model)
    start_idx = 0  # First token index
    end_idx = seq_embeds.size(1)  # Prompt length
    generated = 0
    max_total_len = max_output_seq_length
    output_tokens = seq_tokens.clone()
@@ -607,13 +607,11 @@
        mask = seq_tokens.view(B * N) == model.image_token_index
        flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()]
        seq_embeds = flat.view(B, N, C)

    # ───────────────────── KV-cache initialization ─────────────────────
-    kv_cache = get_zeroed_static_cache_inputs(
-        model.language_model
-    )
+    kv_cache = get_zeroed_static_cache_inputs(model.language_model)
    start_idx = 0  # First token index
    end_idx = seq_embeds.size(1)  # Prompt length
    generated = 0
    max_total_len = end_idx + max_new_tokens
    output_tokens = seq_tokens.clone()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py	2025-07-03 01:46:23.684507+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py	2025-07-03 01:46:57.646835+00:00
@@ -318,14 +318,16 @@
    generated = 0

    while generated < osl:
        cur_embeds = seq_embeds  # full seq first step or cache off
        position_ids = (
-                torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
-            )
+            torch.arange(cur_embeds.shape[1]).unsqueeze(0).to(cur_embeds.device)
+        )
        with torch.no_grad():
-            logits = model.language_model(inputs_embeds=cur_embeds, position_ids=position_ids)
+            logits = model.language_model(
+                inputs_embeds=cur_embeds, position_ids=position_ids
+            )
            if hasattr(logits, "logits"):
                logits = logits.logits

        next_tok = torch.argmax(logits[:, -1, :], dim=-1)  # (B,)
        # append token & embed
@@ -381,13 +383,11 @@
        mask = seq_tokens.view(B * N) == model.image_token_index
        flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()]
        seq_embeds = flat.view(B, N, C)

    # ───────────────────── KV-cache initialization ─────────────────────
-    kv_cache = get_zeroed_static_cache_inputs(
-        model.language_model
-    )
+    kv_cache = get_zeroed_static_cache_inputs(model.language_model)
    start_idx = 0  # First token index
    end_idx = seq_embeds.size(1)  # Prompt length
    generated = 0
    max_total_len = max_output_seq_length
    output_tokens = seq_tokens.clone()
@@ -607,13 +607,11 @@
        mask = seq_tokens.view(B * N) == model.image_token_index
        flat[mask] = vit_embeds.reshape(-1, C).to(flat.dtype)[: mask.sum()]
        seq_embeds = flat.view(B, N, C)

    # ───────────────────── KV-cache initialization ─────────────────────
-    kv_cache = get_zeroed_static_cache_inputs(
-        model.language_model
-    )
+    kv_cache = get_zeroed_static_cache_inputs(model.language_model)
    start_idx = 0  # First token index
    end_idx = seq_embeds.size(1)  # Prompt length
    generated = 0
    max_total_len = end_idx + max_new_tokens
    output_tokens = seq_tokens.clone()

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added initial set of review comments. I'll try to run this example and add more comments later. Where is the vision model being compiled here ?

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import transformers.models.qwen2.modeling_qwen2 as mq # noqa: E402

mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to modify the model config after it is initialized instead of modifying the entries ?

model.config._attn_implementation = "sdpa"

"""Dispatch helper for supported VLMs."""
if model_name.lower() == "eagle2":
return _load_eagle2(device, torch_dtype)
msg = f"Unsupported model: {model_name}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider modifying the message to Encountered model: {model_name} is not supported. Supported models include nvidia/Eagle2-2B, PaliGemma2 (and others if they are supported)

lm_wrap = _LMNoCache(language_model).to(DEVICE).eval()
max_seq_len = input_embeds.shape[1] + args.num_tokens

S = torch.export.Dim("seq", min=1, max=max_seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming S to seq_len

enabled_precisions = {torch.float32}

with torch.inference_mode():
exported = torch.export.export(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comment: consider renaming exported to exported_program here


example_embeds = torch.randn(
1,
2560,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider declaring this in a map eg: image_tokens_map = {model_name: 2560}EAGLE_IMG_TOKENS = 2560.

Comment on lines +164 to +166
Front-end dispatcher mirroring *run_llm.py*’s `compile_torchtrt`.

Depending on the target VLM, delegates to the appropriate compile routine.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making this docstring more informative. It looks like this function only compiles the LLM part of the VLM. Please mention that.

parser = argparse.ArgumentParser(
description="Run VLM inference (PyTorch & TensorRT back-ends)"
)
parser.add_argument("--model", default="eagle2", help="VLM model name")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let the model names be exactly same as HF model names instead of using short forms. Eg: the default should be nvidia/Eagle2-2B

Comment on lines +256 to +257
url = "https://cdn.pixabay.com/photo/2019/08/08/23/33/car-4393990_1280.jpg"
image = Image.open(requests.get(url, stream=True).raw)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to a load image function

image = Image.open(requests.get(url, stream=True).raw)

if args.benchmark:
prompt_len = args.isl - 1792 - 26
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is 1792 and 26 ? Please specify them as variables and add comments indicating what they are

)

# Register static cache lowering passes if requested
if args.cache == "static_v1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is static_v2 not working ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants