File tree Expand file tree Collapse file tree 3 files changed +11
-12
lines changed Expand file tree Collapse file tree 3 files changed +11
-12
lines changed Original file line number Diff line number Diff line change 19
19
import numpy as np
20
20
import tensorrt as trt
21
21
import torch
22
- import torch_tensorrt .dynamo .conversion .impl as impl
23
22
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
24
23
from torch .fx .node import Argument , Target
25
24
from torch .fx .passes .shape_prop import TensorMetadata
25
+
26
+ import torch_tensorrt .dynamo .conversion .impl as impl
26
27
from torch_tensorrt import _enums
27
28
from torch_tensorrt .dynamo ._settings import CompilationSettings
28
29
from torch_tensorrt .dynamo ._SourceIR import SourceIR
@@ -152,9 +153,9 @@ def cast_trt_tensor(
152
153
) -> TRTTensor :
153
154
"""Given a TRT Tensor, convert that Tensor to the specified dtype
154
155
155
- Adds an Identity layer to the network which performs the conversion
156
- if the input's dtype is different from the cast type. Otherwise returns
157
- input unchanged
156
+ Adds a Cast layer to the network to convert the input tensor to the specified dtype.
157
+ If the input tensor already has the desired dtype, it is returned unchanged.
158
+ Otherwise, a Cast layer is added to perform the conversion
158
159
159
160
Args:
160
161
ctx (ConversionContext): A ConversionContext containing the TensorRT network
Original file line number Diff line number Diff line change 5
5
import numpy as np
6
6
import tensorrt as trt
7
7
from torch .fx .node import Target
8
+
8
9
from torch_tensorrt .dynamo ._SourceIR import SourceIR
9
10
from torch_tensorrt .dynamo .conversion import impl
10
11
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
13
14
flatten_dims ,
14
15
get_positive_dim ,
15
16
get_trt_tensor ,
17
+ has_dynamic_shape ,
18
+ prepend_ones ,
19
+ set_layer_name ,
16
20
)
17
21
from torch_tensorrt .dynamo .conversion .impl .cat import cat
18
22
from torch_tensorrt .dynamo .conversion .impl .elementwise import floor_divide
23
27
from torch_tensorrt .dynamo .conversion .impl .shape import shape as get_shape
24
28
from torch_tensorrt .dynamo .conversion .impl .slice .base import slice
25
29
from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
26
- from torch_tensorrt .fx .converters .converter_utils import (
27
- has_dynamic_shape ,
28
- prepend_ones ,
29
- set_layer_name ,
30
- )
31
30
from torch_tensorrt .fx .types import Shape , TRTTensor
32
31
33
32
@@ -230,7 +229,7 @@ def expand(
230
229
# If the rank of the input tensor is less than the shape's rank, pad with ones
231
230
if initial_tensor_rank < shape_rank :
232
231
input_t = prepend_ones (
233
- ctx . net ,
232
+ ctx ,
234
233
input_t ,
235
234
name + "_expand_broadcast" ,
236
235
shape_rank - initial_tensor_rank ,
Original file line number Diff line number Diff line change @@ -909,7 +909,6 @@ def type_cast(
909
909
"""
910
910
This function helps to cast the input type to cast_type
911
911
"""
912
- layer_i = network .add_identity (input )
913
- layer_i .set_output_type (0 , cast_type )
912
+ layer_i = network .add_cast (input , cast_type )
914
913
set_layer_name (layer_i , target , f"{ name } _dtype_change" )
915
914
return layer_i .get_output (0 )
You can’t perform that action at this time.
0 commit comments