Skip to content

Commit d9e267b

Browse files
committed
Update
[ghstack-poisoned]
2 parents 967ea76 + 6922733 commit d9e267b

File tree

49 files changed

+5376
-1613
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+5376
-1613
lines changed

.github/workflows/float8nocompile_test.yaml

Lines changed: 0 additions & 53 deletions
This file was deleted.

.github/workflows/torchao_experimental_test.yml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,8 @@ jobs:
3737
# of torch and torchao, which we do not want to use
3838
pip install executorch
3939
pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
40-
pip install numpy
41-
pip install pytest
42-
pip install parameterized
43-
USE_CPP=1 TOCHAO_BUILD_KLEIDIAI=1 pip install .
40+
pip install -r dev-requirements.txt
41+
USE_CPP=1 TORCHAO_BUILD_KLEIDIAI=1 pip install .
4442
- name: Run python tests
4543
run: |
4644
conda activate venv
@@ -99,11 +97,8 @@ jobs:
9997
python -c "import torch; print(torch.__version__)"
10098
- name: Install requirements
10199
run: |
102-
pip install cmake
103-
pip install parameterized
104-
pip install pyyaml
105-
pip install numpy
106-
pip install importlib-metadata
100+
pip install -r dev-requirements.txt
101+
pip install pyyaml importlib-metadata
107102
- name: Print pip freeze
108103
run: |
109104
pip freeze

dev-requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ importlib_metadata
2626
# Custom CUDA Extensions
2727
ninja
2828

29+
# CPU kernels
30+
cmake<4.0.0,>=3.19.0
31+
2932
# Linting
3033
ruff==0.6.8
3134
pre-commit

scripts/clean_release_notes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def format_commit(commit_line: str) -> str:
223223
After: * Commit title (https://github.com/pytorch/ao/pull/123)
224224
"""
225225
# Remove author, put PR link in parentheses
226-
commit_line = re.sub(" by @.* in (.*)", r" (\\g<1>)", commit_line)
226+
commit_line = re.sub(" by @.* in (.*)", r" (\g<1>)", commit_line)
227227
# Capitalize first letter
228228
commit_line = commit_line.lstrip("* ")
229229
commit_line = "* " + commit_line[0].upper() + commit_line[1:]

test/quantization/test_galore_quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939

4040
@pytest.mark.skip("skipping for now, see comments below")
41+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
4142
@pytest.mark.parametrize(
4243
"dim1,dim2,dtype,signed,blocksize",
4344
TEST_CONFIGS,
@@ -89,6 +90,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
8990
TEST_CONFIGS,
9091
)
9192
@skip_if_rocm("ROCm enablement in progress")
93+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
9294
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
9395
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
9496

test/quantization/test_qat.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ def forward(self, x):
133133
return x
134134

135135

136+
class M4(torch.nn.Module):
137+
def __init__(self):
138+
super().__init__()
139+
self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float)
140+
141+
def example_inputs(self):
142+
return (torch.randn(1, 512).to(torch.float),)
143+
144+
def forward(self, x):
145+
return self.linear(x)
146+
147+
136148
class ModelWithLinearBias(torch.nn.Module):
137149
def __init__(self):
138150
super().__init__()
@@ -1389,6 +1401,65 @@ def test_qat_linear_bias(self):
13891401
example_inputs = m.example_inputs()
13901402
m(*example_inputs)
13911403

1404+
@unittest.skipIf(
1405+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1406+
)
1407+
def test_fake_quantize_per_token_vs_convert(self):
1408+
"""
1409+
Test that the following produce the exact same numerics:
1410+
1. FakeQuantizer with asymmetric per_token config
1411+
2. torchao.quantization.utils.per_token_dynamic_quant
1412+
"""
1413+
from torchao.quantization.utils import per_token_dynamic_quant
1414+
1415+
torch.manual_seed(self.SEED)
1416+
x = torch.randn(1, 235, 2048)
1417+
config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1418+
fake_quantizer = FakeQuantizer(config)
1419+
fake_quantizer_out = fake_quantizer(x)
1420+
baseline_out = per_token_dynamic_quant(x)
1421+
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)
1422+
1423+
@unittest.skipIf(
1424+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1425+
)
1426+
def test_qat_8da4w_prepare_vs_convert(self):
1427+
"""
1428+
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
1429+
numerics that match exactly over N trials.
1430+
"""
1431+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
1432+
from torchao.quantization.utils import compute_error
1433+
1434+
num_trials = 1000
1435+
group_size = 16
1436+
non_inf_sqnr = []
1437+
1438+
for seed in range(self.SEED, self.SEED + num_trials):
1439+
torch.manual_seed(seed)
1440+
m = M4()
1441+
torch.manual_seed(seed)
1442+
x = m.example_inputs()
1443+
1444+
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
1445+
prepared = quantizer.prepare(m)
1446+
prepared_out = prepared(*x)
1447+
converted = quantizer.convert(prepared)
1448+
converted_out = converted(*x)
1449+
sqnr = compute_error(prepared_out, converted_out).item()
1450+
if sqnr != float("inf"):
1451+
non_inf_sqnr.append(sqnr)
1452+
1453+
avg_sqnr = (
1454+
sum(non_inf_sqnr) / len(non_inf_sqnr) if len(non_inf_sqnr) > 0 else -1
1455+
)
1456+
fail_message = "%s/%s trials did not match exactly, average sqnr = %s" % (
1457+
len(non_inf_sqnr),
1458+
num_trials,
1459+
avg_sqnr,
1460+
)
1461+
self.assertEqual(len(non_inf_sqnr), 0, fail_message)
1462+
13921463

13931464
if __name__ == "__main__":
13941465
unittest.main()

torchao/_executorch_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77

8+
# TODO: delete these ops
9+
810

911
def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs):
1012
"""

torchao/csrc/cuda/fp6_llm/fp6_linear.cu

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
//
2222
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
2323
// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory
24+
// - Added proper architecture check at both host and device level
2425
//
2526

2627

@@ -98,7 +99,24 @@ void fpx_linear_kernel(cudaStream_t stream,
9899
static_assert(std::is_same<InputDataType, half>::value || std::is_same<InputDataType, __nv_bfloat16>::value, "Type must be 'half' or '__nv_bfloat16'");
99100
assert(M_Global % 256 == 0);
100101
assert(K_Global % 64 == 0);
101-
assert(N_Global>0);
102+
assert(N_Global > 0);
103+
104+
// Check GPU Compute Capability before proceeding
105+
int device, major, minor;
106+
CHECK_CUDA(cudaGetDevice(&device));
107+
CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
108+
CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
109+
110+
// Early exit with error for unsupported architectures
111+
if ((major < 7) || (major == 7 && minor < 5)) {
112+
TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
113+
"Your current device has SM", major, minor, " which is not supported.");
114+
}
115+
116+
const bool is_sm75_gpu = (major == 7) && (minor == 5);
117+
if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value) {
118+
TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs.");
119+
}
102120

103121
// Work around to support more N shapes:
104122
size_t N_PowerOf2;
@@ -109,17 +127,6 @@ void fpx_linear_kernel(cudaStream_t stream,
109127
if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128;
110128
if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128;
111129

112-
// Check GPU Compute Capability
113-
int device, major, minor;
114-
CHECK_CUDA(cudaGetDevice(&device));
115-
CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
116-
CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
117-
const bool is_sm75_gpu = (major == 7) && (minor == 5);
118-
if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value)
119-
TORCH_CHECK(false, "Bfloat16 inputs are not supported for SM75");
120-
if ((major < 7) || (major == 7 && minor < 5))
121-
TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n");
122-
123130
if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) {
124131
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
125132
if (Split_K == 1) {
@@ -136,7 +143,7 @@ void fpx_linear_kernel(cudaStream_t stream,
136143
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
137144
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
138145
default: if (N_PowerOf2 % 128 != 0) {
139-
TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2);
146+
TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2);
140147
}
141148
Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
142149
}
@@ -149,7 +156,7 @@ void fpx_linear_kernel(cudaStream_t stream,
149156
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
150157
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
151158
default: if (N_PowerOf2 % 128 != 0) {
152-
TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2);
159+
TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2);
153160
}
154161
Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
155162
}
@@ -210,6 +217,23 @@ torch::Tensor fp_eXmY_linear_forward_cuda(
210217
torch::Tensor _scales,
211218
int64_t splitK=1)
212219
{
220+
// Check GPU Compute Capability before proceeding
221+
int device, major, minor;
222+
CHECK_CUDA(cudaGetDevice(&device));
223+
CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
224+
CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
225+
226+
// Early exit with error for unsupported architectures
227+
if ((major < 7) || (major == 7 && minor < 5)) {
228+
TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
229+
"Your current device has SM", major, minor, " which is not supported.");
230+
}
231+
232+
const bool is_sm75_gpu = (major == 7) && (minor == 5);
233+
if (is_sm75_gpu && _in_feats.scalar_type() == at::ScalarType::BFloat16) {
234+
TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs.");
235+
}
236+
213237
const int64_t NBITS = 1 + EXPONENT + MANTISSA;
214238
int num_in_feats = _in_feats.size(0);
215239
int num_in_channels = _in_feats.size(1);

torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,14 @@
5151
* B: col major, FP16
5252
* C: col major, FP16
5353
*/
54-
template<typename TilingConfig, typename InputDataType, typename OutputDataType, int EXPONENT, int MANTISSA>
54+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
55+
template<typename TilingConfig, typename InputDataType, typename OutputDataType, int EXPONENT, int MANTISSA>
5556
__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
5657
const half *B,
5758
OutputDataType* C,
5859
const size_t M_Global, const size_t N_Global, const size_t K_Global,
5960
int Split_K)
6061
{
61-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
62-
static_assert(false, "Quant-LLM kernel: At least Turing generation (sm75) is required.");
63-
// __trap(); // fails at runtime instead of compile time
64-
#endif
6562
#ifdef DEBUG_MODE
6663
assert(K_Global%TilingConfig::TILE_K==0);
6764
assert(M_Global%TilingConfig::TILE_M==0);
@@ -233,3 +230,15 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
233230
}
234231
}
235232
}
233+
#else
234+
// Stub implementation for older architectures
235+
template<typename TilingConfig, typename InputDataType, typename OutputDataType, int EXPONENT, int MANTISSA>
236+
__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
237+
const half *B,
238+
OutputDataType* C,
239+
const size_t M_Global, const size_t N_Global, const size_t K_Global,
240+
int Split_K)
241+
{
242+
// NOOP, should never actually be called
243+
}
244+
#endif

torchao/experimental/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ include_directories(${TORCHAO_INCLUDE_DIRS})
4040
if(TORCHAO_BUILD_CPU_AARCH64)
4141
message(STATUS "Building with cpu/aarch64")
4242
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64)
43+
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT)
4344

4445
# Defines torchao_kernels_aarch64
4546
add_subdirectory(kernels/cpu/aarch64)

0 commit comments

Comments
 (0)