Skip to content

Commit 6cf72ab

Browse files
authored
feat: Add bf16 support to cast converter (#3643)
1 parent 001fe31 commit 6cf72ab

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
10341034
torch.bool,
10351035
torch.int8,
10361036
torch.float16,
1037+
torch.bfloat16,
10371038
}
10381039

10391040
# Validate input node has convertible kwargs

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,16 @@ def pow(
544544
lhs_val: Union[TRTTensor, int, float],
545545
rhs_val: Union[TRTTensor, int, float],
546546
) -> TRTTensor:
547+
548+
lhs_dtype = None
549+
rhs_dtype = None
550+
if isinstance(lhs_val, int):
551+
lhs_dtype = torch.int32
552+
if isinstance(rhs_val, int):
553+
rhs_dtype = torch.int32
547554
# POW operation supports only float32 and int8 inputs
548-
lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", trt.float32)
549-
rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", trt.float32)
555+
lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", lhs_dtype)
556+
rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", rhs_dtype)
550557
out = convert_binary_elementwise(
551558
ctx, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val
552559
)

tests/py/dynamo/conversion/harness.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def run_test(
412412
propagate_shapes=False,
413413
int32_reqd=False,
414414
immutable_weights=True,
415+
use_explicit_typing=False,
415416
):
416417
# TODO: lan to remove this and set use_dynamo_traccer to True by default
417418
# once all the converter test files are moved to use_dynamo_tracer
@@ -422,6 +423,7 @@ def run_test(
422423
enabled_precisions={dtype._from(precision)},
423424
truncate_double=True,
424425
immutable_weights=immutable_weights,
426+
use_explicit_typing=use_explicit_typing,
425427
)
426428

427429
mod = self.generate_graph(

tests/py/dynamo/conversion/test_casts.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ def forward(self, x):
6464
precision=torch.float,
6565
)
6666

67+
def test_to_copy_bfloat16(self):
68+
class ToCopyBFloat16(nn.Module):
69+
def forward(self, x):
70+
y = torch.ops.aten._to_copy.default(x, dtype=torch.bfloat16)
71+
y = y**2
72+
return y
73+
74+
inputs = [torch.rand((1, 3, 10), dtype=torch.float32)]
75+
self.run_test(
76+
ToCopyBFloat16(),
77+
inputs,
78+
precision=torch.float,
79+
use_explicit_typing=True,
80+
)
81+
6782
def test_to_copy_i64b(self):
6883
class ToCopy64Bit(nn.Module):
6984
def forward(self, x):

0 commit comments

Comments
 (0)