Skip to content

Commit 9860c56

Browse files
committed
fix rebase issue
1 parent 02d045b commit 9860c56

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

test/float8/test_compile.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@
3737
hp_tensor_to_float8_dynamic,
3838
)
3939
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
40-
from torchao.quantization.quant_primitives import (
41-
dequantize_affine_float8,
42-
quantize_affine_float8,
43-
)
4440
from torchao.testing.float8.test_utils import get_test_float8_linear_config
4541

4642

@@ -412,14 +408,16 @@ def test_dynamic_scale_numeric_parity(
412408
],
413409
)
414410
def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype):
411+
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
412+
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
415413
input = torch.randn(10, 10)
416414
with torch.no_grad():
417415
torch._dynamo.reset()
418416
expected_scale = torch.tensor(2.0)
419417
expected_quantized = quantize_affine_float8(
420418
input,
421419
expected_scale,
422-
float8_dtype,
420+
float8_dtype=float8_dtype,
423421
)
424422
expected_dequantized = dequantize_affine_float8(
425423
expected_quantized,
@@ -430,7 +428,7 @@ def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype):
430428
torch.compile(quantize_affine_float8),
431429
input,
432430
expected_scale,
433-
float8_dtype,
431+
float8_dtype=float8_dtype,
434432
)
435433
torch.testing.FileCheck().check(
436434
"torch.ops.torchao.quantize_affine_float8.default"

torchao/quantization/quant_primitives.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2274,7 +2274,7 @@ def _expand_scale_to_tensor_shape(
22742274
def _quantize_affine_float8(
22752275
tensor: torch.Tensor,
22762276
scale: torch.Tensor,
2277-
float8_dtype: torch.dtype,
2277+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
22782278
) -> torch.Tensor:
22792279
"""
22802280
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
@@ -2295,7 +2295,7 @@ def _quantize_affine_float8(
22952295
def _quantize_affine_float8_meta(
22962296
tensor: torch.Tensor,
22972297
scale: torch.Tensor,
2298-
float8_dtype: torch.dtype,
2298+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
22992299
) -> torch.Tensor:
23002300
return torch.empty_like(tensor, dtype=float8_dtype)
23012301

@@ -2304,7 +2304,7 @@ def _quantize_affine_float8_meta(
23042304
def _dequantize_affine_float8(
23052305
tensor: torch.Tensor,
23062306
scale: torch.Tensor,
2307-
output_dtype: torch.dtype,
2307+
output_dtype: torch.dtype = torch.float32,
23082308
) -> torch.Tensor:
23092309
"""
23102310
Dequantizes the float8 tensor to high precision tensor.
@@ -2322,6 +2322,6 @@ def _dequantize_affine_float8(
23222322
def _dequantize_affine_float8_meta(
23232323
tensor: torch.Tensor,
23242324
scale: torch.Tensor,
2325-
output_dtype: torch.dtype,
2325+
output_dtype: torch.dtype = torch.float32,
23262326
) -> torch.Tensor:
23272327
return torch.empty_like(tensor, dtype=output_dtype)

0 commit comments

Comments
 (0)