Skip to content

Commit ed83e0a

Browse files
committed
Remove transpose_input from fbgemm configs
Summary: This is actually not needed since people can manually transpose the weights beforehand Test Plan: ``` python test/dtypes/test_fbgemm_fp8.py -k test_bmm python test/dtypes/test_fbgemm_int4.py -k test_bmm ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2422, branch: jerryzh168/stack/2
1 parent 027648f commit ed83e0a

File tree

5 files changed

+4
-17
lines changed

5 files changed

+4
-17
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def forward(self, x):
128128
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
129129
m = M(weight).eval()
130130
original = m(input)
131+
# we need to transpose the weight first for bmm
132+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
131133
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
132134
quantized = m(input)
133135
self.assertTrue(compute_error(original, quantized) > 20)

test/dtypes/test_fbgemm_int4.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def setUp(self):
3939
weight_dtype=torch.int4,
4040
output_dtype=torch.bfloat16,
4141
block_size=[1, 1, 128],
42-
transpose_input=True,
4342
)
4443
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4544

@@ -134,6 +133,8 @@ def forward(self, x):
134133
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
135134
m = M(weight).eval()
136135
original = m(input)
136+
# we need to transpose the weight first for bmm
137+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
137138
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
138139
quantized = m(input)
139140
self.assertTrue(compute_error(original, quantized) > 18)

torchao/dtypes/fbgemm_fp8_tensor.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def from_float(
9090
cls,
9191
w: torch.Tensor,
9292
activation_scale_ub: Optional[float] = None,
93-
transpose_input: bool = False,
9493
):
9594
if activation_scale_ub is None:
9695
activation_scale_ub = 1200.0
@@ -100,12 +99,6 @@ def from_float(
10099
dtype=torch.float,
101100
device=w.device,
102101
)
103-
if transpose_input:
104-
if w.ndim == 3:
105-
w = w.transpose(-1, -2)
106-
else:
107-
w = w.t()
108-
109102
wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
110103
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
111104
dtype = w.dtype

torchao/dtypes/fbgemm_int4_tensor.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,13 @@ def from_float(
9393
cls,
9494
w: torch.Tensor,
9595
block_size: List[int],
96-
transpose_input: bool = False,
9796
):
9897
assert len(block_size) == w.ndim, (
9998
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
10099
)
101100
if int4_row_quantize_zp is None:
102101
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")
103102

104-
if transpose_input:
105-
if w.ndim == 3:
106-
w = w.transpose(-1, -2)
107-
else:
108-
w = w.t()
109-
110103
group_size = block_size[-1]
111104
original_shape = w.shape
112105

torchao/quantization/quant_api.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,7 +2000,6 @@ class FbgemmConfig(AOBaseConfig):
20002000
output_dtype: torch.dtype
20012001
block_size: Optional[List[int]] = None
20022002
activation_scale_ub: Optional[float] = None
2003-
transpose_input: bool = False
20042003
preshuffle: bool = False
20052004

20062005

@@ -2032,7 +2031,6 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
20322031
weight = to_fbgemm_int4(
20332032
module.weight,
20342033
config.block_size,
2035-
config.transpose_input,
20362034
)
20372035
module.weight = torch.nn.Parameter(weight, requires_grad=False)
20382036
module.extra_repr = types.MethodType(_linear_extra_repr, module)

0 commit comments

Comments
 (0)