@@ -2148,6 +2148,40 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21482148 sqnr = compute_error (out , baseline_out ).item ()
21492149 self .assertGreaterEqual (sqnr , float ("inf" ))
21502150
2151+ @unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
2152+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2153+ @parametrize ("use_per_tensor_scale" , [True , False ])
2154+ def test_qat_nvfp4_training (self , use_per_tensor_scale : bool ):
2155+ from torchao .prototype .mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
2156+
2157+ torch .manual_seed (self .SEED )
2158+ m = M ().cuda ()
2159+ base_config = NVFP4DynamicActivationNVFP4WeightConfig (
2160+ use_dynamic_per_tensor_scale = use_per_tensor_scale
2161+ )
2162+ quantize_ (m , QATConfig (base_config , step = "prepare" ))
2163+
2164+ # Simulate training
2165+ num_steps = 10
2166+ optimizer = torch .optim .SGD (
2167+ m .parameters (), lr = 0.001 , momentum = 0.9 , weight_decay = 1e-5
2168+ )
2169+ loss_fn = torch .nn .CrossEntropyLoss ()
2170+ for i in range (num_steps ):
2171+ example_inputs = m .example_inputs ("cuda" )
2172+ prev_weight = copy .deepcopy (m .linear1 .weight )
2173+ optimizer .zero_grad ()
2174+ target = torch .randn (1 , 512 ).float ().cuda ()
2175+ out = m (* example_inputs )
2176+ loss = loss_fn (out , target )
2177+ loss .backward ()
2178+ optimizer .step ()
2179+ # Assert that weights have valid gradients and are being updated
2180+ new_weight = m .linear1 .weight
2181+ self .assertIsNotNone (new_weight .grad )
2182+ self .assertNotEqual (torch .count_nonzero (new_weight .grad ), 0 )
2183+ self .assertFalse (torch .equal (new_weight , prev_weight ))
2184+
21512185 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
21522186 @unittest .skipIf (
21532187 not _is_fbgemm_gpu_genai_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
0 commit comments