63
63
common_max_new_tokens = [int (mnt ) for mnt in common_max_new_tokens .split ("," )]
64
64
65
65
common_shapes = list (itertools .product (common_model_paths , common_batch_sizes , common_seq_lengths , common_max_new_tokens ))
66
+ cache_params = list (itertools .product ([common_model_paths [0 ]], [common_batch_sizes [0 ]], [common_seq_lengths [0 ]], [common_max_new_tokens [0 ]], ["miss" , "hit" ]))
66
67
67
68
# thresholds are chosen based on 1024 tokens per sequence
68
69
# 1% error threshold rate between cpu fp32 and cuda fp16
@@ -78,6 +79,7 @@ def reset_compiler():
78
79
torch .compiler .reset ()
79
80
torch ._dynamo .reset ()
80
81
os .environ .pop ('COMPILATION_MODE' , None )
82
+ os .environ .pop ('TORCH_SENDNN_CACHE_ENABLE' , None )
81
83
if ORIGINAL_HF_HOME is None :
82
84
os .environ .pop ('HF_HOME' , None )
83
85
else :
@@ -287,5 +289,56 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
287
289
else :
288
290
print ("passed validation level 0" )
289
291
292
+ @pytest .mark .parametrize ("model_path,batch_size,seq_length,max_new_tokens,cache_status" , cache_params )
293
+ def test_cache (model_path , batch_size , seq_length , max_new_tokens , cache_status ):
294
+ torch .manual_seed (42 )
295
+ os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
296
+ os .environ ["COMPILATION_MODE" ] = "offline_decoder"
297
+
298
+ dprint (f"testing with cache: model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } , cache={ cache_status } " )
290
299
300
+ if USE_MICRO_MODELS :
301
+ micro_model_kwargs = {"architecture" : "hf_configured" , "nlayers" : 3 }
302
+ else :
303
+ micro_model_kwargs = {"architecture" : "hf_pretrained" }
304
+
305
+ if not USE_MICRO_MODELS and os .path .exists (model_path ):
306
+ model_path_kwargs = {"model_path" : model_path }
307
+ else :
308
+ model_path_kwargs = {"variant" : model_path }
309
+
310
+ distributed_kwargs = {}
311
+ if USE_DISTRIBUTED :
312
+ distributed_kwargs ["distr_param" ] = "tp"
313
+ distributed_kwargs ["group" ] = dist .group .WORLD
314
+ get_model_kwargs = {** model_path_kwargs , ** micro_model_kwargs , ** distributed_kwargs }
291
315
316
+ tokenizer = tokenizers .get_tokenizer (model_path )
317
+
318
+ # prepare the AIU model
319
+ model = get_model (
320
+ device_type = "cpu" ,
321
+ fused_weights = False ,
322
+ ** get_model_kwargs
323
+ )
324
+
325
+ model .eval ()
326
+ torch .set_grad_enabled (False )
327
+ model .compile (backend = "sendnn_decoder" )
328
+
329
+
330
+ # prepare input_ids
331
+ input_ids , padding_kwargs = __prepare_inputs (batch_size , seq_length , tokenizer )
332
+
333
+ # warmup aiu model
334
+ warmup_model (model , input_ids , max_new_tokens , ** padding_kwargs )
335
+
336
+ # aiu validatation
337
+ aiu_validation_info = extract_validation_information (
338
+ model ,
339
+ input_ids ,
340
+ max_new_tokens ,
341
+ None ,
342
+ only_last_token = True ,
343
+ ** padding_kwargs
344
+ )
0 commit comments