56
56
)
57
57
USE_MICRO_MODELS = os .environ .get ("FMS_TEST_SHAPES_USE_MICRO_MODELS" , "1" ) == "1"
58
58
USE_DISTRIBUTED = os .environ .get ("FMS_TEST_SHAPES_DISTRIBUTED" , "0" ) == "1"
59
+ TIMING = os .environ .get ("TIMING" , "" )
59
60
60
61
ATTN_TYPE = os .environ .get ("FMS_TEST_SHAPES_ATTN_TYPE" , "sdpa" )
61
62
attention_map = {
@@ -342,11 +343,11 @@ def get_or_create(self, is_gptq, **kwargs):
342
343
343
344
if compile_dynamic_sendnn :
344
345
self .model = model
345
-
346
+
346
347
return model
347
348
else :
348
349
return self .model
349
-
350
+
350
351
# TODO: This was added as we require a special reset for gptq models. Ideally, we would be able to do something like this reset when calling reset_parameters() on the model
351
352
# however the gptq modules are yet to support this
352
353
@staticmethod
@@ -458,6 +459,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
458
459
max_new_tokens ,
459
460
LogitsExtractorHook (),
460
461
attn_algorithm = "math" ,
462
+ timing = TIMING ,
461
463
** extra_kwargs ,
462
464
)
463
465
@@ -477,7 +479,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
477
479
478
480
# first test validation level 0
479
481
aiu_validation_info = extract_validation_information (
480
- model , input_ids , max_new_tokens , None , only_last_token = "paged" not in ATTN_NAME , ** extra_kwargs
482
+ model , input_ids , max_new_tokens , None , only_last_token = "paged" not in ATTN_NAME , timing = TIMING , ** extra_kwargs
481
483
)
482
484
dprint ("aiu validation info extracted for validation level 0" )
483
485
@@ -530,6 +532,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
530
532
max_new_tokens ,
531
533
LogitsExtractorHook (),
532
534
attn_algorithm = "math" ,
535
+ timing = TIMING ,
533
536
** extra_kwargs ,
534
537
)
535
538
dprint (
@@ -556,6 +559,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
556
559
max_new_tokens ,
557
560
GoldenTokenHook (cpu_static_tokens ),
558
561
only_last_token = ATTN_TYPE != "paged" ,
562
+ timing = TIMING ,
559
563
** extra_kwargs ,
560
564
)
561
565
dprint (f"aiu validation info extracted for validation level 1 - iter={ i } " )
0 commit comments