Skip to content

feat: Add TensorRT Edge-LLM AttentionPlugin backend support#4108

Open
chohk88 wants to merge 1 commit intomainfrom
attn-plugin-workflow
Open

feat: Add TensorRT Edge-LLM AttentionPlugin backend support#4108
chohk88 wants to merge 1 commit intomainfrom
attn-plugin-workflow

Conversation

@chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Mar 3, 2026

Add plugin backend as an alternative to the default SDPA lowering for LLM inference, providing ~1.7x-3.3x speedup over SDPA and ~8x-11x over PyTorch eager execution.

Supported Models: Llama 3.x (3.1/3.2), Qwen 2.5, Qwen 3

Changes:

  • examples/dynamo/attention_plugin_example.py: Standalone plugin demo with correctness validation against PyTorch SDPA
  • examples/dynamo/end_to_end_llm_generation_example.py: End-to-end LLM generation example with plugin integration and benchmarks
  • tools/llm/plugin_utils.py: Model-agnostic plugin utilities including op registration (tensorrt_edge_llm::xqa_attn), TensorRT converter, PluginAttention module, LLMPluginWrapper, compilation and generation
  • tools/llm/run_llm.py: Add --backend plugin/sdpa selection with plugin workflow integration
  • tools/llm/README.md: Plugin backend documentation with build guide, usage examples, and performance summary

Plugin library built from TensorRT-Edge-LLM 0.4.0: https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

Add plugin backend as an alternative to the default SDPA lowering for
LLM inference, providing ~1.5x-1.8x speedup over SDPA and ~8x-11x
over PyTorch eager execution.

Supported Models: Llama 3.x (3.1/3.2), Qwen 2.5, Qwen 3

Changes:
- examples/dynamo/attention_plugin_example.py: Standalone plugin demo
  with correctness validation against PyTorch SDPA
- examples/dynamo/end_to_end_llm_generation_example.py: End-to-end LLM
  generation example with plugin integration and benchmarks
- tools/llm/plugin_utils.py: Model-agnostic plugin utilities including
  op registration (tensorrt_edge_llm::xqa_attn), TensorRT converter,
  PluginAttention module, LLMPluginWrapper, compilation and generation
- tools/llm/run_llm.py: Add --backend plugin/sdpa selection with plugin
  workflow integration
- tools/llm/README.md: Plugin backend documentation with build guide,
  usage examples, and performance summary

Plugin library built from TensorRT-Edge-LLM 0.4.0:
https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime
@chohk88 chohk88 requested review from narendasan and zewenli98 March 3, 2026 13:54
@chohk88 chohk88 self-assigned this Mar 3, 2026
@meta-cla meta-cla bot added the cla signed label Mar 3, 2026
- **Source build (slow)**: `pip install flash-attn --no-build-isolation -v` (fallback if pre-built wheels fail)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to add the MAX_JOBS=8 otherwise you might take peoples systems down

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I think its close. @zewenli98 should take a pass but we can merge near as is. but I want to think about next how we might create lowering passes that insert the placeholder ops programmatically. Evan is about to disable decomposition by default for sdpa so we can basically dynamically insert a pass that keys on those ops

trt_timings.append(elapsed_ms / 1000.0)
else:
# SDPA backend (default)
if args.cache == "static_v1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a few threads, backend and cache and with @zewenli98's PR in core Attention. can we merge these settings so its easy to understand when you will get TRT-Edge-LLM, when you get native IAttention and when you get Static KV Cache?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chohk88 I implemented the converters for some attention variants in #4104. Can you take a look how to integrate?

# -----------------------------------------------------------------------------


@dynamo_tensorrt_converter(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets put all the converters for our edgellm ops in their own file

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have lowering passes to insert the tensorrt edge llm ops in place of pytorch ops?

Comment on lines +697 to +698
enabled_precisions={torch.float32},
use_explicit_typing=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when use_explicit_typing is true, enabled_precisions should be removed.

device=device,
disable_tf32=True,
min_block_size=1,
debug=debug,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...)

@@ -7,7 +7,9 @@ This directory provides utilities and scripts for compiling, optimizing, and ben
- **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc.
- **VLM Support:** Supports Visual Language Models like Qwen2.5-VL and Eagle2.
- **Precision Modes:** Supports FP16, BF16, and FP32.
- **Quantization:** Supports FP8 and NVFP4 quantization formats for reduced memory usage and improved inference speed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we keep quant?

trt_timings.append(elapsed_ms / 1000.0)
else:
# SDPA backend (default)
if args.cache == "static_v1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chohk88 I implemented the converters for some attention variants in #4104. Can you take a look how to integrate?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants