Skip to content

Commit 109e5c2

Browse files
committed
addressing review comments, adding example to docs, adding docstring, restructuring test and adding function for real-imag extract by removing hardcode
1 parent a90f651 commit 109e5c2

File tree

6 files changed

+127
-35
lines changed

6 files changed

+127
-35
lines changed

examples/distributed_inference/rotary_embedding.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
"""
2+
.. _rotary_embedding:
3+
4+
Rotary Embedding Implementation for Tensor Parallel Attention
5+
============================================================
6+
7+
This module provides an implementation of rotary positional embeddings (RoPE) for transformer models
8+
with support for tensor parallel distributed inference. Rotary embeddings are used to encode positional
9+
information in transformer attention mechanisms.
10+
"""
11+
112
import time
213

314
import tensorrt as trt
@@ -49,7 +60,7 @@ def rotary_embedding(xq, xk, dim, freqs_cis=None):
4960
Returns:
5061
tuple: Tuple containing the rotated query and key tensors.
5162
"""
52-
63+
freqs_cis = freqs_cis[None, :, None, :]
5364
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
5465
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
5566

examples/distributed_inference/tensor_parallel_initialize_dist.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
.. _tensor_parallel_initialize_dist:
3+
Tensor Parallel Initialize Distributed Environment
4+
==================================================
5+
6+
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
7+
"""
8+
19
import logging
210
import os
311
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

examples/distributed_inference/tensor_parallel_rotary_embedding.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
"""
2+
.. _tensor_parallel_rotary_embedding:
3+
Tensor Parallel Rotary Embedding Example
4+
=======================================
5+
6+
This example demonstrates how to use Torch-TensorRT with tensor parallel distributed inference
7+
for models that use rotary positional embeddings (RoPE). It lowers the complex
8+
operations in attention models with rotary embeddings across multiple GPUs.
9+
10+
"""
11+
112
import logging
213
import os
314
import time
@@ -17,6 +28,7 @@
1728

1829
"""
1930
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
31+
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.pyx
2032
"""
2133

2234
BATCH = 2
@@ -35,7 +47,7 @@
3547

3648
logger.info("Torch-tensorrt compilation for rotary embedding")
3749

38-
model = torch.compile(model, backend="torch_tensorrt", options={"debug": True})
50+
model = torch.compile(model, backend="torch_tensorrt")
3951

4052
try:
4153
for i in range(15):

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
1+
"""
2+
.. _tensor_parallel_simple_example:
3+
4+
Torch Parallel Distributed example for simple model
5+
=========================================
6+
7+
Below example shows how to use Torch-TensorRT backend for distributed inference with tensor parallelism.
8+
9+
This example demonstrates:
10+
- Setting up distributed environment for tensor parallelism
11+
- Model sharding across multiple GPUs
12+
- Compilation with Torch-TensorRT
13+
- Distributed inference execution
14+
15+
Usage
16+
-----
17+
.. code-block:: bash
18+
19+
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
20+
"""
21+
122
import time
223

324
import tensorrt as trt
@@ -21,7 +42,7 @@
2142
)
2243

2344
"""
24-
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
45+
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
2546
"""
2647

2748

py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import logging
2-
import operator
3-
from typing import Callable, List, Optional, Set, Tuple
2+
from typing import Callable, List, Set, Tuple
43

54
import torch
65
from torch._subclasses.fake_tensor import FakeTensorMode
76
from torch.fx import GraphModule, Node
8-
from torch.fx.subgraph_rewriter import Match
7+
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
98
from torch_tensorrt.dynamo._settings import CompilationSettings
109
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
1110
clean_up_graph_after_modifications,
@@ -25,7 +24,7 @@ def __init__(
2524
self.subgraph_nodes = subgraph_nodes
2625
self.input_nodes = input_nodes
2726

28-
def __repr__(self):
27+
def __repr__(self) -> str:
2928
return (
3029
f"ComplexOpSubGraphInfo(anchor_nodes={[n.name for n in self.anchor_nodes]}, "
3130
f"subgraph={[n.name for n in self.subgraph_nodes]}, "
@@ -34,7 +33,7 @@ def __repr__(self):
3433

3534

3635
class ComplexOpDetector:
37-
def __init__(self):
36+
def __init__(self) -> None:
3837
pass
3938

4039
def is_complex_dtype(self, node: Node) -> bool:
@@ -106,16 +105,18 @@ def find_complex_op_subgraphs(
106105

107106

108107
class ComplexGraphRewriter:
109-
def __init__(self, gm: GraphModule, truncate_double: bool = False):
108+
def __init__(self, gm: GraphModule, truncate_double: bool = False) -> None:
110109
self.gm = gm
111110
self.truncate_double = truncate_double
112111

113-
def extract_shape_dtype_device(self, input_node):
112+
def extract_shape_dtype_device(
113+
self, input_node: Node
114+
) -> Tuple[Tuple[int, ...], torch.dtype, torch.device]:
114115
if input_node.op == "placeholder":
115116
tensor_val = input_node.meta["val"]
116117

117118
elif input_node.op == "get_attr":
118-
tensor_val = self.get_attr_tensor(input_node.target)
119+
tensor_val = self.get_attr_tensor(input_node.target) # type: ignore
119120

120121
else:
121122
raise ValueError(f"Unsupported node type: {input_node.op}")
@@ -134,7 +135,7 @@ def extract_shape_dtype_device(self, input_node):
134135

135136
return new_node_shape, new_node_dtype, device
136137

137-
def get_attr_tensor(self, target):
138+
def get_attr_tensor(self, target): # type: ignore
138139
# Check if target is param or buffer
139140
if target in dict(self.gm.named_parameters()):
140141
return self.gm.get_parameter(target)
@@ -145,7 +146,7 @@ def get_attr_tensor(self, target):
145146
f"Attribute {target} not found in gm parameters or buffers."
146147
)
147148

148-
def replace_input_node(self, input_node):
149+
def replace_input_node(self, input_node: Node) -> None:
149150
modified = False
150151
logger.debug(f"Replacing input node: {input_node.name}")
151152
new_shape, new_dtype, device = self.extract_shape_dtype_device(input_node)
@@ -160,10 +161,8 @@ def replace_input_node(self, input_node):
160161

161162
elif input_node.op == "get_attr":
162163
new_attr_name = input_node.target + "_reshaped"
163-
from torch._subclasses.fake_tensor import unset_fake_temporarily
164-
165164
with unset_fake_temporarily():
166-
original_tensor = self.get_attr_tensor(input_node.target)
165+
original_tensor = self.get_attr_tensor(input_node.target) # type: ignore
167166
stacked_tensor = torch.stack(
168167
[original_tensor.real, original_tensor.imag], dim=-1
169168
)
@@ -181,7 +180,7 @@ def replace_input_node(self, input_node):
181180
self.gm.graph.erase_node(input_node)
182181
clean_up_graph_after_modifications(self.gm)
183182

184-
def rewrite_subgraph_nodes(self, subgraphs):
183+
def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None:
185184
modified = False
186185
for subgraph in subgraphs:
187186
for input_node in subgraph.input_nodes:
@@ -196,11 +195,20 @@ def rewrite_subgraph_nodes(self, subgraphs):
196195
elif node.target == torch.ops.aten.mul.Tensor:
197196
# this is complex mul where inputs = a+ib and output = c+id.
198197
# complex mul returns (ac - bd) + (ad + bc)i
199-
# which is then view_as_real as (ac-bd), ad+bc stacked along the last dimension with last dimension size 2
198+
# which is then view_as_real as (ac-bd), (ad+bc) stacked along the last dimension with last dimension size 2
199+
x_placeholder_or_func = (
200+
True if node.args[0].op != "get_attr" else False
201+
)
202+
y_placeholder_or_func = (
203+
True if node.args[1].op != "get_attr" else False
204+
)
205+
200206
replaced_nodes = []
201-
original_mul, replacement = complex_mul_replacement()
207+
original_mul, replacement = complex_mul_replacement(
208+
x_placeholder_or_func, y_placeholder_or_func
209+
)
202210

203-
def match_complex_mul(
211+
def match_complex_mul( # type: ignore[no-untyped-def]
204212
match: torch.fx.subgraph_rewriter.Match,
205213
original_graph,
206214
pattern_graph,
@@ -233,7 +241,7 @@ def match_complex_mul(
233241
self.gm.graph.lint()
234242
self.gm.recompile()
235243

236-
def propagate_metadata(self):
244+
def propagate_metadata(self) -> None:
237245
fake_inputs = []
238246
from torch._subclasses.fake_tensor import FakeTensorMode
239247
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
@@ -260,7 +268,34 @@ def propagate_metadata(self):
260268
).propagate(*fake_inputs)
261269

262270

263-
def complex_mul_replacement() -> Tuple[
271+
def extract_real_imag(input, placeholder_or_func: bool = True): # type: ignore
272+
"""Extract real and imaginary parts from a tensor.
273+
This function handles different tensor types based on whether they are placeholder/function
274+
tensors or get_attr tensors. For placeholder/function tensors, it uses select operations,
275+
while for get_attr tensors, it uses indexing.
276+
Args:
277+
input: Input tensor to extract real and imaginary parts from
278+
placeholder_or_func: Boolean flag indicating if the input is a placeholder/function tensor (True)
279+
or a get_attr tensor (False). Defaults to True.
280+
Returns:
281+
Tuple of (real_part, imaginary_part) where both parts have the same type as the input
282+
Note:
283+
- When placeholder_or_func=True: Uses torch.ops.aten.select.int operations
284+
- When placeholder_or_func=False: Uses tensor indexing [..., 0] and [..., 1]
285+
"""
286+
if placeholder_or_func:
287+
# For ITensor, use select operations
288+
real_part = torch.ops.aten.select.int(input, -1, 0)
289+
imag_part = torch.ops.aten.select.int(input, -1, 1)
290+
return real_part, imag_part
291+
else:
292+
# For get_attr, use indexing
293+
return input[..., 0], input[..., 1]
294+
295+
296+
def complex_mul_replacement(
297+
x_placeholder_or_func: bool = True, y_placeholder_or_func: bool = True
298+
) -> Tuple[
264299
Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
265300
Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
266301
]:
@@ -280,9 +315,8 @@ def original_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
280315

281316
# Replacement function: manual complex multiplication on real/imag stacked tensors
282317
def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
283-
x_real = torch.ops.aten.select.int(x, -1, 0)
284-
x_imag = torch.ops.aten.select.int(x, -1, 1) # x is reshape tensor
285-
y_real, y_imag = y[..., 0], y[..., 1] # y is frozen param
318+
x_real, x_imag = extract_real_imag(x, x_placeholder_or_func)
319+
y_real, y_imag = extract_real_imag(y, y_placeholder_or_func)
286320

287321
real_part1 = torch.ops.aten.mul.Tensor(x_real, y_real)
288322
real_part2 = torch.ops.aten.mul.Tensor(x_imag, y_imag)
@@ -304,10 +338,18 @@ def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
304338

305339

306340
# This lowering pass is used to detect and rewrite complex subgraphs in the graph
307-
# This lowering pass works for complex tensor in mul which are parameter or buffers in the graph
308341
def complex_graph_detection(
309342
gm: GraphModule, settings: CompilationSettings
310-
) -> List[ComplexSubGraphInfo]:
343+
) -> GraphModule:
344+
"""Detect and rewrite complex subgraphs in the graph.
345+
This lowering pass is used to detect and rewrite complex subgraphs in the graph.
346+
This lowering pass works for complex tensor in mul which are parameter or buffers in the graph.
347+
Args:
348+
gm: The GraphModule to process
349+
settings: Compilation settings
350+
Returns:
351+
The modified GraphModule with complex subgraphs rewritten
352+
"""
311353
complex_op_detector = ComplexOpDetector()
312354
complex_subgraphs = complex_op_detector.find_complex_op_subgraphs(
313355
gm, anchor_target=torch.ops.aten.view_as_real.default

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,6 @@ def forward(self, input, mat1, mat2):
237237
torch._dynamo.reset()
238238

239239

240-
def rotary_embedding(x, dim, freqs_cis=None):
241-
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
242-
x_out_flatten = torch.view_as_real(x_ * freqs_cis)
243-
return x_out_flatten.type_as(x)
244-
245-
246240
class TestComplexSubgraph(TestCase):
247241
def test_complex_subgraph(self):
248242
BATCH = 1
@@ -263,6 +257,11 @@ def __init__(self):
263257
persistent=True,
264258
)
265259

260+
def rotary_embedding(self, x, dim, freqs_cis=None):
261+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
262+
x_out_flatten = torch.view_as_real(x_ * freqs_cis)
263+
return x_out_flatten.type_as(x)
264+
266265
def _freqs_ex_tensor(self):
267266
real = torch.tensor([[[[1.0000]], [[2.0000]]]], device="cuda")
268267
imag = torch.tensor([[[[0.0000]], [[3.0000]]]], device="cuda")
@@ -273,14 +272,13 @@ def _freqs_ex_tensor(self):
273272
def forward(self, x):
274273
q = self.wq(x)
275274
freqs_cis = self._freqs_ex_tensor().to(q.device)
276-
q_out = rotary_embedding(q, self.dim, freqs_cis=freqs_cis)
275+
q_out = self.rotary_embedding(q, self.dim, freqs_cis=freqs_cis)
277276
return q_out
278277

279278
inputs = [torch.randn(BATCH, SEQ_LEN, HEADS, DIM).cuda()]
280279
model = RotaryAttention()
281280
model = model.cuda()
282281

283-
fx_graph = torch.fx.symbolic_trace(RotaryAttention().cuda())
284282
expected_ops = {torch.ops.aten.mul.Tensor}
285283
unexpected_ops = {
286284
torch.ops.aten.view_as_complex.default,

0 commit comments

Comments
 (0)