Skip to content

Commit 93e4c4a

Browse files
committed
Merge branch 'ncomly-torch_tensorrt_rebrand-patch-42370' into 'release/1.0'
Update README.md with new API & pointing to NGC container See merge request adlsa/TRTorch!21
2 parents 79904bf + a5543c9 commit 93e4c4a

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

README.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,29 +55,27 @@ trt_mod.save("trt_torchscript_module.ts");
5555
import torch_tensorrt
5656
5757
...
58-
compile_settings = {
59-
"inputs": [torch_tensorrt.Input(
60-
min_shape=[1, 3, 224, 224],
61-
opt_shape=[1, 3, 512, 512],
62-
max_shape=[1, 3, 1024, 1024],
63-
# For static size shape=[1, 3, 224, 224]
64-
dtype=torch.half, # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
65-
)],
66-
"enabled_precisions": {torch.half}, # Run with FP16
67-
}
68-
69-
trt_ts_module = torch_tensorrt.compile(torch_script_module, compile_settings)
70-
71-
input_data = input_data.half()
72-
result = trt_ts_module(input_data)
73-
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
58+
59+
trt_ts_module = torch_tensorrt.compile(torch_script_module,
60+
inputs = [example_tensor, # Provide example tensor for input shape or...
61+
torch_tensorrt.Input( # Specify input object with shape and dtype
62+
min_shape=[1, 3, 224, 224],
63+
opt_shape=[1, 3, 512, 512],
64+
max_shape=[1, 3, 1024, 1024],
65+
# For static size shape=[1, 3, 224, 224]
66+
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
67+
],
68+
enabled_precisions = {torch.half}, # Run with FP16)
69+
70+
result = trt_ts_module(input_data) # run inference
71+
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedded Torchscript
7472
```
7573

7674
> Notes on running in lower precisions:
7775
>
7876
> - Enabled lower precisions with compile_spec.enabled_precisions
7977
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
80-
> - In FP16 only input tensors by default should be FP16, other precisions use FP32. This can be overrided by setting Input::dtype
78+
> - Provided input tensors dtype should be the same as module before compilation, regardless of `enabled_precisions`. This can be overrided by setting `Input::dtype`
8179
8280
## Platform Support
8381

@@ -89,6 +87,8 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
8987
| Windows / GPU | **Unofficial Support** |
9088
| Linux ppc64le / GPU | - |
9189

90+
Torch-TensorRT will be included in NVIDIA NGC containers (https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) starting in 21.11.
91+
9292
> Note: Refer NVIDIA NGC container(https://ngc.nvidia.com/catalog/containers/nvidia:l4t-pytorch) for PyTorch libraries on JetPack.
9393
9494
### Dependencies

0 commit comments

Comments
 (0)