@@ -2149,6 +2149,40 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21492149 sqnr = compute_error (out , baseline_out ).item ()
21502150 self .assertGreaterEqual (sqnr , float ("inf" ))
21512151
2152+ @unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
2153+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2154+ @parametrize ("use_per_tensor_scale" , [True , False ])
2155+ def test_qat_nvfp4_training (self , use_per_tensor_scale : bool ):
2156+ from torchao .prototype .mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
2157+
2158+ torch .manual_seed (self .SEED )
2159+ m = M ().cuda ()
2160+ base_config = NVFP4DynamicActivationNVFP4WeightConfig (
2161+ use_dynamic_per_tensor_scale = use_per_tensor_scale
2162+ )
2163+ quantize_ (m , QATConfig (base_config , step = "prepare" ))
2164+
2165+ # Simulate training
2166+ num_steps = 10
2167+ optimizer = torch .optim .SGD (
2168+ m .parameters (), lr = 0.001 , momentum = 0.9 , weight_decay = 1e-5
2169+ )
2170+ loss_fn = torch .nn .CrossEntropyLoss ()
2171+ for i in range (num_steps ):
2172+ example_inputs = m .example_inputs ("cuda" )
2173+ prev_weight = copy .deepcopy (m .linear1 .weight )
2174+ optimizer .zero_grad ()
2175+ target = torch .randn (1 , 512 ).float ().cuda ()
2176+ out = m (* example_inputs )
2177+ loss = loss_fn (out , target )
2178+ loss .backward ()
2179+ optimizer .step ()
2180+ # Assert that weights have valid gradients and are being updated
2181+ new_weight = m .linear1 .weight
2182+ self .assertIsNotNone (new_weight .grad )
2183+ self .assertNotEqual (torch .count_nonzero (new_weight .grad ), 0 )
2184+ self .assertFalse (torch .equal (new_weight , prev_weight ))
2185+
21522186 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
21532187 @unittest .skipIf (
21542188 not _is_fbgemm_gpu_genai_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
0 commit comments