You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: py/torch_tensorrt/dynamo/_TorchTensorRTModule.py
+9-13
Original file line number
Diff line number
Diff line change
@@ -1,6 +1,5 @@
1
1
importlogging
2
-
fromoperatorimporttruediv
3
-
fromtypingimportAny, List, Sequence, Tuple
2
+
fromtypingimportAny, List, Tuple
4
3
5
4
importtorch
6
5
fromtorch_tensorrtimport_C
@@ -9,8 +8,8 @@
9
8
logger=logging.getLogger(__name__)
10
9
11
10
12
-
classTRTModuleNext(torch.nn.Module):
13
-
"""TRTModuleNext is a PyTorch module which encompasses an arbitrary TensorRT Engine.
11
+
classTorchTensorRTModule(torch.nn.Module):
12
+
"""TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
14
13
15
14
This module is backed by the Torch-TensorRT runtime and is fully compatibile with both
16
15
FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as
@@ -20,7 +19,7 @@ class TRTModuleNext(torch.nn.Module):
20
19
The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where
21
20
the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))``
22
21
23
-
> Note: TRTModuleNext only supports engines built with explict batch
22
+
> Note: TorchTensorRTModule only supports engines built with explict batch
24
23
25
24
Attributes:
26
25
name (str): Name of module (for easier debugging)
@@ -37,7 +36,7 @@ def __init__(
37
36
output_binding_names: List[str] = [],
38
37
target_device: Device=Device._current_device(),
39
38
):
40
-
"""__init__ method for torch_tensorrt.TRTModuleNext
39
+
"""__init__ method for torch_tensorrt.TorchTensorRTModule
41
40
42
41
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
43
42
a PyTorch ``torch.nn.Module`` around it.
@@ -70,10 +69,7 @@ def __init__(
70
69
)
71
70
72
71
"""
73
-
logger.warning(
74
-
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
75
-
)
76
-
super(TRTModuleNext, self).__init__()
72
+
super(TorchTensorRTModule, self).__init__()
77
73
78
74
ifnotisinstance(serialized_engine, bytearray):
79
75
ValueError("Expected serialized engine as bytearray")
0 commit comments