@@ -794,6 +794,44 @@ def test_cyclic_kv_cache_beam_search(self):
794
794
])
795
795
796
796
797
+ class TestMistral7B (CliFlowAccuracyTestHarness ):
798
+ MODEL_NAME = "mistralai/Mistral-7B-v0.1"
799
+ MODEL_PATH = f"{ llm_models_root ()} /mistral-7b-v0.1"
800
+ EXAMPLE_FOLDER = "models/core/llama"
801
+
802
+ @skip_pre_blackwell
803
+ def test_beam_search (self ):
804
+ self .run (extra_acc_spec = "beam_width=4" ,
805
+ extra_build_args = ["--gemm_plugin=auto" , "--max_beam_width=4" ],
806
+ extra_summarize_args = ["--num_beams=4" ])
807
+ import gc
808
+
809
+ import torch
810
+ for num_beams in [1 , 2 ]:
811
+ gc .collect ()
812
+ torch .cuda .empty_cache ()
813
+ self .extra_acc_spec = f"beam_width={ num_beams } "
814
+ self .extra_summarize_args = [f"--num_beams={ num_beams } " ]
815
+ self .evaluate ()
816
+
817
+ @skip_pre_ada
818
+ @pytest .mark .skip_less_device (8 )
819
+ def test_fp8_tp4pp2 (self ):
820
+ self .run (quant_algo = QuantAlgo .FP8 ,
821
+ tp_size = 4 ,
822
+ pp_size = 2 ,
823
+ extra_convert_args = ["--calib_size=4" ],
824
+ extra_build_args = ["--gemm_plugin=auto" ])
825
+
826
+ @skip_post_blackwell
827
+ @pytest .mark .skip_less_device (4 )
828
+ def test_smooth_quant_tp4pp1 (self ):
829
+ self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN ,
830
+ tp_size = 4 ,
831
+ pp_size = 1 ,
832
+ extra_build_args = ["--gemm_plugin=auto" ])
833
+
834
+
797
835
class TestMixtral8x7B (CliFlowAccuracyTestHarness ):
798
836
MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"
799
837
MODEL_PATH = f"{ llm_models_root ()} /Mixtral-8x7B-v0.1"
@@ -804,6 +842,43 @@ class TestMixtral8x7B(CliFlowAccuracyTestHarness):
804
842
def test_tp2 (self ):
805
843
self .run (dtype = 'auto' , tp_size = 2 )
806
844
845
+ @skip_post_blackwell
846
+ @pytest .mark .skip_less_device (8 )
847
+ @pytest .mark .skip_less_device_memory (45000 )
848
+ @pytest .mark .parametrize (
849
+ "moe_tp_size" , [1 , 4 , 8 ],
850
+ ids = ['expert_parallel' , 'mixed_parallel' , 'tensor_parallel' ])
851
+ def test_ootb_except_mha_tp8 (self , moe_tp_size ):
852
+ self .run (tp_size = 8 ,
853
+ extra_convert_args = [
854
+ f"--moe_tp_size={ moe_tp_size } " ,
855
+ f"--moe_ep_size={ 8 // moe_tp_size } " ,
856
+ f"--moe_renorm_mode={ 0 } "
857
+ ],
858
+ extra_build_args = [
859
+ "--gemm_plugin=disable" , "--moe_plugin=disable" ,
860
+ f"--max_seq_len={ 8192 } "
861
+ ])
862
+
863
+ @pytest .mark .skip_less_device (8 )
864
+ @pytest .mark .skip_less_device_memory (45000 )
865
+ @pytest .mark .parametrize (
866
+ "moe_tp_size" , [1 , 4 , 8 ],
867
+ ids = ['expert_parallel' , 'mixed_parallel' , 'tensor_parallel' ])
868
+ @pytest .mark .parametrize ("moe_renorm_mode" , [0 , 1 ],
869
+ ids = ['no_renormalize' , 'renormalize' ])
870
+ def test_plugin_tp8 (self , moe_tp_size , moe_renorm_mode ):
871
+ self .run (tp_size = 8 ,
872
+ extra_convert_args = [
873
+ f"--moe_tp_size={ moe_tp_size } " ,
874
+ f"--moe_ep_size={ 8 // moe_tp_size } " ,
875
+ f"--moe_renorm_mode={ moe_renorm_mode } "
876
+ ],
877
+ extra_build_args = [
878
+ "--gemm_plugin=auto" , "--moe_plugin=auto" ,
879
+ f"--max_seq_len={ 8192 } "
880
+ ])
881
+
807
882
@skip_pre_ada
808
883
@pytest .mark .skip_less_device (2 )
809
884
@pytest .mark .skip_less_device_memory (80000 )
@@ -835,6 +910,43 @@ def test_fp8_tp2pp2_manage_weights(self):
835
910
pp_size = 2 ,
836
911
extra_build_args = ["--fast_build" ])
837
912
913
+ @pytest .mark .skip_less_device (2 )
914
+ @pytest .mark .skip_less_device_memory (80000 )
915
+ def test_weight_only_int4_tp2 (self ):
916
+ self .run (quant_algo = QuantAlgo .W4A16 ,
917
+ tp_size = 2 ,
918
+ extra_build_args = ["--gemm_plugin=auto" ])
919
+
920
+ @pytest .mark .skip_less_device (2 )
921
+ @pytest .mark .skip_less_device_memory (80000 )
922
+ def test_weight_only_int8_tp2 (self ):
923
+ self .run (quant_algo = QuantAlgo .W8A16 ,
924
+ tp_size = 2 ,
925
+ extra_build_args = ["--gemm_plugin=auto" ])
926
+
927
+ @skip_post_blackwell
928
+ @pytest .mark .skip_less_device (4 )
929
+ @pytest .mark .skip_less_device_memory (45000 )
930
+ def test_pp_reduce_scatter_tp2pp2 (self ):
931
+ self .run (quant_algo = QuantAlgo .W8A16 ,
932
+ tp_size = 2 ,
933
+ pp_size = 2 ,
934
+ extra_build_args = [
935
+ "--gemm_plugin=auto" , "--pp_reduce_scatter=enable"
936
+ ])
937
+
938
+ @skip_pre_blackwell
939
+ @pytest .mark .skip_less_device_memory (180000 )
940
+ def test_fp4_plugin (self ):
941
+ build_args = [
942
+ "--max_input_len=2048" , "--gemm_plugin=nvfp4" ,
943
+ "--use_paged_context_fmha=enable" , "--use_fp8_context_fmha=enable"
944
+ ]
945
+ self .run (tasks = [MMLU (self .MODEL_NAME )],
946
+ quant_algo = QuantAlgo .NVFP4 ,
947
+ kv_cache_quant_algo = QuantAlgo .FP8 ,
948
+ extra_build_args = build_args )
949
+
838
950
@skip_pre_blackwell
839
951
def test_nvfp4_prequantized (self , mocker ):
840
952
mocker .patch .object (
@@ -845,6 +957,45 @@ def test_nvfp4_prequantized(self, mocker):
845
957
kv_cache_quant_algo = QuantAlgo .FP8 )
846
958
847
959
960
+ class TestMixtral8x22B (CliFlowAccuracyTestHarness ):
961
+ MODEL_NAME = "mistralai/Mixtral-8x22B-v0.1"
962
+ MODEL_PATH = f"{ llm_models_root ()} /Mixtral-8x22B-v0.1"
963
+ EXAMPLE_FOLDER = "models/core/llama"
964
+
965
+ @skip_pre_ada
966
+ @pytest .mark .skip_less_device (4 )
967
+ @pytest .mark .skip_less_device_memory (80000 )
968
+ def test_fp8_tp2pp2 (self ):
969
+ self .run (tasks = [CnnDailymail (self .MODEL_NAME ),
970
+ MMLU (self .MODEL_NAME )],
971
+ quant_algo = QuantAlgo .FP8 ,
972
+ tp_size = 2 ,
973
+ pp_size = 2 ,
974
+ extra_convert_args = ["--calib_size=32" ],
975
+ extra_build_args = ["--gemm_plugin=auto" ])
976
+
977
+ @skip_post_blackwell
978
+ @pytest .mark .skip_less_device (8 )
979
+ @pytest .mark .skip_less_device_memory (45000 )
980
+ @pytest .mark .parametrize (
981
+ "moe_tp_size" , [1 , 4 , 8 ],
982
+ ids = ['expert_parallel' , 'mixed_parallel' , 'tensor_parallel' ])
983
+ @pytest .mark .parametrize ("moe_renorm_mode" , [0 , 1 ],
984
+ ids = ['no_renormalize' , 'renormalize' ])
985
+ def test_int8_plugin_tp8 (self , moe_tp_size , moe_renorm_mode ):
986
+ self .run (quant_algo = QuantAlgo .W8A16 ,
987
+ tp_size = 8 ,
988
+ extra_convert_args = [
989
+ f"--moe_tp_size={ moe_tp_size } " ,
990
+ f"--moe_ep_size={ 8 // moe_tp_size } " ,
991
+ f"--moe_renorm_mode={ moe_renorm_mode } "
992
+ ],
993
+ extra_build_args = [
994
+ "--max_beam_width=4" , "--gemm_plugin=auto" ,
995
+ "--moe_plugin=auto" , f"--max_seq_len={ 8192 } "
996
+ ])
997
+
998
+
848
999
class TestGemma2B (CliFlowAccuracyTestHarness ):
849
1000
MODEL_NAME = "google/gemma-2b"
850
1001
MODEL_PATH = f"{ llm_models_root ()} /gemma/gemma-2b"
0 commit comments