Skip to content

Remove transpose_input from fbgemm configs #2422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: jerryzh168/stack/1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/dtypes/test_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def forward(self, x):
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
m = M(weight).eval()
original = m(input)
# we need to transpose the weight first for bmm
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 20)
Expand Down
3 changes: 2 additions & 1 deletion test/dtypes/test_fbgemm_int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def setUp(self):
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
transpose_input=True,
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

Expand Down Expand Up @@ -134,6 +133,8 @@ def forward(self, x):
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
m = M(weight).eval()
original = m(input)
# we need to transpose the weight first for bmm
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)
Expand Down
7 changes: 0 additions & 7 deletions torchao/dtypes/fbgemm_fp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def from_float(
cls,
w: torch.Tensor,
activation_scale_ub: Optional[float] = None,
transpose_input: bool = False,
):
if activation_scale_ub is None:
activation_scale_ub = 1200.0
Expand All @@ -100,12 +99,6 @@ def from_float(
dtype=torch.float,
device=w.device,
)
if transpose_input:
if w.ndim == 3:
w = w.transpose(-1, -2)
else:
w = w.t()

wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
dtype = w.dtype
Expand Down
7 changes: 0 additions & 7 deletions torchao/dtypes/fbgemm_int4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,13 @@ def from_float(
cls,
w: torch.Tensor,
block_size: List[int],
transpose_input: bool = False,
):
assert len(block_size) == w.ndim, (
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
)
if int4_row_quantize_zp is None:
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")

if transpose_input:
if w.ndim == 3:
w = w.transpose(-1, -2)
else:
w = w.t()

group_size = block_size[-1]
original_shape = w.shape

Expand Down
2 changes: 0 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,7 +2046,6 @@ class FbgemmConfig(AOBaseConfig):
output_dtype: torch.dtype
block_size: Optional[List[int]] = None
activation_scale_ub: Optional[float] = None
transpose_input: bool = False
preshuffle: bool = False


Expand Down Expand Up @@ -2080,7 +2079,6 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
weight = to_fbgemm_int4(
module.weight,
config.block_size,
config.transpose_input,
)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
Expand Down
Loading