Skip to content

Commit f534f12

Browse files
authored
Rand converter - evaluator (#2580)
1 parent 8daebf6 commit f534f12

File tree

3 files changed

+243
-4
lines changed

3 files changed

+243
-4
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,74 @@ def aten_ops_arange_start_step(
4747
name: str,
4848
) -> Union[TRTTensor, Sequence[TRTTensor]]:
4949
return np.arange(*args)
50+
51+
52+
def rand_validator(rand_node: Node) -> bool:
53+
dtype = rand_node.kwargs.get("dtype", None)
54+
layout = rand_node.kwargs.get("layout", None)
55+
if dtype is not None:
56+
_LOGGER.debug(
57+
f"Currently we don't support specifying output dtype, got {dtype}."
58+
)
59+
return False
60+
if layout is not None:
61+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
62+
return False
63+
return True
64+
65+
66+
@dynamo_tensorrt_converter(
67+
torch.ops.aten.rand.default, capability_validator=rand_validator
68+
)
69+
def aten_ops_rand(
70+
ctx: ConversionContext,
71+
target: Target,
72+
args: Tuple[Argument, ...],
73+
kwargs: Dict[str, Argument],
74+
name: str,
75+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
76+
return np.random.rand(*args[0])
77+
78+
79+
@dynamo_tensorrt_converter(
80+
torch.ops.aten.randn.default, capability_validator=rand_validator
81+
)
82+
def aten_ops_randn(
83+
ctx: ConversionContext,
84+
target: Target,
85+
args: Tuple[Argument, ...],
86+
kwargs: Dict[str, Argument],
87+
name: str,
88+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
89+
return np.random.randn(*args[0])
90+
91+
92+
def randperm_validator(randperm_node: Node) -> bool:
93+
dtype = randperm_node.kwargs.get("dtype", None)
94+
layout = randperm_node.kwargs.get("layout", None)
95+
input = randperm_node.args[0]
96+
if not isinstance(input, int):
97+
_LOGGER.error(f"Input should be of type int.")
98+
return False
99+
if dtype is not None:
100+
_LOGGER.debug(
101+
f"Currently we don't support specifying output dtype, got {dtype}."
102+
)
103+
return False
104+
if layout is not None:
105+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
106+
return False
107+
return True
108+
109+
110+
@dynamo_tensorrt_converter(
111+
torch.ops.aten.randperm.default, capability_validator=randperm_validator
112+
)
113+
def aten_ops_randperm(
114+
ctx: ConversionContext,
115+
target: Target,
116+
args: Tuple[Argument, ...],
117+
kwargs: Dict[str, Argument],
118+
name: str,
119+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
120+
return np.random.permutation(args[0])

tests/py/dynamo/conversion/harness.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@ def run_test_custom_compare_results(
138138
if len(expected_ops):
139139
self.assert_has_op(mod, expected_ops)
140140

141-
interpreter_result = interpreter.run(
142-
precision=torch.half if fp16_mode else torch.float
143-
)
141+
interpreter_result = interpreter.run()
144142
trt_mod = PythonTorchTensorRTModule(
145143
interpreter_result.engine,
146144
interpreter_result.input_names,
@@ -149,7 +147,6 @@ def run_test_custom_compare_results(
149147
res_trt = trt_mod(*cuda_inputs).cpu()
150148
res_cpu = mod(*cuda_inputs).cpu()
151149
assert len(res_trt) == len(res_cpu)
152-
assert len(res_cpu) == len(comparators)
153150
for output_trt, output_cpu, comparator in zip(
154151
res_trt, res_cpu, comparators
155152
):
@@ -270,6 +267,42 @@ def run_test(
270267
check_dtype,
271268
)
272269

270+
def run_test_compare_tensor_attributes_only(
271+
self,
272+
mod,
273+
inputs,
274+
expected_ops,
275+
comparators: List[Tuple[Callable, List]],
276+
precision=torch.float,
277+
output_dtypes=None,
278+
use_dynamo_tracer=False,
279+
enable_passes=False,
280+
):
281+
mod.eval()
282+
mod = self.generate_graph(
283+
mod,
284+
inputs,
285+
use_dynamo_tracer=use_dynamo_tracer,
286+
enable_passes=enable_passes,
287+
)
288+
# Previous instance of the interpreter auto-casted 64-bit inputs
289+
# We replicate this behavior here
290+
compilation_settings = CompilationSettings(
291+
enabled_precisions={dtype._from(precision)},
292+
truncate_long_and_double=True,
293+
debug=True,
294+
)
295+
296+
interp = TRTInterpreter(
297+
mod,
298+
Input.from_tensors(inputs),
299+
output_dtypes=output_dtypes,
300+
compilation_settings=compilation_settings,
301+
)
302+
super().run_test_custom_compare_results(
303+
mod, inputs, expected_ops, interp, comparators
304+
)
305+
273306
def run_test_with_dynamic_shape(
274307
self,
275308
mod,
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch_tensorrt
4+
from parameterized import parameterized
5+
from torch.testing._internal.common_utils import TestCase, run_tests
6+
7+
from .harness import DispatchTestCase
8+
9+
rand_ops = [
10+
(
11+
"rand_one_dimension",
12+
(lambda shape: torch.ops.aten.rand(shape)),
13+
[1],
14+
),
15+
(
16+
"rand_two_dimension",
17+
(lambda shape: torch.ops.aten.rand(shape)),
18+
[1, 2],
19+
),
20+
(
21+
"rand_three_dimension",
22+
(lambda shape: torch.ops.aten.rand(shape)),
23+
[2, 3, 4],
24+
),
25+
(
26+
"randn_one_dimension",
27+
(lambda shape: torch.ops.aten.randn(shape)),
28+
[1],
29+
),
30+
(
31+
"randn_two_dimension",
32+
(lambda shape: torch.ops.aten.randn(shape)),
33+
[2, 3],
34+
),
35+
(
36+
"randn_three_dimension",
37+
(lambda shape: torch.ops.aten.randn(shape)),
38+
[2, 3, 4],
39+
),
40+
]
41+
42+
43+
rand_perm_ops = [
44+
(
45+
"randperm_one_case",
46+
(lambda x: torch.ops.aten.randperm(x)),
47+
[1],
48+
),
49+
(
50+
"randperm_two_case",
51+
(lambda x: torch.ops.aten.randperm(x)),
52+
[150],
53+
),
54+
(
55+
"randperm_three_case",
56+
(lambda x: torch.ops.aten.randperm(x)),
57+
[1500],
58+
),
59+
]
60+
61+
62+
class TestRandConverter(DispatchTestCase):
63+
@parameterized.expand(
64+
[
65+
(
66+
rand_op[0],
67+
rand_op[1],
68+
rand_op[2],
69+
)
70+
for rand_op in rand_ops
71+
]
72+
)
73+
def test_rand(self, name, op, shape_or_input):
74+
class TestModule(nn.Module):
75+
def __init__(self):
76+
super().__init__()
77+
78+
def forward(self, x):
79+
shape_or_input[0] = x.shape[0]
80+
return op(shape_or_input)
81+
82+
rand_model = TestModule()
83+
84+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
85+
comparator_shape = lambda x, y, check_dtype: x.shape == y.shape and (
86+
x.dtype == y.dtype if check_dtype else True
87+
)
88+
expected_ops = []
89+
self.run_test_compare_tensor_attributes_only(
90+
rand_model,
91+
inputs,
92+
expected_ops,
93+
[(comparator_shape, [True])],
94+
use_dynamo_tracer=True,
95+
)
96+
97+
@parameterized.expand(
98+
[
99+
(
100+
rand_op[0],
101+
rand_op[1],
102+
rand_op[2],
103+
)
104+
for rand_op in rand_perm_ops
105+
]
106+
)
107+
def test_rand(self, name, op, shape_or_input):
108+
class TestModule(nn.Module):
109+
def __init__(self):
110+
super().__init__()
111+
112+
def forward(self, x):
113+
shape_or_input[0] = x.shape[0]
114+
return op(shape_or_input[0])
115+
116+
rand_model = TestModule()
117+
# cannot use self.run_test() since it expects input in form of tensor
118+
119+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
120+
comparator_shape = lambda x, y, check_dtype: x.shape == y.shape and (
121+
x.dtype == y.dtype if check_dtype else True
122+
)
123+
expected_ops = []
124+
# TRT-TRT returns int32 while torch returns int64
125+
self.run_test_compare_tensor_attributes_only(
126+
rand_model,
127+
inputs,
128+
expected_ops,
129+
[(comparator_shape, [False])],
130+
use_dynamo_tracer=True,
131+
)
132+
133+
134+
if __name__ == "__main__":
135+
run_tests()

0 commit comments

Comments
 (0)