Skip to content

Commit acc9103

Browse files
andrewor14danielvegamyhre
authored andcommitted
Fix NVFP4 QAT backward typo (#3478)
**Summary:** Fix `to_dtype` -> `dequantize`. This was broken in #3169. **Test Plan:** ``` python test/quantization/test_qat.py -k nvfp4_training ```
1 parent 286c2d8 commit acc9103

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

test/quantization/test_qat.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,6 +2149,40 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21492149
sqnr = compute_error(out, baseline_out).item()
21502150
self.assertGreaterEqual(sqnr, float("inf"))
21512151

2152+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
2153+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2154+
@parametrize("use_per_tensor_scale", [True, False])
2155+
def test_qat_nvfp4_training(self, use_per_tensor_scale: bool):
2156+
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
2157+
2158+
torch.manual_seed(self.SEED)
2159+
m = M().cuda()
2160+
base_config = NVFP4DynamicActivationNVFP4WeightConfig(
2161+
use_dynamic_per_tensor_scale=use_per_tensor_scale
2162+
)
2163+
quantize_(m, QATConfig(base_config, step="prepare"))
2164+
2165+
# Simulate training
2166+
num_steps = 10
2167+
optimizer = torch.optim.SGD(
2168+
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
2169+
)
2170+
loss_fn = torch.nn.CrossEntropyLoss()
2171+
for i in range(num_steps):
2172+
example_inputs = m.example_inputs("cuda")
2173+
prev_weight = copy.deepcopy(m.linear1.weight)
2174+
optimizer.zero_grad()
2175+
target = torch.randn(1, 512).float().cuda()
2176+
out = m(*example_inputs)
2177+
loss = loss_fn(out, target)
2178+
loss.backward()
2179+
optimizer.step()
2180+
# Assert that weights have valid gradients and are being updated
2181+
new_weight = m.linear1.weight
2182+
self.assertIsNotNone(new_weight.grad)
2183+
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
2184+
self.assertFalse(torch.equal(new_weight, prev_weight))
2185+
21522186
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21532187
@unittest.skipIf(
21542188
not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0"

torchao/prototype/qat/nvfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
9191
_input, weight = ctx.saved_tensors
9292
assert isinstance(_input, NVFP4Tensor)
9393
assert isinstance(weight, NVFP4Tensor)
94-
_input = _input.to_dtype(_input._orig_dtype)
95-
weight = weight.to_dtype(weight._orig_dtype)
94+
_input = _input.dequantize(_input._orig_dtype)
95+
weight = weight.dequantize(weight._orig_dtype)
9696
grad_input = torch.mm(grad_output, weight)
9797
grad_weight = torch.mm(grad_output.t(), _input)
9898
return grad_input, grad_weight, None, None, None

0 commit comments

Comments
 (0)