1
1
import logging
2
- import operator
3
- from typing import Callable , List , Optional , Set , Tuple
2
+ from typing import Callable , List , Set , Tuple
4
3
5
4
import torch
6
5
from torch ._subclasses .fake_tensor import FakeTensorMode
7
6
from torch .fx import GraphModule , Node
8
- from torch .fx .subgraph_rewriter import Match
7
+ from torch .fx .experimental . proxy_tensor import unset_fake_temporarily
9
8
from torch_tensorrt .dynamo ._settings import CompilationSettings
10
9
from torch_tensorrt .dynamo .lowering .passes .pass_utils import (
11
10
clean_up_graph_after_modifications ,
@@ -25,7 +24,7 @@ def __init__(
25
24
self .subgraph_nodes = subgraph_nodes
26
25
self .input_nodes = input_nodes
27
26
28
- def __repr__ (self ):
27
+ def __repr__ (self ) -> str :
29
28
return (
30
29
f"ComplexOpSubGraphInfo(anchor_nodes={ [n .name for n in self .anchor_nodes ]} , "
31
30
f"subgraph={ [n .name for n in self .subgraph_nodes ]} , "
@@ -34,7 +33,7 @@ def __repr__(self):
34
33
35
34
36
35
class ComplexOpDetector :
37
- def __init__ (self ):
36
+ def __init__ (self ) -> None :
38
37
pass
39
38
40
39
def is_complex_dtype (self , node : Node ) -> bool :
@@ -106,16 +105,18 @@ def find_complex_op_subgraphs(
106
105
107
106
108
107
class ComplexGraphRewriter :
109
- def __init__ (self , gm : GraphModule , truncate_double : bool = False ):
108
+ def __init__ (self , gm : GraphModule , truncate_double : bool = False ) -> None :
110
109
self .gm = gm
111
110
self .truncate_double = truncate_double
112
111
113
- def extract_shape_dtype_device (self , input_node ):
112
+ def extract_shape_dtype_device (
113
+ self , input_node : Node
114
+ ) -> Tuple [Tuple [int , ...], torch .dtype , torch .device ]:
114
115
if input_node .op == "placeholder" :
115
116
tensor_val = input_node .meta ["val" ]
116
117
117
118
elif input_node .op == "get_attr" :
118
- tensor_val = self .get_attr_tensor (input_node .target )
119
+ tensor_val = self .get_attr_tensor (input_node .target ) # type: ignore
119
120
120
121
else :
121
122
raise ValueError (f"Unsupported node type: { input_node .op } " )
@@ -134,7 +135,7 @@ def extract_shape_dtype_device(self, input_node):
134
135
135
136
return new_node_shape , new_node_dtype , device
136
137
137
- def get_attr_tensor (self , target ):
138
+ def get_attr_tensor (self , target ): # type: ignore
138
139
# Check if target is param or buffer
139
140
if target in dict (self .gm .named_parameters ()):
140
141
return self .gm .get_parameter (target )
@@ -145,7 +146,7 @@ def get_attr_tensor(self, target):
145
146
f"Attribute { target } not found in gm parameters or buffers."
146
147
)
147
148
148
- def replace_input_node (self , input_node ) :
149
+ def replace_input_node (self , input_node : Node ) -> None :
149
150
modified = False
150
151
logger .debug (f"Replacing input node: { input_node .name } " )
151
152
new_shape , new_dtype , device = self .extract_shape_dtype_device (input_node )
@@ -160,10 +161,8 @@ def replace_input_node(self, input_node):
160
161
161
162
elif input_node .op == "get_attr" :
162
163
new_attr_name = input_node .target + "_reshaped"
163
- from torch ._subclasses .fake_tensor import unset_fake_temporarily
164
-
165
164
with unset_fake_temporarily ():
166
- original_tensor = self .get_attr_tensor (input_node .target )
165
+ original_tensor = self .get_attr_tensor (input_node .target ) # type: ignore
167
166
stacked_tensor = torch .stack (
168
167
[original_tensor .real , original_tensor .imag ], dim = - 1
169
168
)
@@ -181,7 +180,7 @@ def replace_input_node(self, input_node):
181
180
self .gm .graph .erase_node (input_node )
182
181
clean_up_graph_after_modifications (self .gm )
183
182
184
- def rewrite_subgraph_nodes (self , subgraphs ) :
183
+ def rewrite_subgraph_nodes (self , subgraphs : List [ ComplexSubGraphInfo ]) -> None :
185
184
modified = False
186
185
for subgraph in subgraphs :
187
186
for input_node in subgraph .input_nodes :
@@ -196,11 +195,20 @@ def rewrite_subgraph_nodes(self, subgraphs):
196
195
elif node .target == torch .ops .aten .mul .Tensor :
197
196
# this is complex mul where inputs = a+ib and output = c+id.
198
197
# complex mul returns (ac - bd) + (ad + bc)i
199
- # which is then view_as_real as (ac-bd), ad+bc stacked along the last dimension with last dimension size 2
198
+ # which is then view_as_real as (ac-bd), (ad+bc) stacked along the last dimension with last dimension size 2
199
+ x_placeholder_or_func = (
200
+ True if node .args [0 ].op != "get_attr" else False
201
+ )
202
+ y_placeholder_or_func = (
203
+ True if node .args [1 ].op != "get_attr" else False
204
+ )
205
+
200
206
replaced_nodes = []
201
- original_mul , replacement = complex_mul_replacement ()
207
+ original_mul , replacement = complex_mul_replacement (
208
+ x_placeholder_or_func , y_placeholder_or_func
209
+ )
202
210
203
- def match_complex_mul (
211
+ def match_complex_mul ( # type: ignore[no-untyped-def]
204
212
match : torch .fx .subgraph_rewriter .Match ,
205
213
original_graph ,
206
214
pattern_graph ,
@@ -233,7 +241,7 @@ def match_complex_mul(
233
241
self .gm .graph .lint ()
234
242
self .gm .recompile ()
235
243
236
- def propagate_metadata (self ):
244
+ def propagate_metadata (self ) -> None :
237
245
fake_inputs = []
238
246
from torch ._subclasses .fake_tensor import FakeTensorMode
239
247
from torch .fx .passes .fake_tensor_prop import FakeTensorProp
@@ -260,7 +268,34 @@ def propagate_metadata(self):
260
268
).propagate (* fake_inputs )
261
269
262
270
263
- def complex_mul_replacement () -> Tuple [
271
+ def extract_real_imag (input , placeholder_or_func : bool = True ): # type: ignore
272
+ """Extract real and imaginary parts from a tensor.
273
+ This function handles different tensor types based on whether they are placeholder/function
274
+ tensors or get_attr tensors. For placeholder/function tensors, it uses select operations,
275
+ while for get_attr tensors, it uses indexing.
276
+ Args:
277
+ input: Input tensor to extract real and imaginary parts from
278
+ placeholder_or_func: Boolean flag indicating if the input is a placeholder/function tensor (True)
279
+ or a get_attr tensor (False). Defaults to True.
280
+ Returns:
281
+ Tuple of (real_part, imaginary_part) where both parts have the same type as the input
282
+ Note:
283
+ - When placeholder_or_func=True: Uses torch.ops.aten.select.int operations
284
+ - When placeholder_or_func=False: Uses tensor indexing [..., 0] and [..., 1]
285
+ """
286
+ if placeholder_or_func :
287
+ # For ITensor, use select operations
288
+ real_part = torch .ops .aten .select .int (input , - 1 , 0 )
289
+ imag_part = torch .ops .aten .select .int (input , - 1 , 1 )
290
+ return real_part , imag_part
291
+ else :
292
+ # For get_attr, use indexing
293
+ return input [..., 0 ], input [..., 1 ]
294
+
295
+
296
+ def complex_mul_replacement (
297
+ x_placeholder_or_func : bool = True , y_placeholder_or_func : bool = True
298
+ ) -> Tuple [
264
299
Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ],
265
300
Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ],
266
301
]:
@@ -280,9 +315,8 @@ def original_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
280
315
281
316
# Replacement function: manual complex multiplication on real/imag stacked tensors
282
317
def replacement (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
283
- x_real = torch .ops .aten .select .int (x , - 1 , 0 )
284
- x_imag = torch .ops .aten .select .int (x , - 1 , 1 ) # x is reshape tensor
285
- y_real , y_imag = y [..., 0 ], y [..., 1 ] # y is frozen param
318
+ x_real , x_imag = extract_real_imag (x , x_placeholder_or_func )
319
+ y_real , y_imag = extract_real_imag (y , y_placeholder_or_func )
286
320
287
321
real_part1 = torch .ops .aten .mul .Tensor (x_real , y_real )
288
322
real_part2 = torch .ops .aten .mul .Tensor (x_imag , y_imag )
@@ -304,10 +338,18 @@ def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
304
338
305
339
306
340
# This lowering pass is used to detect and rewrite complex subgraphs in the graph
307
- # This lowering pass works for complex tensor in mul which are parameter or buffers in the graph
308
341
def complex_graph_detection (
309
342
gm : GraphModule , settings : CompilationSettings
310
- ) -> List [ComplexSubGraphInfo ]:
343
+ ) -> GraphModule :
344
+ """Detect and rewrite complex subgraphs in the graph.
345
+ This lowering pass is used to detect and rewrite complex subgraphs in the graph.
346
+ This lowering pass works for complex tensor in mul which are parameter or buffers in the graph.
347
+ Args:
348
+ gm: The GraphModule to process
349
+ settings: Compilation settings
350
+ Returns:
351
+ The modified GraphModule with complex subgraphs rewritten
352
+ """
311
353
complex_op_detector = ComplexOpDetector ()
312
354
complex_subgraphs = complex_op_detector .find_complex_op_subgraphs (
313
355
gm , anchor_target = torch .ops .aten .view_as_real .default
0 commit comments