2424)
2525import json
2626from aiu_fms_testing_utils .utils .aiu_setup import dprint , aiu_dist_setup
27+ import shutil
2728import os
2829
2930try :
@@ -795,6 +796,7 @@ def _run_cpu_aiu_validation_test(
795796 cpu_model ,
796797 aiu_model ,
797798 micro_model_path ,
799+ verify_cache_state = None ,
798800):
799801 # Get the tokenizer and AIU / CPU models to compare
800802 tokenizer = tokenizers .get_tokenizer (model_path )
@@ -820,6 +822,12 @@ def _run_cpu_aiu_validation_test(
820822 aiu_model ,
821823 )
822824
825+ # Used only for cache tests; this is a nonparametric closure that
826+ # should assert the cache for torch sendnn is in the correct state
827+ # for this test
828+ if verify_cache_state is not None :
829+ verify_cache_state ()
830+
823831 # if level 0 fails validation, validate level 1
824832 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0 :
825833 if failed_validation_level_0 :
@@ -841,6 +849,87 @@ def _run_cpu_aiu_validation_test(
841849 )
842850
843851
852+ def _reset_cache_settings (purge_cache_dir ):
853+ os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
854+ os .environ ["COMPILATION_MODE" ] = "offline_decoder"
855+ cache_dir = os .environ ["TORCH_SENDNN_CACHE_DIR" ]
856+
857+ # Ensure we start in clean state
858+ if purge_cache_dir and os .path .isdir (cache_dir ):
859+ shutil .rmtree (cache_dir )
860+ os .mkdir (cache_dir )
861+
862+ from torch_sendnn .backends import cache
863+
864+ # Explicitly clear cache paths from the global torch sendnn graph;
865+ # TODO would be better to add a helper to explicitly do this in
866+ # torch sendnn
867+ cache .cache = {}
868+
869+
870+ @pytest .fixture
871+ def use_cached_model ():
872+ """Configures the tochsendnn cache and runs the AIU model prior to test execution;
873+ this is computationally expensive and should only be used in situations like testing
874+ cache hit correctness;
875+ """
876+ torch .manual_seed (42 )
877+ torch .set_grad_enabled (False )
878+ _reset_cache_settings (purge_cache_dir = True )
879+
880+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
881+ micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
882+
883+ def verify_cache_miss ():
884+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
885+ updated_cache_len = (
886+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
887+ )
888+ assert updated_cache_len == max_new_tokens , (
889+ "cache directory not populated on cache miss"
890+ )
891+
892+ dprint (
893+ f"Setting up cache [i.e., cache miss check] for model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
894+ )
895+
896+ # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
897+ gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
898+
899+ model = _get_aiu_model (
900+ model_path ,
901+ gptq_kwargs_aiu ,
902+ persistent_model_inst = None ,
903+ )
904+
905+ validation_model = _get_cpu_model (
906+ model_path ,
907+ gptq_kwargs_cpu ,
908+ micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
909+ )
910+
911+ _run_cpu_aiu_validation_test (
912+ model_path ,
913+ batch_size ,
914+ seq_length ,
915+ max_new_tokens ,
916+ validation_model ,
917+ model ,
918+ micro_model_path ,
919+ verify_cache_state = verify_cache_miss ,
920+ )
921+
922+
923+ def _get_cache_test_params ():
924+ # NOTE - currently we always use granite 3.3 for the cache test,
925+ # TODO make this configurable as tests are refactored
926+ model_path = GRANITE_3p3_8B_INSTRUCT
927+ batch_size = COMMON_BATCH_SIZES [0 ]
928+ seq_length = COMMON_SEQ_LENGTHS [0 ]
929+ max_new_tokens = COMMON_MAX_NEW_TOKENS [0 ]
930+ return [model_path , batch_size , seq_length , max_new_tokens ]
931+
932+
844933@pytest .mark .parametrize (
845934 "model_path,batch_size,seq_length,max_new_tokens" , common_shapes
846935)
@@ -879,3 +968,51 @@ def test_common_shapes(
879968 model ,
880969 micro_model_path ,
881970 )
971+
972+
973+ def test_cache (use_cached_model ):
974+ torch .manual_seed (42 )
975+ torch .set_grad_enabled (False )
976+ _reset_cache_settings (purge_cache_dir = False )
977+
978+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
979+ micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
980+
981+ def verify_cache_hit ():
982+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
983+ updated_cache_len = (
984+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
985+ )
986+ assert updated_cache_len == max_new_tokens , (
987+ "cache miss occurred when hit was expected"
988+ )
989+
990+ dprint (
991+ f"testing: model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } , for cache hit"
992+ )
993+
994+ # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
995+ gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
996+
997+ model = _get_aiu_model (
998+ model_path ,
999+ gptq_kwargs_aiu ,
1000+ persistent_model_inst = None ,
1001+ )
1002+
1003+ validation_model = _get_cpu_model (
1004+ model_path ,
1005+ gptq_kwargs_cpu ,
1006+ micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
1007+ )
1008+
1009+ _run_cpu_aiu_validation_test (
1010+ model_path ,
1011+ batch_size ,
1012+ seq_length ,
1013+ max_new_tokens ,
1014+ validation_model ,
1015+ model ,
1016+ micro_model_path ,
1017+ verify_cache_state = verify_cache_hit ,
1018+ )
0 commit comments