Skip to content

Commit 6a6fe67

Browse files
Update cache test, add validation for cached run
1 parent 87b3ac8 commit 6a6fe67

File tree

1 file changed

+243
-28
lines changed

1 file changed

+243
-28
lines changed

tests/models/test_decoders.py

Lines changed: 243 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
import json
2626
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
27-
27+
import shutil
2828
import os
2929

3030
try:
@@ -148,7 +148,6 @@
148148
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str((((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64)
149149
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(common_batch_sizes))
150150

151-
cache_params = list(itertools.product([common_model_paths[0]], [common_batch_sizes[0]], [common_seq_lengths[0]], [common_max_new_tokens[0]], ["miss", "hit"]))
152151

153152
# thresholds are chosen based on 1024 tokens per sequence
154153
# 1% error threshold rate between cpu fp32 and cuda fp16
@@ -182,8 +181,7 @@
182181
USE_MICRO_MODELS = False
183182
common_model_paths = []
184183
frequency = int(model_configuration_frequency)
185-
with open(model_configuration_path, 'r') as f:
186-
for line in f:
184+
for line in f:
187185
try:
188186
model_config = json.loads(line)
189187
if model_config["frequency"] <= frequency:
@@ -426,7 +424,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
426424

427425
# prepare the AIU model
428426
model = persistent_model.get_or_create(is_gptq, **gptq_kwargs_aiu, **get_model_kwargs)
429-
427+
430428
# prepare the cpu model
431429
validation_model = get_model(
432430
device_type="cpu",
@@ -555,6 +553,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
555553
model,
556554
input_ids,
557555
max_new_tokens,
556+
max_new_tokens,
558557
GoldenTokenHook(cpu_static_tokens),
559558
only_last_token=ATTN_TYPE != "paged",
560559
**extra_kwargs,
@@ -622,56 +621,272 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
622621
else:
623622
print("passed validation level 0")
624623

625-
@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens,cache_status", cache_params)
626-
def test_cache(model_path, batch_size, seq_length, max_new_tokens, cache_status):
624+
@pytest.mark.parametrize("cache_status", ["miss", "hit"])
625+
def test_cache(cache_status):
627626
torch.manual_seed(42)
627+
torch.set_grad_enabled(False)
628628
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
629+
os.environ["TORCH_SENDNN_CACHE_DIR"] = os.getcwd()+"/.cache"
629630
os.environ["COMPILATION_MODE"] = "offline_decoder"
630631

632+
if cache_status == "miss" and os.path.isdir(os.getcwd()+"/.cache"):
633+
# Remove cache from previous runs
634+
shutil.rmtree(os.getcwd()+"/.cache")
635+
636+
model_path = "ibm-granite/granite-3.3-8b-instruct"
637+
batch_size = common_batch_sizes[0]
638+
seq_length = common_seq_lengths[0]
639+
max_new_tokens = common_max_new_tokens[0]
640+
631641
dprint(f"testing with cache: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, cache={cache_status}")
632642

633-
if USE_MICRO_MODELS:
643+
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
644+
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
645+
is_gptq = len(gptq_kwargs_aiu) != 0
646+
647+
micro_model_path = micro_model_mapping.get(model_path, None)
648+
if USE_MICRO_MODELS and micro_model_path is None:
649+
dprint("using randomly initialized model")
634650
micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3}
635651
else:
636-
micro_model_kwargs = {"architecture": "hf_pretrained"}
637-
652+
dprint("using trained model")
653+
micro_model_kwargs = {"architecture": "hf_pretrained"}
654+
638655
if not USE_MICRO_MODELS and os.path.exists(model_path):
639656
model_path_kwargs = {"model_path": model_path}
657+
elif USE_MICRO_MODELS and micro_model_path is not None:
658+
model_path_kwargs = {"model_path": micro_model_path}
640659
else:
641660
model_path_kwargs = {"variant": model_path}
642-
661+
643662
distributed_kwargs = {}
644663
if USE_DISTRIBUTED:
645-
distributed_kwargs["distr_param"] = "tp"
664+
distributed_kwargs["distributed_strategy"] = "tp"
646665
distributed_kwargs["group"] = dist.group.WORLD
647-
get_model_kwargs = {**model_path_kwargs, **micro_model_kwargs, **distributed_kwargs}
666+
667+
get_model_kwargs = {}
668+
if not is_gptq:
669+
get_model_kwargs = {
670+
**model_path_kwargs,
671+
**micro_model_kwargs,
672+
**distributed_kwargs,
673+
}
648674

649675
tokenizer = tokenizers.get_tokenizer(model_path)
650676

651677
# prepare the AIU model
652678
model = get_model(
679+
device_type="cpu",
680+
data_type=None if is_gptq else torch.float16,
681+
fused_weights=False,
682+
**get_model_kwargs,
683+
)
684+
685+
model.eval()
686+
model.compile(backend="sendnn")
687+
688+
# prepare the cpu model
689+
validation_model = get_model(
653690
device_type="cpu",
691+
data_type=None if is_gptq else torch.float32,
654692
fused_weights=False,
655-
**get_model_kwargs
693+
**gptq_kwargs_cpu,
694+
**get_model_kwargs,
656695
)
657696

658-
model.eval()
659-
torch.set_grad_enabled(False)
660-
model.compile(backend="sendnn_decoder")
661-
697+
if USE_MICRO_MODELS:
698+
serialization.load_state_dict_into_model(
699+
validation_model, model.state_dict(), **__custom_adapter
700+
)
662701

663702
# prepare input_ids
664-
input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
703+
input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
704+
extra_kwargs["attn_name"] = ATTN_NAME
665705

666706
# warmup aiu model
667-
warmup_model(model, input_ids, max_new_tokens, **padding_kwargs)
707+
warmup_model(model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs)
708+
709+
# generate cpu validation info
710+
cpu_validation_info = __load_validation_info(
711+
model_path, batch_size, seq_length, max_new_tokens, tokenizer, 0
712+
)
713+
if cpu_validation_info is None:
714+
cpu_validation_info = extract_validation_information(
715+
validation_model,
716+
input_ids,
717+
max_new_tokens,
718+
LogitsExtractorHook(),
719+
attn_algorithm="math",
720+
**extra_kwargs,
721+
)
668722

669-
# aiu validatation
723+
if save_validation_info_outputs:
724+
cpu_validation_info.save(
725+
__get_validation_info_full_path(
726+
model_path, batch_size, seq_length, max_new_tokens, 0
727+
)
728+
)
729+
cpu_static_tokens = cpu_validation_info.get_info("tokens")
730+
eos_indexes = __find_eos_index(
731+
cpu_static_tokens, tokenizer.eos_token_id, seq_length, max_new_tokens
732+
)
733+
dprint(
734+
"cpu validation info extracted for validation level 0 and validation level 1 (iter=0)"
735+
)
736+
737+
# first test validation level 0
670738
aiu_validation_info = extract_validation_information(
671-
model,
672-
input_ids,
673-
max_new_tokens,
674-
None,
675-
only_last_token=True,
676-
**padding_kwargs
677-
)
739+
model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, **extra_kwargs
740+
)
741+
dprint("aiu validation info extracted for validation level 0")
742+
743+
# check cache status before validating cached results
744+
updated_cache_len = len(os.listdir(os.getcwd()+"/.cache")) if os.path.isdir(os.getcwd()+"/.cache") else 0
745+
if cache_status == "miss":
746+
assert updated_cache_len == max_new_tokens, (
747+
"cache directory not populated on cache miss"
748+
)
749+
return
750+
else:
751+
assert updated_cache_len == max_new_tokens, (
752+
"cache miss occurred when hit was expected"
753+
)
754+
755+
# validate level 0
756+
failed_responses = validate_level_0(
757+
aiu_validation_info.get_info("tokens"), cpu_static_tokens
758+
)
759+
760+
failed_validation_level_0 = len(failed_responses) != 0
761+
762+
# if level 0 fails validation, validate level 1
763+
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
764+
765+
if failed_validation_level_0:
766+
dprint("failed validation level 0, testing validation level 1")
767+
else:
768+
dprint("passed validation level 0, testing validation level 1")
769+
770+
# metric calculator based on the cross-entropy and mean diff for each decode step
771+
def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
772+
cross_entropy = torch.nn.CrossEntropyLoss()(
773+
r, t.softmax(dim=1).to(dtype=torch.float32)
774+
)
775+
diff = torch.mean(
776+
torch.abs(
777+
r.softmax(dim=1).to(dtype=torch.float32)
778+
- t.softmax(dim=1).to(dtype=torch.float32)
779+
)
780+
)
781+
return (cross_entropy, diff)
782+
783+
iters = 1024 // max_new_tokens
784+
ce_fail_responses_list = []
785+
diff_fail_responses_list = []
786+
total_tokens = 0
787+
for i in range(iters):
788+
# for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip
789+
if i != 0:
790+
input_ids, extra_kwargs = __prepare_inputs(
791+
batch_size, seq_length, tokenizer, seed=i
792+
)
793+
extra_kwargs["attn_name"] = ATTN_NAME
794+
cpu_validation_info = __load_validation_info(
795+
model_path, batch_size, seq_length, max_new_tokens, tokenizer, i
796+
)
797+
if cpu_validation_info is None:
798+
cpu_validation_info = extract_validation_information(
799+
validation_model,
800+
input_ids,
801+
max_new_tokens,
802+
LogitsExtractorHook(),
803+
attn_algorithm="math",
804+
**extra_kwargs,
805+
)
806+
dprint(
807+
f"cpu validation info extracted for validation level 1 - iter={i}"
808+
)
809+
if save_validation_info_outputs:
810+
cpu_validation_info.save(
811+
__get_validation_info_full_path(
812+
model_path, batch_size, seq_length, max_new_tokens, i
813+
)
814+
)
815+
cpu_static_tokens = cpu_validation_info.get_info("tokens")
816+
eos_indexes = __find_eos_index(
817+
cpu_static_tokens,
818+
tokenizer.eos_token_id,
819+
seq_length,
820+
max_new_tokens,
821+
)
822+
823+
# generate aiu validation info
824+
aiu_validation_info = extract_validation_information(
825+
model,
826+
input_ids,
827+
max_new_tokens,
828+
GoldenTokenHook(cpu_static_tokens),
829+
only_last_token=ATTN_TYPE != "paged",
830+
**extra_kwargs,
831+
)
832+
dprint(f"aiu validation info extracted for validation level 1 - iter={i}")
833+
if save_validation_info_outputs:
834+
aiu_validation_info.save(
835+
__get_validation_info_full_path(
836+
model_path, batch_size, seq_length, max_new_tokens, i, "aiu"
837+
)
838+
)
839+
840+
# capture all level 1 metrics
841+
level_1_metrics = capture_level_1_metrics(
842+
cpu_validation_info.get_info("logits"),
843+
aiu_validation_info.get_info("logits"),
844+
top_k_loss_calculator(20, _metric_calculator),
845+
)
846+
# only consider those metrics captured prior to the eos
847+
level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes)
848+
849+
# if we do not have real model weights, use a default_metrics_threshold
850+
if USE_MICRO_MODELS and micro_model_path is None:
851+
ce_threshold, diff_threshold = default_metrics_threshold
852+
# if we have real weights, try and get the proper validation metrics threshold
853+
else:
854+
# if we have a micro model with real weights, but no real thresholds, default to the full model thresholds
855+
if USE_MICRO_MODELS:
856+
ce_threshold, diff_threshold = fail_thresholds.get(
857+
(model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold)
858+
)
859+
else:
860+
ce_threshold, diff_threshold = fail_thresholds.get(
861+
(model_path, False), default_metrics_threshold
862+
)
863+
864+
# get all failed responses for each metric
865+
ce_fail_responses = filter_failed_level_1_cases(
866+
level_1_metrics, lambda m: m[0] >= ce_threshold
867+
)
868+
diff_fail_responses = filter_failed_level_1_cases(
869+
level_1_metrics,
870+
lambda m: m[1] >= diff_threshold,
871+
)
872+
873+
ce_fail_responses_list.extend(ce_fail_responses)
874+
diff_fail_responses_list.extend(diff_fail_responses)
875+
total_tokens += len(level_1_metrics)
876+
877+
# test the failure rates for across all tokens
878+
diff_failure_rate = len(diff_fail_responses_list) / total_tokens
879+
ce_failure_rate = len(ce_fail_responses_list) / total_tokens
880+
dprint(f"mean diff failure rate: {diff_failure_rate}")
881+
dprint(f"cross entropy loss failure rate: {ce_failure_rate}")
882+
if "mean_diff" not in skip_assertions:
883+
assert diff_failure_rate < failure_rate_threshold, (
884+
f"failure rate for mean diff was too high: {diff_failure_rate}"
885+
)
886+
if "ce" not in skip_assertions:
887+
assert ce_failure_rate < failure_rate_threshold, (
888+
f"failure rate for cross entropy loss was too high: {ce_failure_rate}"
889+
)
890+
print("passed validation level 1")
891+
else:
892+
print("passed validation level 0")

0 commit comments

Comments
 (0)