Skip to content

Commit a5543c9

Browse files
committed
Update README.md with new API & pointing to NGC container
1 parent 65ffaef commit a5543c9

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

README.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,26 @@ trt_mod.save("trt_torchscript_module.ts");
3939
import torch_tensorrt
4040
4141
...
42-
compile_settings = {
43-
"inputs": [torch_tensorrt.Input(
44-
min_shape=[1, 3, 224, 224],
45-
opt_shape=[1, 3, 512, 512],
46-
max_shape=[1, 3, 1024, 1024],
47-
# For static size shape=[1, 3, 224, 224]
48-
dtype=torch.half, # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
49-
)],
50-
"enabled_precisions": {torch.half}, # Run with FP16
51-
}
52-
53-
trt_ts_module = torch_tensorrt.compile(torch_script_module, compile_settings)
54-
55-
input_data = input_data.half()
56-
result = trt_ts_module(input_data)
57-
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
42+
43+
trt_ts_module = torch_tensorrt.compile(torch_script_module,
44+
inputs = [example_tensor, # Provide example tensor for input shape or...
45+
torch_tensorrt.Input( # Specify input object with shape and dtype
46+
min_shape=[1, 3, 224, 224],
47+
opt_shape=[1, 3, 512, 512],
48+
max_shape=[1, 3, 1024, 1024],
49+
# For static size shape=[1, 3, 224, 224]
50+
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
51+
],
52+
enabled_precisions = {torch.half}, # Run with FP16)
53+
54+
result = trt_ts_module(input_data) # run inference
55+
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedded Torchscript
5856
```
5957

6058
> Notes on running in lower precisions:
6159
> - Enabled lower precisions with compile_spec.enabled_precisions
6260
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
63-
> - In FP16 only input tensors by default should be FP16, other precisions use FP32. This can be overrided by setting Input::dtype
61+
> - Provided input tensors dtype should be the same as module before compilation, regardless of `enabled_precisions`. This can be overrided by setting `Input::dtype`
6462
6563
## Platform Support
6664

@@ -72,6 +70,8 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
7270
| Windows / GPU | **Unofficial Support** |
7371
| Linux ppc64le / GPU | - |
7472

73+
Torch-TensorRT will be included in NVIDIA NGC containers (https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) starting in 21.11.
74+
7575
> Note: Refer NVIDIA NGC container(https://ngc.nvidia.com/catalog/containers/nvidia:l4t-pytorch) for PyTorch libraries on JetPack.
7676
7777
### Dependencies

0 commit comments

Comments
 (0)