Skip to content

feat: Add bf16 support to cast converter #3643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
torch.bool,
torch.int8,
torch.float16,
torch.bfloat16,
}

# Validate input node has convertible kwargs
Expand Down
11 changes: 9 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,16 @@ def pow(
lhs_val: Union[TRTTensor, int, float],
rhs_val: Union[TRTTensor, int, float],
) -> TRTTensor:

lhs_dtype = None
rhs_dtype = None
if isinstance(lhs_val, int):
lhs_dtype = torch.int32
if isinstance(rhs_val, int):
rhs_dtype = torch.int32
# POW operation supports only float32 and int8 inputs
lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", trt.float32)
rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", trt.float32)
lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", lhs_dtype)
rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", rhs_dtype)
out = convert_binary_elementwise(
ctx, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val
)
Expand Down
2 changes: 2 additions & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def run_test(
propagate_shapes=False,
int32_reqd=False,
immutable_weights=True,
use_explicit_typing=False,
):
# TODO: lan to remove this and set use_dynamo_traccer to True by default
# once all the converter test files are moved to use_dynamo_tracer
Expand All @@ -422,6 +423,7 @@ def run_test(
enabled_precisions={dtype._from(precision)},
truncate_double=True,
immutable_weights=immutable_weights,
use_explicit_typing=use_explicit_typing,
)

mod = self.generate_graph(
Expand Down
15 changes: 15 additions & 0 deletions tests/py/dynamo/conversion/test_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ def forward(self, x):
precision=torch.float,
)

def test_to_copy_bfloat16(self):
class ToCopyBFloat16(nn.Module):
def forward(self, x):
y = torch.ops.aten._to_copy.default(x, dtype=torch.bfloat16)
y = y**2
return y

inputs = [torch.rand((1, 3, 10), dtype=torch.float32)]
self.run_test(
ToCopyBFloat16(),
inputs,
precision=torch.float,
use_explicit_typing=True,
)

def test_to_copy_i64b(self):
class ToCopy64Bit(nn.Module):
def forward(self, x):
Expand Down
Loading