Skip to content

Add test case for cache #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 274 additions & 4 deletions tests/models/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
import json
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup

import shutil
import os

try:
Expand Down Expand Up @@ -148,6 +148,7 @@
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str((((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64)
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(common_batch_sizes))


# thresholds are chosen based on 1024 tokens per sequence
# 1% error threshold rate between cpu fp32 and cuda fp16
# if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above
Expand Down Expand Up @@ -214,7 +215,7 @@ def reset_compiler():
torch.compiler.reset()
torch._dynamo.reset()
os.environ.pop("COMPILATION_MODE", None)

os.environ.pop('TORCH_SENDNN_CACHE_ENABLE', None)

# TODO: Currently, gptq does not have the same level of support as non-gptq models for get_model. This method provides the extra requirements for gptq for get_model,
# however ideally, these fixes should be done in foundation-model-stack.
Expand Down Expand Up @@ -260,7 +261,6 @@ def __maybe_get_gptq_kwargs(model_path):
pass
return gptq_kwargs_aiu, gptq_kwargs_cpu


def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
prompts_and_sizes = sample_sharegpt_requests(
SHARE_GPT_DATASET_PATH,
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi

# prepare the AIU model
model = persistent_model.get_or_create(is_gptq, **gptq_kwargs_aiu, **get_model_kwargs)

# prepare the cpu model
validation_model = get_model(
device_type="cpu",
Expand Down Expand Up @@ -620,3 +620,273 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
print("passed validation level 1")
else:
print("passed validation level 0")

@pytest.mark.parametrize("cache_status", ["miss", "hit"])
def test_cache(cache_status):
torch.manual_seed(42)
torch.set_grad_enabled(False)
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
os.environ["TORCH_SENDNN_CACHE_DIR"] = os.getcwd()+"/.cache"
os.environ["COMPILATION_MODE"] = "offline_decoder"

if cache_status == "miss" and os.path.isdir(os.getcwd()+"/.cache"):
# Remove cache from previous runs
shutil.rmtree(os.getcwd()+"/.cache")

model_path = "ibm-granite/granite-3.3-8b-instruct"
batch_size = common_batch_sizes[0]
seq_length = common_seq_lengths[0]
max_new_tokens = common_max_new_tokens[0]

dprint(f"testing with cache: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, cache={cache_status}")

# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
is_gptq = len(gptq_kwargs_aiu) != 0

micro_model_path = micro_model_mapping.get(model_path, None)
if USE_MICRO_MODELS and micro_model_path is None:
dprint("using randomly initialized model")
micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3}
else:
dprint("using trained model")
micro_model_kwargs = {"architecture": "hf_pretrained"}

if not USE_MICRO_MODELS and os.path.exists(model_path):
model_path_kwargs = {"model_path": model_path}
elif USE_MICRO_MODELS and micro_model_path is not None:
model_path_kwargs = {"model_path": micro_model_path}
else:
model_path_kwargs = {"variant": model_path}

distributed_kwargs = {}
if USE_DISTRIBUTED:
distributed_kwargs["distributed_strategy"] = "tp"
distributed_kwargs["group"] = dist.group.WORLD

get_model_kwargs = {}
if not is_gptq:
get_model_kwargs = {
**model_path_kwargs,
**micro_model_kwargs,
**distributed_kwargs,
}

tokenizer = tokenizers.get_tokenizer(model_path)

# prepare the AIU model
model = get_model(
device_type="cpu",
data_type=None if is_gptq else torch.float16,
fused_weights=False,
**get_model_kwargs,
)

model.eval()
model.compile(backend="sendnn")

# prepare the cpu model
validation_model = get_model(
device_type="cpu",
data_type=None if is_gptq else torch.float32,
fused_weights=False,
**gptq_kwargs_cpu,
**get_model_kwargs,
)

if USE_MICRO_MODELS:
serialization.load_state_dict_into_model(
validation_model, model.state_dict(), **__custom_adapter
)

# prepare input_ids
input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
extra_kwargs["attn_name"] = ATTN_NAME

# warmup aiu model
warmup_model(model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs)

# generate cpu validation info
cpu_validation_info = __load_validation_info(
model_path, batch_size, seq_length, max_new_tokens, tokenizer, 0
)
if cpu_validation_info is None:
cpu_validation_info = extract_validation_information(
validation_model,
input_ids,
max_new_tokens,
LogitsExtractorHook(),
attn_algorithm="math",
**extra_kwargs,
)

if save_validation_info_outputs:
cpu_validation_info.save(
__get_validation_info_full_path(
model_path, batch_size, seq_length, max_new_tokens, 0
)
)
cpu_static_tokens = cpu_validation_info.get_info("tokens")
eos_indexes = __find_eos_index(
cpu_static_tokens, tokenizer.eos_token_id, seq_length, max_new_tokens
)
dprint(
"cpu validation info extracted for validation level 0 and validation level 1 (iter=0)"
)

# first test validation level 0
aiu_validation_info = extract_validation_information(
model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, **extra_kwargs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there something we need to assert for miss/hit?

dprint("aiu validation info extracted for validation level 0")

# check cache status before validating cached results
updated_cache_len = len(os.listdir(os.getcwd()+"/.cache")) if os.path.isdir(os.getcwd()+"/.cache") else 0
if cache_status == "miss":
assert updated_cache_len == max_new_tokens, (
"cache directory not populated on cache miss"
)
return
else:
assert updated_cache_len == max_new_tokens, (
"cache miss occurred when hit was expected"
)

# validate level 0
failed_responses = validate_level_0(
aiu_validation_info.get_info("tokens"), cpu_static_tokens
)

failed_validation_level_0 = len(failed_responses) != 0

# if level 0 fails validation, validate level 1
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:

if failed_validation_level_0:
dprint("failed validation level 0, testing validation level 1")
else:
dprint("passed validation level 0, testing validation level 1")

# metric calculator based on the cross-entropy and mean diff for each decode step
def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
cross_entropy = torch.nn.CrossEntropyLoss()(
r, t.softmax(dim=1).to(dtype=torch.float32)
)
diff = torch.mean(
torch.abs(
r.softmax(dim=1).to(dtype=torch.float32)
- t.softmax(dim=1).to(dtype=torch.float32)
)
)
return (cross_entropy, diff)

iters = 1024 // max_new_tokens
ce_fail_responses_list = []
diff_fail_responses_list = []
total_tokens = 0
for i in range(iters):
# for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip
if i != 0:
input_ids, extra_kwargs = __prepare_inputs(
batch_size, seq_length, tokenizer, seed=i
)
extra_kwargs["attn_name"] = ATTN_NAME
cpu_validation_info = __load_validation_info(
model_path, batch_size, seq_length, max_new_tokens, tokenizer, i
)
if cpu_validation_info is None:
cpu_validation_info = extract_validation_information(
validation_model,
input_ids,
max_new_tokens,
LogitsExtractorHook(),
attn_algorithm="math",
**extra_kwargs,
)
dprint(
f"cpu validation info extracted for validation level 1 - iter={i}"
)
if save_validation_info_outputs:
cpu_validation_info.save(
__get_validation_info_full_path(
model_path, batch_size, seq_length, max_new_tokens, i
)
)
cpu_static_tokens = cpu_validation_info.get_info("tokens")
eos_indexes = __find_eos_index(
cpu_static_tokens,
tokenizer.eos_token_id,
seq_length,
max_new_tokens,
)

# generate aiu validation info
aiu_validation_info = extract_validation_information(
model,
input_ids,
max_new_tokens,
GoldenTokenHook(cpu_static_tokens),
only_last_token=ATTN_TYPE != "paged",
**extra_kwargs,
)
dprint(f"aiu validation info extracted for validation level 1 - iter={i}")
if save_validation_info_outputs:
aiu_validation_info.save(
__get_validation_info_full_path(
model_path, batch_size, seq_length, max_new_tokens, i, "aiu"
)
)

# capture all level 1 metrics
level_1_metrics = capture_level_1_metrics(
cpu_validation_info.get_info("logits"),
aiu_validation_info.get_info("logits"),
top_k_loss_calculator(20, _metric_calculator),
)
# only consider those metrics captured prior to the eos
level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes)

# if we do not have real model weights, use a default_metrics_threshold
if USE_MICRO_MODELS and micro_model_path is None:
ce_threshold, diff_threshold = default_metrics_threshold
# if we have real weights, try and get the proper validation metrics threshold
else:
# if we have a micro model with real weights, but no real thresholds, default to the full model thresholds
if USE_MICRO_MODELS:
ce_threshold, diff_threshold = fail_thresholds.get(
(model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold)
)
else:
ce_threshold, diff_threshold = fail_thresholds.get(
(model_path, False), default_metrics_threshold
)

# get all failed responses for each metric
ce_fail_responses = filter_failed_level_1_cases(
level_1_metrics, lambda m: m[0] >= ce_threshold
)
diff_fail_responses = filter_failed_level_1_cases(
level_1_metrics,
lambda m: m[1] >= diff_threshold,
)

ce_fail_responses_list.extend(ce_fail_responses)
diff_fail_responses_list.extend(diff_fail_responses)
total_tokens += len(level_1_metrics)

# test the failure rates for across all tokens
diff_failure_rate = len(diff_fail_responses_list) / total_tokens
ce_failure_rate = len(ce_fail_responses_list) / total_tokens
dprint(f"mean diff failure rate: {diff_failure_rate}")
dprint(f"cross entropy loss failure rate: {ce_failure_rate}")
if "mean_diff" not in skip_assertions:
assert diff_failure_rate < failure_rate_threshold, (
f"failure rate for mean diff was too high: {diff_failure_rate}"
)
if "ce" not in skip_assertions:
assert ce_failure_rate < failure_rate_threshold, (
f"failure rate for cross entropy loss was too high: {ce_failure_rate}"
)
print("passed validation level 1")
else:
print("passed validation level 0")