@@ -575,6 +575,192 @@ struct lowbit_embedding_test_case {
575
575
}
576
576
};
577
577
578
+ struct groupwise_lowbit_weight_lut_test_case {
579
+ // --------------------------------------------------------------------------
580
+ // Parameters
581
+ // --------------------------------------------------------------------------
582
+ int m, k, n;
583
+ int scale_group_size;
584
+ int lut_group_size;
585
+ int weight_nbit;
586
+ bool has_scales, has_bias, has_clamp;
587
+ float clamp_min, clamp_max;
588
+
589
+ // --------------------------------------------------------------------------
590
+ // Data Tensors
591
+ // --------------------------------------------------------------------------
592
+ std::vector<float > expected_output;
593
+ std::vector<float > activations;
594
+ std::vector<float > bias;
595
+ std::vector<uint8_t > weight_qval_indices; // Indices into a LUT for each weight
596
+ std::vector<float > weight_luts; // The pool of unique LUTs
597
+ std::vector<float > weight_scales; // The pool of unique scales
598
+
599
+ // --------------------------------------------------------------------------
600
+ // Constructor
601
+ // --------------------------------------------------------------------------
602
+ groupwise_lowbit_weight_lut_test_case (
603
+ int m_, int k_, int n_, int scale_group_size_, int lut_group_size_, int weight_nbit_, bool has_scales_, bool has_bias_, bool has_clamp_,
604
+ float clamp_min_, float clamp_max_,
605
+ std::vector<float > expected_output_, std::vector<float > activations_,
606
+ std::vector<float > bias_, std::vector<uint8_t > weight_qval_indices_,
607
+ std::vector<float > weight_luts_, std::vector<float > weight_scales_)
608
+ : m(m_), k(k_), n(n_),
609
+ scale_group_size (scale_group_size_), lut_group_size(lut_group_size_), weight_nbit(weight_nbit_),
610
+ has_scales(has_scales_),
611
+ has_bias(has_bias_), has_clamp(has_clamp_), clamp_min(clamp_min_), clamp_max(clamp_max_),
612
+ expected_output(expected_output_),
613
+ activations(activations_),
614
+ bias(bias_),
615
+ weight_qval_indices(weight_qval_indices_),
616
+ weight_luts(weight_luts_),
617
+ weight_scales(weight_scales_)
618
+ {}
619
+
620
+ // --------------------------------------------------------------------------
621
+ // Generator Functions (Factories)
622
+ // --------------------------------------------------------------------------
623
+
624
+ private:
625
+ /* *
626
+ * @brief The private "master" generator that provides maximum flexibility.
627
+ *
628
+ * This function is the core engine. It takes the exact number of scales and LUTs
629
+ * to generate and constructs the test case. All other public generators are
630
+ * wrappers around this one.
631
+ */
632
+ static groupwise_lowbit_weight_lut_test_case _generate_master (
633
+ int m, int k, int n,
634
+ int scale_group_size, // Directly controls scale change frequency
635
+ int lut_group_size, // Directly controls LUT change frequency
636
+ int weight_nbit, bool has_scales,
637
+ bool has_bias, bool has_clamp) {
638
+
639
+ // --- 0. Validation and Setup ---
640
+ const int total_weights = n * k;
641
+ // Frequencies are controlled by their group sizes.
642
+ assert (total_weights % scale_group_size == 0 );
643
+ assert (total_weights % lut_group_size == 0 );
644
+
645
+ // The number of unique scales/LUTs is derived directly from their group size.
646
+ const int num_scales = total_weights / scale_group_size;
647
+ const int num_luts = total_weights / lut_group_size;
648
+ const int lut_size = 1 << weight_nbit;
649
+ std::mt19937 gen (std::random_device{}());
650
+
651
+ // --- 1. Generate Primary Inputs ---
652
+ auto activations = get_random_vector (m * k, -1 .0f , 1 .0f );
653
+ std::vector<float > bias_vec (n, 0 .0f );
654
+ if (has_bias) bias_vec = get_random_vector (n, -0 .5f , 0 .5f );
655
+ float clamp_min = -std::numeric_limits<float >::infinity (), clamp_max = std::numeric_limits<float >::infinity ();
656
+ if (has_clamp) {
657
+ auto r = get_random_vector (2 , -5 .0f , 5 .0f );
658
+ clamp_min = std::min (r[0 ], r[1 ]); clamp_max = std::max (r[0 ], r[1 ]);
659
+ }
660
+
661
+ // --- 2. Generate Quantization Data ---
662
+ // 2a. Generate the pools of unique scales and LUTs.
663
+ std::vector<float > weight_scales;
664
+ if (has_scales) {
665
+ // Normal case: generate random scales.
666
+ weight_scales = get_random_vector (num_scales, 0 .001f , 0 .1f );
667
+ } else {
668
+ // LUT-only case: create a vector where every scale is 1.0f.
669
+ weight_scales.assign (num_scales, 1 .0f );
670
+ }
671
+
672
+ auto weight_luts = get_random_vector (num_luts * lut_size, -0 .2f , 0 .2f ); // Independent random LUTs
673
+
674
+ // 2b. Generate random quantized indices for each weight.
675
+ auto weight_qval_indices = std::vector<uint8_t >(total_weights);
676
+ std::uniform_int_distribution<int > qval_dis (0 , lut_size - 1 );
677
+ for (int i = 0 ; i < total_weights; ++i) weight_qval_indices[i] = static_cast <uint8_t >(qval_dis (gen));
678
+
679
+ // --- 3. Compute Expected Output using the IMPLICIT mappings ---
680
+ std::vector<float > expected_output (m * n);
681
+ for (int m_idx = 0 ; m_idx < m; ++m_idx) {
682
+ for (int n_idx = 0 ; n_idx < n; ++n_idx) {
683
+ float res = 0 .0f ;
684
+ for (int k_idx = 0 ; k_idx < k; ++k_idx) {
685
+ float activation_val = activations[m_idx * k + k_idx];
686
+ int weight_idx = n_idx * k + k_idx;
687
+ uint8_t qval_idx = weight_qval_indices[weight_idx];
688
+
689
+ int32_t scale_idx = weight_idx / scale_group_size;
690
+ int32_t lut_idx = weight_idx / lut_group_size;
691
+
692
+ // Dequantize: scale * LUT_value
693
+ float scale = weight_scales[scale_idx];
694
+ float lut_val = weight_luts[lut_idx * lut_size + qval_idx];
695
+ res += activation_val * (scale * lut_val);
696
+ }
697
+ res += bias_vec[n_idx];
698
+ if (has_clamp) { res = std::clamp (res, clamp_min, clamp_max); }
699
+ expected_output[m_idx * n + n_idx] = res;
700
+ }
701
+ }
702
+
703
+ // --- 4. Construct and Return ---
704
+ return groupwise_lowbit_weight_lut_test_case (
705
+ m, k, n, scale_group_size, lut_group_size, weight_nbit, has_scales,
706
+ has_bias, has_clamp, clamp_min, clamp_max,
707
+ expected_output,
708
+ activations,
709
+ bias_vec,
710
+ weight_qval_indices,
711
+ weight_luts,
712
+ weight_scales);
713
+
714
+ }
715
+
716
+ public:
717
+ /* *
718
+ * @brief OVERLOAD 1: Simple generator where scales and LUTs share the same grouping.
719
+ *
720
+ * This is for the simplest case where a block of weights gets one scale and one LUT,
721
+ * and this pattern repeats.
722
+ */
723
+ static groupwise_lowbit_weight_lut_test_case generate_per_group (
724
+ int m, int k, int n,
725
+ int group_size, // The size of the block for both scales and LUTs
726
+ int weight_nbit, bool has_scales,
727
+ bool has_bias, bool has_clamp) {
728
+
729
+ std::cout << " [Generator Info] Using 'Per-Group' model.\n "
730
+ << " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl;
731
+
732
+ // Just call the decoupled generator with the same group size for both.
733
+ return _generate_master (
734
+ m, k, n,
735
+ group_size, /* scale_group_size */
736
+ group_size, /* lut_group_size */
737
+ weight_nbit,
738
+ has_scales,
739
+ has_bias, has_clamp
740
+ );
741
+ }
742
+
743
+ /* *
744
+ * @brief OVERLOAD 2: Advanced generator with separate grouping for scales and LUTs.
745
+ */
746
+ static groupwise_lowbit_weight_lut_test_case generate_with_decoupled_grouping (
747
+ int m, int k, int n,
748
+ int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales,
749
+ bool has_bias, bool has_clamp) {
750
+
751
+ std::cout << " [Generator Info] Using 'Decoupled Grouping' model.\n "
752
+ << " - Scales will switch every " << scale_group_size << " weights.\n "
753
+ << " - LUTs will switch every " << lut_group_size << " weights." << std::endl;
754
+
755
+ return _generate_master (
756
+ m, k, n,
757
+ scale_group_size, lut_group_size,
758
+ weight_nbit, has_scales,
759
+ has_bias, has_clamp
760
+ );
761
+ }
762
+ };
763
+
578
764
} // namespace torchao
579
765
580
766
#endif // defined(__aarch64__) || defined(__ARM_NEON)
0 commit comments