Skip to content

Commit 3b71e07

Browse files
committed
Fix NVFP4 backward typo
**Summary:** Fix `to_dtype` -> `to` **Test Plan:** ``` python test/quantization/test_qat.py -k nvfp4_training ```
1 parent 31192f2 commit 3b71e07

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

test/quantization/test_qat.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,6 +2148,40 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21482148
sqnr = compute_error(out, baseline_out).item()
21492149
self.assertGreaterEqual(sqnr, float("inf"))
21502150

2151+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
2152+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2153+
@parametrize("use_per_tensor_scale", [True, False])
2154+
def test_qat_nvfp4_training(self, use_per_tensor_scale: bool):
2155+
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
2156+
2157+
torch.manual_seed(self.SEED)
2158+
m = M().cuda()
2159+
base_config = NVFP4DynamicActivationNVFP4WeightConfig(
2160+
use_dynamic_per_tensor_scale=use_per_tensor_scale
2161+
)
2162+
quantize_(m, QATConfig(base_config, step="prepare"))
2163+
2164+
# Simulate training
2165+
num_steps = 10
2166+
optimizer = torch.optim.SGD(
2167+
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
2168+
)
2169+
loss_fn = torch.nn.CrossEntropyLoss()
2170+
for i in range(num_steps):
2171+
example_inputs = m.example_inputs("cuda")
2172+
prev_weight = copy.deepcopy(m.linear1.weight)
2173+
optimizer.zero_grad()
2174+
target = torch.randn(1, 512).float().cuda()
2175+
out = m(*example_inputs)
2176+
loss = loss_fn(out, target)
2177+
loss.backward()
2178+
optimizer.step()
2179+
# Assert that weights have valid gradients and are being updated
2180+
new_weight = m.linear1.weight
2181+
self.assertIsNotNone(new_weight.grad)
2182+
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
2183+
self.assertFalse(torch.equal(new_weight, prev_weight))
2184+
21512185
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21522186
@unittest.skipIf(
21532187
not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0"

torchao/prototype/qat/nvfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
9191
_input, weight = ctx.saved_tensors
9292
assert isinstance(_input, NVFP4Tensor)
9393
assert isinstance(weight, NVFP4Tensor)
94-
_input = _input.to_dtype(_input._orig_dtype)
95-
weight = weight.to_dtype(weight._orig_dtype)
94+
_input = _input.to(_input._orig_dtype)
95+
weight = weight.to(weight._orig_dtype)
9696
grad_input = torch.mm(grad_output, weight)
9797
grad_weight = torch.mm(grad_output.t(), _input)
9898
return grad_input, grad_weight, None, None, None

0 commit comments

Comments
 (0)