Skip to content

Commit 0bcda12

Browse files
committed
Fixed fp16 quantization error
1 parent 18b6455 commit 0bcda12

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def quantize(
6666
if not isinstance(amax, trt.ITensor):
6767
amax = to_torch(amax, None)
6868
scale = torch.divide(amax, max_bound)
69-
scale = get_trt_tensor(ctx, scale, name + "_scale")
69+
scale = get_trt_tensor(ctx, scale, name + "_scale", dtype=torch.float32)
7070
else:
7171
scale = impl.elementwise.div(
7272
ctx,
@@ -76,7 +76,7 @@ def quantize(
7676
amax,
7777
max_bound,
7878
)
79-
scale = get_trt_tensor(ctx, scale, name + "_scale")
79+
scale = get_trt_tensor(ctx, scale, name + "_scale", dtype=torch.float32)
8080

8181
# Add Q node
8282
if num_bits == 8 and exponent_bits == 0:
@@ -96,7 +96,6 @@ def quantize(
9696
q_output, scale, output_type=input_tensor.dtype
9797
)
9898
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
99-
dequantize_layer.precision = dtype
10099

101100
dq_output = dequantize_layer.get_output(0)
102101

0 commit comments

Comments
 (0)