Skip to content

Commit 63509d3

Browse files
committed
[test_decoders] add timing
Signed-off-by: kcirred <[email protected]>
1 parent f43cc04 commit 63509d3

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tests/models/test_decoders.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
)
5757
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
5858
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
59+
TIMING = os.environ.get("TIMING", "")
5960

6061
ATTN_TYPE = os.environ.get("FMS_TEST_SHAPES_ATTN_TYPE", "sdpa")
6162
attention_map = {
@@ -342,11 +343,11 @@ def get_or_create(self, is_gptq, **kwargs):
342343

343344
if compile_dynamic_sendnn:
344345
self.model = model
345-
346+
346347
return model
347348
else:
348349
return self.model
349-
350+
350351
# TODO: This was added as we require a special reset for gptq models. Ideally, we would be able to do something like this reset when calling reset_parameters() on the model
351352
# however the gptq modules are yet to support this
352353
@staticmethod
@@ -458,6 +459,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
458459
max_new_tokens,
459460
LogitsExtractorHook(),
460461
attn_algorithm="math",
462+
timing=TIMING,
461463
**extra_kwargs,
462464
)
463465

@@ -477,7 +479,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
477479

478480
# first test validation level 0
479481
aiu_validation_info = extract_validation_information(
480-
model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, **extra_kwargs
482+
model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, timing=TIMING, **extra_kwargs
481483
)
482484
dprint("aiu validation info extracted for validation level 0")
483485

@@ -530,6 +532,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
530532
max_new_tokens,
531533
LogitsExtractorHook(),
532534
attn_algorithm="math",
535+
timing=TIMING,
533536
**extra_kwargs,
534537
)
535538
dprint(
@@ -556,6 +559,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
556559
max_new_tokens,
557560
GoldenTokenHook(cpu_static_tokens),
558561
only_last_token=ATTN_TYPE != "paged",
562+
timing=TIMING,
559563
**extra_kwargs,
560564
)
561565
dprint(f"aiu validation info extracted for validation level 1 - iter={i}")

0 commit comments

Comments
 (0)