Skip to content

Commit 0e07855

Browse files
Add test case for caching
Signed-off-by: Avery Blanchard <[email protected]>
1 parent 27b67c2 commit 0e07855

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

tests/models/test_decoders.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
common_max_new_tokens = [int(mnt) for mnt in common_max_new_tokens.split(",")]
6464

6565
common_shapes = list(itertools.product(common_model_paths, common_batch_sizes, common_seq_lengths, common_max_new_tokens))
66+
cache_params = list(itertools.product([common_model_paths[0]], [common_batch_sizes[0]], [common_seq_lengths[0]], [common_max_new_tokens[0]], ["miss", "hit"]))
6667

6768
# thresholds are chosen based on 1024 tokens per sequence
6869
# 1% error threshold rate between cpu fp32 and cuda fp16
@@ -78,6 +79,7 @@ def reset_compiler():
7879
torch.compiler.reset()
7980
torch._dynamo.reset()
8081
os.environ.pop('COMPILATION_MODE', None)
82+
os.environ.pop('TORCH_SENDNN_CACHE_ENABLE', None)
8183
if ORIGINAL_HF_HOME is None:
8284
os.environ.pop('HF_HOME', None)
8385
else:
@@ -287,5 +289,56 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
287289
else:
288290
print("passed validation level 0")
289291

292+
@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens,cache_status", cache_params)
293+
def test_cache(model_path, batch_size, seq_length, max_new_tokens, cache_status):
294+
torch.manual_seed(42)
295+
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
296+
os.environ["COMPILATION_MODE"] = "offline_decoder"
297+
298+
dprint(f"testing with cache: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, cache={cache_status}")
290299

300+
if USE_MICRO_MODELS:
301+
micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3}
302+
else:
303+
micro_model_kwargs = {"architecture": "hf_pretrained"}
304+
305+
if not USE_MICRO_MODELS and os.path.exists(model_path):
306+
model_path_kwargs = {"model_path": model_path}
307+
else:
308+
model_path_kwargs = {"variant": model_path}
309+
310+
distributed_kwargs = {}
311+
if USE_DISTRIBUTED:
312+
distributed_kwargs["distr_param"] = "tp"
313+
distributed_kwargs["group"] = dist.group.WORLD
314+
get_model_kwargs = {**model_path_kwargs, **micro_model_kwargs, **distributed_kwargs}
291315

316+
tokenizer = tokenizers.get_tokenizer(model_path)
317+
318+
# prepare the AIU model
319+
model = get_model(
320+
device_type="cpu",
321+
fused_weights=False,
322+
**get_model_kwargs
323+
)
324+
325+
model.eval()
326+
torch.set_grad_enabled(False)
327+
model.compile(backend="sendnn_decoder")
328+
329+
330+
# prepare input_ids
331+
input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
332+
333+
# warmup aiu model
334+
warmup_model(model, input_ids, max_new_tokens, **padding_kwargs)
335+
336+
# aiu validatation
337+
aiu_validation_info = extract_validation_information(
338+
model,
339+
input_ids,
340+
max_new_tokens,
341+
None,
342+
only_last_token=True,
343+
**padding_kwargs
344+
)

0 commit comments

Comments
 (0)