@@ -117,6 +117,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
117117 for i in range (iters ):
118118 A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
119119 C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested )
120+ if i == 0 :
121+ d = S .as_dict ()
122+ S = F .QuantState .from_dict (d , device = torch .device (device ))
120123 A2 = F .dequantize_blockwise (C , S )
121124 diff = torch .abs (A1 - A2 ).float ()
122125 reldiff = diff / torch .abs (A1 .float () + 1e-8 )
@@ -133,6 +136,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
133136 for i in range (iters ):
134137 A1 = torch .rand (1024 , 1024 , device = device , dtype = dtype )
135138 C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested , code = code )
139+ if i == 0 :
140+ d = S .as_dict ()
141+ S = F .QuantState .from_dict (d , device = torch .device (device ))
136142 A2 = F .dequantize_blockwise (C , S )
137143 diff = torch .abs (A1 - A2 ).float ()
138144 reldiff = diff / torch .abs (A1 .float () + 1e-8 )
@@ -242,6 +248,9 @@ def test_fp8_quant(self, device):
242248 for i in range (100 ):
243249 A1 = torch .randn (1024 , 1024 , device = device )
244250 C , SC = F .quantize_blockwise (A1 , code = code )
251+ if i == 0 :
252+ d = SC .as_dict ()
253+ SC = F .QuantState .from_dict (d , device = torch .device (device ))
245254 A2 = F .dequantize_blockwise (C , SC )
246255 diff = torch .abs (A1 - A2 )
247256 reldiff = diff / torch .abs (A1 + 1e-8 )
@@ -1115,6 +1124,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11151124
11161125 A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
11171126 qa , SA = F .quantize_4bit (A1 , blocksize = blocksize , quant_type = quant_type )
1127+ d = SA .as_dict ()
1128+ SA = F .QuantState .from_dict (d , device = torch .device (device ))
11181129 A2 = F .dequantize_4bit (qa , SA , blocksize = blocksize , quant_type = quant_type )
11191130
11201131 err = (A1 - A2 ).abs ().float ()
0 commit comments