Skip to content

Commit 1aa6813

Browse files
authored
[FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder. (NVIDIA#640)
* [FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder.
1 parent 280e75c commit 1aa6813

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

FasterTransformer/v2.1/fastertransformer/trt_plugin/bert_transformer_plugin.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,13 @@ class TransformerPlugin: public IPluginV2
269269

270270
bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override
271271
{
272-
return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW;
272+
return type == TransformerTrtTraits<T>::DataType && format == PluginFormat::kNCHW;
273273
}
274274

275275
void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim,
276276
int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override
277277
{
278-
assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW);
278+
assert(dataType == TransformerTrtTraits<T>::DataType && pluginFormat == nvinfer1::PluginFormat::kNCHW);
279279
assert(nInputDim == 2);
280280
assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_);
281281
assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_);

FasterTransformer/v2.1/fastertransformer/trt_plugin/trt_model.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class TRT_Transformer
104104

105105
builder->setMaxBatchSize(batch_size_);
106106
builder->setMaxWorkspaceSize(1 << 20);
107-
builder->setFp16Mode(false);
107+
builder->setFp16Mode(sizeof(T) == 2);
108108

109109
engine_ = builder->buildCudaEngine(*network);
110110
assert(engine_);

0 commit comments

Comments
 (0)