Skip to content

Commit 6243040

Browse files
authored
Add test case generator for groupwise low bit LUT based quantization (#2359)
* Add test case generator for groupwise low bit LUT based quantization kernel * Add granularity to LUT and scale generation in test cases * Update LUT test case generation. scale_group_size and lut_group_size control the frequency of group change. * Add has_scales tag to the LUT test case generation
1 parent 01b43cb commit 6243040

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed

torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,192 @@ struct lowbit_embedding_test_case {
575575
}
576576
};
577577

578+
struct groupwise_lowbit_weight_lut_test_case {
579+
//--------------------------------------------------------------------------
580+
// Parameters
581+
//--------------------------------------------------------------------------
582+
int m, k, n;
583+
int scale_group_size;
584+
int lut_group_size;
585+
int weight_nbit;
586+
bool has_scales, has_bias, has_clamp;
587+
float clamp_min, clamp_max;
588+
589+
//--------------------------------------------------------------------------
590+
// Data Tensors
591+
//--------------------------------------------------------------------------
592+
std::vector<float> expected_output;
593+
std::vector<float> activations;
594+
std::vector<float> bias;
595+
std::vector<uint8_t> weight_qval_indices; // Indices into a LUT for each weight
596+
std::vector<float> weight_luts; // The pool of unique LUTs
597+
std::vector<float> weight_scales; // The pool of unique scales
598+
599+
//--------------------------------------------------------------------------
600+
// Constructor
601+
//--------------------------------------------------------------------------
602+
groupwise_lowbit_weight_lut_test_case(
603+
int m_, int k_, int n_, int scale_group_size_, int lut_group_size_, int weight_nbit_, bool has_scales_, bool has_bias_, bool has_clamp_,
604+
float clamp_min_, float clamp_max_,
605+
std::vector<float> expected_output_, std::vector<float> activations_,
606+
std::vector<float> bias_, std::vector<uint8_t> weight_qval_indices_,
607+
std::vector<float> weight_luts_, std::vector<float> weight_scales_)
608+
: m(m_), k(k_), n(n_),
609+
scale_group_size(scale_group_size_), lut_group_size(lut_group_size_), weight_nbit(weight_nbit_),
610+
has_scales(has_scales_),
611+
has_bias(has_bias_), has_clamp(has_clamp_), clamp_min(clamp_min_), clamp_max(clamp_max_),
612+
expected_output(expected_output_),
613+
activations(activations_),
614+
bias(bias_),
615+
weight_qval_indices(weight_qval_indices_),
616+
weight_luts(weight_luts_),
617+
weight_scales(weight_scales_)
618+
{}
619+
620+
//--------------------------------------------------------------------------
621+
// Generator Functions (Factories)
622+
//--------------------------------------------------------------------------
623+
624+
private:
625+
/**
626+
* @brief The private "master" generator that provides maximum flexibility.
627+
*
628+
* This function is the core engine. It takes the exact number of scales and LUTs
629+
* to generate and constructs the test case. All other public generators are
630+
* wrappers around this one.
631+
*/
632+
static groupwise_lowbit_weight_lut_test_case _generate_master(
633+
int m, int k, int n,
634+
int scale_group_size, // Directly controls scale change frequency
635+
int lut_group_size, // Directly controls LUT change frequency
636+
int weight_nbit, bool has_scales,
637+
bool has_bias, bool has_clamp) {
638+
639+
// --- 0. Validation and Setup ---
640+
const int total_weights = n * k;
641+
// Frequencies are controlled by their group sizes.
642+
assert(total_weights % scale_group_size == 0);
643+
assert(total_weights % lut_group_size == 0);
644+
645+
// The number of unique scales/LUTs is derived directly from their group size.
646+
const int num_scales = total_weights / scale_group_size;
647+
const int num_luts = total_weights / lut_group_size;
648+
const int lut_size = 1 << weight_nbit;
649+
std::mt19937 gen(std::random_device{}());
650+
651+
// --- 1. Generate Primary Inputs ---
652+
auto activations = get_random_vector(m * k, -1.0f, 1.0f);
653+
std::vector<float> bias_vec(n, 0.0f);
654+
if (has_bias) bias_vec = get_random_vector(n, -0.5f, 0.5f);
655+
float clamp_min = -std::numeric_limits<float>::infinity(), clamp_max = std::numeric_limits<float>::infinity();
656+
if (has_clamp) {
657+
auto r = get_random_vector(2, -5.0f, 5.0f);
658+
clamp_min = std::min(r[0], r[1]); clamp_max = std::max(r[0], r[1]);
659+
}
660+
661+
// --- 2. Generate Quantization Data ---
662+
// 2a. Generate the pools of unique scales and LUTs.
663+
std::vector<float> weight_scales;
664+
if (has_scales) {
665+
// Normal case: generate random scales.
666+
weight_scales = get_random_vector(num_scales, 0.001f, 0.1f);
667+
} else {
668+
// LUT-only case: create a vector where every scale is 1.0f.
669+
weight_scales.assign(num_scales, 1.0f);
670+
}
671+
672+
auto weight_luts = get_random_vector(num_luts * lut_size, -0.2f, 0.2f); // Independent random LUTs
673+
674+
// 2b. Generate random quantized indices for each weight.
675+
auto weight_qval_indices = std::vector<uint8_t>(total_weights);
676+
std::uniform_int_distribution<int> qval_dis(0, lut_size - 1);
677+
for (int i = 0; i < total_weights; ++i) weight_qval_indices[i] = static_cast<uint8_t>(qval_dis(gen));
678+
679+
// --- 3. Compute Expected Output using the IMPLICIT mappings ---
680+
std::vector<float> expected_output(m * n);
681+
for (int m_idx = 0; m_idx < m; ++m_idx) {
682+
for (int n_idx = 0; n_idx < n; ++n_idx) {
683+
float res = 0.0f;
684+
for (int k_idx = 0; k_idx < k; ++k_idx) {
685+
float activation_val = activations[m_idx * k + k_idx];
686+
int weight_idx = n_idx * k + k_idx;
687+
uint8_t qval_idx = weight_qval_indices[weight_idx];
688+
689+
int32_t scale_idx = weight_idx / scale_group_size;
690+
int32_t lut_idx = weight_idx / lut_group_size;
691+
692+
// Dequantize: scale * LUT_value
693+
float scale = weight_scales[scale_idx];
694+
float lut_val = weight_luts[lut_idx * lut_size + qval_idx];
695+
res += activation_val * (scale * lut_val);
696+
}
697+
res += bias_vec[n_idx];
698+
if (has_clamp) { res = std::clamp(res, clamp_min, clamp_max); }
699+
expected_output[m_idx * n + n_idx] = res;
700+
}
701+
}
702+
703+
// --- 4. Construct and Return ---
704+
return groupwise_lowbit_weight_lut_test_case(
705+
m, k, n, scale_group_size, lut_group_size, weight_nbit, has_scales,
706+
has_bias, has_clamp, clamp_min, clamp_max,
707+
expected_output,
708+
activations,
709+
bias_vec,
710+
weight_qval_indices,
711+
weight_luts,
712+
weight_scales);
713+
714+
}
715+
716+
public:
717+
/**
718+
* @brief OVERLOAD 1: Simple generator where scales and LUTs share the same grouping.
719+
*
720+
* This is for the simplest case where a block of weights gets one scale and one LUT,
721+
* and this pattern repeats.
722+
*/
723+
static groupwise_lowbit_weight_lut_test_case generate_per_group(
724+
int m, int k, int n,
725+
int group_size, // The size of the block for both scales and LUTs
726+
int weight_nbit, bool has_scales,
727+
bool has_bias, bool has_clamp) {
728+
729+
std::cout << "[Generator Info] Using 'Per-Group' model.\n"
730+
<< " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl;
731+
732+
// Just call the decoupled generator with the same group size for both.
733+
return _generate_master(
734+
m, k, n,
735+
group_size, /* scale_group_size */
736+
group_size, /* lut_group_size */
737+
weight_nbit,
738+
has_scales,
739+
has_bias, has_clamp
740+
);
741+
}
742+
743+
/**
744+
* @brief OVERLOAD 2: Advanced generator with separate grouping for scales and LUTs.
745+
*/
746+
static groupwise_lowbit_weight_lut_test_case generate_with_decoupled_grouping(
747+
int m, int k, int n,
748+
int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales,
749+
bool has_bias, bool has_clamp) {
750+
751+
std::cout << "[Generator Info] Using 'Decoupled Grouping' model.\n"
752+
<< " - Scales will switch every " << scale_group_size << " weights.\n"
753+
<< " - LUTs will switch every " << lut_group_size << " weights." << std::endl;
754+
755+
return _generate_master(
756+
m, k, n,
757+
scale_group_size, lut_group_size,
758+
weight_nbit, has_scales,
759+
has_bias, has_clamp
760+
);
761+
}
762+
};
763+
578764
} // namespace torchao
579765

580766
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)