|
24 | 24 | ) |
25 | 25 | import json |
26 | 26 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup |
27 | | - |
| 27 | +import shutil |
28 | 28 | import os |
29 | 29 |
|
30 | 30 | try: |
|
148 | 148 | os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str((((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64) |
149 | 149 | os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(common_batch_sizes)) |
150 | 150 |
|
151 | | -cache_params = list(itertools.product([common_model_paths[0]], [common_batch_sizes[0]], [common_seq_lengths[0]], [common_max_new_tokens[0]], ["miss", "hit"])) |
152 | 151 |
|
153 | 152 | # thresholds are chosen based on 1024 tokens per sequence |
154 | 153 | # 1% error threshold rate between cpu fp32 and cuda fp16 |
|
182 | 181 | USE_MICRO_MODELS = False |
183 | 182 | common_model_paths = [] |
184 | 183 | frequency = int(model_configuration_frequency) |
185 | | - with open(model_configuration_path, 'r') as f: |
186 | | - for line in f: |
| 184 | + for line in f: |
187 | 185 | try: |
188 | 186 | model_config = json.loads(line) |
189 | 187 | if model_config["frequency"] <= frequency: |
@@ -426,7 +424,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi |
426 | 424 |
|
427 | 425 | # prepare the AIU model |
428 | 426 | model = persistent_model.get_or_create(is_gptq, **gptq_kwargs_aiu, **get_model_kwargs) |
429 | | - |
| 427 | + |
430 | 428 | # prepare the cpu model |
431 | 429 | validation_model = get_model( |
432 | 430 | device_type="cpu", |
@@ -555,6 +553,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
555 | 553 | model, |
556 | 554 | input_ids, |
557 | 555 | max_new_tokens, |
| 556 | + max_new_tokens, |
558 | 557 | GoldenTokenHook(cpu_static_tokens), |
559 | 558 | only_last_token=ATTN_TYPE != "paged", |
560 | 559 | **extra_kwargs, |
@@ -622,56 +621,272 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
622 | 621 | else: |
623 | 622 | print("passed validation level 0") |
624 | 623 |
|
625 | | -@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens,cache_status", cache_params) |
626 | | -def test_cache(model_path, batch_size, seq_length, max_new_tokens, cache_status): |
| 624 | +@pytest.mark.parametrize("cache_status", ["miss", "hit"]) |
| 625 | +def test_cache(cache_status): |
627 | 626 | torch.manual_seed(42) |
| 627 | + torch.set_grad_enabled(False) |
628 | 628 | os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" |
| 629 | + os.environ["TORCH_SENDNN_CACHE_DIR"] = os.getcwd()+"/.cache" |
629 | 630 | os.environ["COMPILATION_MODE"] = "offline_decoder" |
630 | 631 |
|
| 632 | + if cache_status == "miss" and os.path.isdir(os.getcwd()+"/.cache"): |
| 633 | + # Remove cache from previous runs |
| 634 | + shutil.rmtree(os.getcwd()+"/.cache") |
| 635 | + |
| 636 | + model_path = "ibm-granite/granite-3.3-8b-instruct" |
| 637 | + batch_size = common_batch_sizes[0] |
| 638 | + seq_length = common_seq_lengths[0] |
| 639 | + max_new_tokens = common_max_new_tokens[0] |
| 640 | + |
631 | 641 | dprint(f"testing with cache: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, cache={cache_status}") |
632 | 642 |
|
633 | | - if USE_MICRO_MODELS: |
| 643 | + # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured |
| 644 | + gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) |
| 645 | + is_gptq = len(gptq_kwargs_aiu) != 0 |
| 646 | + |
| 647 | + micro_model_path = micro_model_mapping.get(model_path, None) |
| 648 | + if USE_MICRO_MODELS and micro_model_path is None: |
| 649 | + dprint("using randomly initialized model") |
634 | 650 | micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3} |
635 | 651 | else: |
636 | | - micro_model_kwargs = {"architecture": "hf_pretrained"} |
637 | | - |
| 652 | + dprint("using trained model") |
| 653 | + micro_model_kwargs = {"architecture": "hf_pretrained"} |
| 654 | + |
638 | 655 | if not USE_MICRO_MODELS and os.path.exists(model_path): |
639 | 656 | model_path_kwargs = {"model_path": model_path} |
| 657 | + elif USE_MICRO_MODELS and micro_model_path is not None: |
| 658 | + model_path_kwargs = {"model_path": micro_model_path} |
640 | 659 | else: |
641 | 660 | model_path_kwargs = {"variant": model_path} |
642 | | - |
| 661 | + |
643 | 662 | distributed_kwargs = {} |
644 | 663 | if USE_DISTRIBUTED: |
645 | | - distributed_kwargs["distr_param"] = "tp" |
| 664 | + distributed_kwargs["distributed_strategy"] = "tp" |
646 | 665 | distributed_kwargs["group"] = dist.group.WORLD |
647 | | - get_model_kwargs = {**model_path_kwargs, **micro_model_kwargs, **distributed_kwargs} |
| 666 | + |
| 667 | + get_model_kwargs = {} |
| 668 | + if not is_gptq: |
| 669 | + get_model_kwargs = { |
| 670 | + **model_path_kwargs, |
| 671 | + **micro_model_kwargs, |
| 672 | + **distributed_kwargs, |
| 673 | + } |
648 | 674 |
|
649 | 675 | tokenizer = tokenizers.get_tokenizer(model_path) |
650 | 676 |
|
651 | 677 | # prepare the AIU model |
652 | 678 | model = get_model( |
| 679 | + device_type="cpu", |
| 680 | + data_type=None if is_gptq else torch.float16, |
| 681 | + fused_weights=False, |
| 682 | + **get_model_kwargs, |
| 683 | + ) |
| 684 | + |
| 685 | + model.eval() |
| 686 | + model.compile(backend="sendnn") |
| 687 | + |
| 688 | + # prepare the cpu model |
| 689 | + validation_model = get_model( |
653 | 690 | device_type="cpu", |
| 691 | + data_type=None if is_gptq else torch.float32, |
654 | 692 | fused_weights=False, |
655 | | - **get_model_kwargs |
| 693 | + **gptq_kwargs_cpu, |
| 694 | + **get_model_kwargs, |
656 | 695 | ) |
657 | 696 |
|
658 | | - model.eval() |
659 | | - torch.set_grad_enabled(False) |
660 | | - model.compile(backend="sendnn_decoder") |
661 | | - |
| 697 | + if USE_MICRO_MODELS: |
| 698 | + serialization.load_state_dict_into_model( |
| 699 | + validation_model, model.state_dict(), **__custom_adapter |
| 700 | + ) |
662 | 701 |
|
663 | 702 | # prepare input_ids |
664 | | - input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) |
| 703 | + input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) |
| 704 | + extra_kwargs["attn_name"] = ATTN_NAME |
665 | 705 |
|
666 | 706 | # warmup aiu model |
667 | | - warmup_model(model, input_ids, max_new_tokens, **padding_kwargs) |
| 707 | + warmup_model(model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs) |
| 708 | + |
| 709 | + # generate cpu validation info |
| 710 | + cpu_validation_info = __load_validation_info( |
| 711 | + model_path, batch_size, seq_length, max_new_tokens, tokenizer, 0 |
| 712 | + ) |
| 713 | + if cpu_validation_info is None: |
| 714 | + cpu_validation_info = extract_validation_information( |
| 715 | + validation_model, |
| 716 | + input_ids, |
| 717 | + max_new_tokens, |
| 718 | + LogitsExtractorHook(), |
| 719 | + attn_algorithm="math", |
| 720 | + **extra_kwargs, |
| 721 | + ) |
668 | 722 |
|
669 | | - # aiu validatation |
| 723 | + if save_validation_info_outputs: |
| 724 | + cpu_validation_info.save( |
| 725 | + __get_validation_info_full_path( |
| 726 | + model_path, batch_size, seq_length, max_new_tokens, 0 |
| 727 | + ) |
| 728 | + ) |
| 729 | + cpu_static_tokens = cpu_validation_info.get_info("tokens") |
| 730 | + eos_indexes = __find_eos_index( |
| 731 | + cpu_static_tokens, tokenizer.eos_token_id, seq_length, max_new_tokens |
| 732 | + ) |
| 733 | + dprint( |
| 734 | + "cpu validation info extracted for validation level 0 and validation level 1 (iter=0)" |
| 735 | + ) |
| 736 | + |
| 737 | + # first test validation level 0 |
670 | 738 | aiu_validation_info = extract_validation_information( |
671 | | - model, |
672 | | - input_ids, |
673 | | - max_new_tokens, |
674 | | - None, |
675 | | - only_last_token=True, |
676 | | - **padding_kwargs |
677 | | -) |
| 739 | + model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, **extra_kwargs |
| 740 | + ) |
| 741 | + dprint("aiu validation info extracted for validation level 0") |
| 742 | + |
| 743 | + # check cache status before validating cached results |
| 744 | + updated_cache_len = len(os.listdir(os.getcwd()+"/.cache")) if os.path.isdir(os.getcwd()+"/.cache") else 0 |
| 745 | + if cache_status == "miss": |
| 746 | + assert updated_cache_len == max_new_tokens, ( |
| 747 | + "cache directory not populated on cache miss" |
| 748 | + ) |
| 749 | + return |
| 750 | + else: |
| 751 | + assert updated_cache_len == max_new_tokens, ( |
| 752 | + "cache miss occurred when hit was expected" |
| 753 | + ) |
| 754 | + |
| 755 | + # validate level 0 |
| 756 | + failed_responses = validate_level_0( |
| 757 | + aiu_validation_info.get_info("tokens"), cpu_static_tokens |
| 758 | + ) |
| 759 | + |
| 760 | + failed_validation_level_0 = len(failed_responses) != 0 |
| 761 | + |
| 762 | + # if level 0 fails validation, validate level 1 |
| 763 | + if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: |
| 764 | + |
| 765 | + if failed_validation_level_0: |
| 766 | + dprint("failed validation level 0, testing validation level 1") |
| 767 | + else: |
| 768 | + dprint("passed validation level 0, testing validation level 1") |
| 769 | + |
| 770 | + # metric calculator based on the cross-entropy and mean diff for each decode step |
| 771 | + def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
| 772 | + cross_entropy = torch.nn.CrossEntropyLoss()( |
| 773 | + r, t.softmax(dim=1).to(dtype=torch.float32) |
| 774 | + ) |
| 775 | + diff = torch.mean( |
| 776 | + torch.abs( |
| 777 | + r.softmax(dim=1).to(dtype=torch.float32) |
| 778 | + - t.softmax(dim=1).to(dtype=torch.float32) |
| 779 | + ) |
| 780 | + ) |
| 781 | + return (cross_entropy, diff) |
| 782 | + |
| 783 | + iters = 1024 // max_new_tokens |
| 784 | + ce_fail_responses_list = [] |
| 785 | + diff_fail_responses_list = [] |
| 786 | + total_tokens = 0 |
| 787 | + for i in range(iters): |
| 788 | + # for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip |
| 789 | + if i != 0: |
| 790 | + input_ids, extra_kwargs = __prepare_inputs( |
| 791 | + batch_size, seq_length, tokenizer, seed=i |
| 792 | + ) |
| 793 | + extra_kwargs["attn_name"] = ATTN_NAME |
| 794 | + cpu_validation_info = __load_validation_info( |
| 795 | + model_path, batch_size, seq_length, max_new_tokens, tokenizer, i |
| 796 | + ) |
| 797 | + if cpu_validation_info is None: |
| 798 | + cpu_validation_info = extract_validation_information( |
| 799 | + validation_model, |
| 800 | + input_ids, |
| 801 | + max_new_tokens, |
| 802 | + LogitsExtractorHook(), |
| 803 | + attn_algorithm="math", |
| 804 | + **extra_kwargs, |
| 805 | + ) |
| 806 | + dprint( |
| 807 | + f"cpu validation info extracted for validation level 1 - iter={i}" |
| 808 | + ) |
| 809 | + if save_validation_info_outputs: |
| 810 | + cpu_validation_info.save( |
| 811 | + __get_validation_info_full_path( |
| 812 | + model_path, batch_size, seq_length, max_new_tokens, i |
| 813 | + ) |
| 814 | + ) |
| 815 | + cpu_static_tokens = cpu_validation_info.get_info("tokens") |
| 816 | + eos_indexes = __find_eos_index( |
| 817 | + cpu_static_tokens, |
| 818 | + tokenizer.eos_token_id, |
| 819 | + seq_length, |
| 820 | + max_new_tokens, |
| 821 | + ) |
| 822 | + |
| 823 | + # generate aiu validation info |
| 824 | + aiu_validation_info = extract_validation_information( |
| 825 | + model, |
| 826 | + input_ids, |
| 827 | + max_new_tokens, |
| 828 | + GoldenTokenHook(cpu_static_tokens), |
| 829 | + only_last_token=ATTN_TYPE != "paged", |
| 830 | + **extra_kwargs, |
| 831 | + ) |
| 832 | + dprint(f"aiu validation info extracted for validation level 1 - iter={i}") |
| 833 | + if save_validation_info_outputs: |
| 834 | + aiu_validation_info.save( |
| 835 | + __get_validation_info_full_path( |
| 836 | + model_path, batch_size, seq_length, max_new_tokens, i, "aiu" |
| 837 | + ) |
| 838 | + ) |
| 839 | + |
| 840 | + # capture all level 1 metrics |
| 841 | + level_1_metrics = capture_level_1_metrics( |
| 842 | + cpu_validation_info.get_info("logits"), |
| 843 | + aiu_validation_info.get_info("logits"), |
| 844 | + top_k_loss_calculator(20, _metric_calculator), |
| 845 | + ) |
| 846 | + # only consider those metrics captured prior to the eos |
| 847 | + level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes) |
| 848 | + |
| 849 | + # if we do not have real model weights, use a default_metrics_threshold |
| 850 | + if USE_MICRO_MODELS and micro_model_path is None: |
| 851 | + ce_threshold, diff_threshold = default_metrics_threshold |
| 852 | + # if we have real weights, try and get the proper validation metrics threshold |
| 853 | + else: |
| 854 | + # if we have a micro model with real weights, but no real thresholds, default to the full model thresholds |
| 855 | + if USE_MICRO_MODELS: |
| 856 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 857 | + (model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold) |
| 858 | + ) |
| 859 | + else: |
| 860 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 861 | + (model_path, False), default_metrics_threshold |
| 862 | + ) |
| 863 | + |
| 864 | + # get all failed responses for each metric |
| 865 | + ce_fail_responses = filter_failed_level_1_cases( |
| 866 | + level_1_metrics, lambda m: m[0] >= ce_threshold |
| 867 | + ) |
| 868 | + diff_fail_responses = filter_failed_level_1_cases( |
| 869 | + level_1_metrics, |
| 870 | + lambda m: m[1] >= diff_threshold, |
| 871 | + ) |
| 872 | + |
| 873 | + ce_fail_responses_list.extend(ce_fail_responses) |
| 874 | + diff_fail_responses_list.extend(diff_fail_responses) |
| 875 | + total_tokens += len(level_1_metrics) |
| 876 | + |
| 877 | + # test the failure rates for across all tokens |
| 878 | + diff_failure_rate = len(diff_fail_responses_list) / total_tokens |
| 879 | + ce_failure_rate = len(ce_fail_responses_list) / total_tokens |
| 880 | + dprint(f"mean diff failure rate: {diff_failure_rate}") |
| 881 | + dprint(f"cross entropy loss failure rate: {ce_failure_rate}") |
| 882 | + if "mean_diff" not in skip_assertions: |
| 883 | + assert diff_failure_rate < failure_rate_threshold, ( |
| 884 | + f"failure rate for mean diff was too high: {diff_failure_rate}" |
| 885 | + ) |
| 886 | + if "ce" not in skip_assertions: |
| 887 | + assert ce_failure_rate < failure_rate_threshold, ( |
| 888 | + f"failure rate for cross entropy loss was too high: {ce_failure_rate}" |
| 889 | + ) |
| 890 | + print("passed validation level 1") |
| 891 | + else: |
| 892 | + print("passed validation level 0") |
0 commit comments