Skip to content

Add model loader arg to generate_layers_metrics #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions aiu_fms_testing_utils/utils/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,41 @@
import torch.nn as nn


def get_model_prefix(model_path,
shapes_size,
max_new_tokens: None,
batch_size: None,
seq_length: None,
dtype: None,
include_shapes: False):
"""
Generate a prefix for a model based on its path and other parameters.

Args:
model_path (str): The path to the model file.
shapes_size (int): The size of the shapes array to use in the model.
max_new_tokens (int): The maximum number of new tokens to use in the model.
batch_size (int): The batch size to use in the model.
seq_length (int): The sequence length to use in the model.
dtype (str): The data type to use in the model.
include_shapes (bool): Include or not the shapes to the prefix.
Returns:
str: A prefix for the model based on its path and other parameters.
"""
if model_path.count("/") > 1:
# this means that the model_path does NOT match to the hf pattern
# Eg.: /home/another-dir/another/ibm-granite/granite-3.3-8b-base
model_prefix = model_path.split("/")[-2] + "--" + model_path.split("/")[-1]
else:
# this means that the model_path does match to the hf pattern
# Eg.: ibm-granite/granite-3.3-8b-base
model_prefix = model_path.replace("/", "--")

if shapes_size > 1 or include_shapes:
model_prefix = f"{model_prefix}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}"

return model_prefix

def abs_diff_linalg_norm(res_vector):
"""
Calculates the Euclidean norm (also known as the L2 norm) of a given array res_vector. This is equivalent to finding the square
Expand Down
175 changes: 127 additions & 48 deletions scripts/generate_layers_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import time
import logging
import argparse
Expand All @@ -11,16 +12,12 @@
from fms.models import get_model
from fms.utils.generation import generate

from aiu_fms_testing_utils.testing.validation import get_default_validation_prefix
from transformers import AutoModelForCausalLM, AutoTokenizer

from aiu_fms_testing_utils.utils import prepare_inputs
from aiu_fms_testing_utils.utils.metrics_utils import tensor_abs_diff, tensor_cos_sim
from aiu_fms_testing_utils.utils.metrics_utils import tensor_abs_diff, tensor_cos_sim, get_model_prefix


logger = logging.getLogger(__name__)
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO').upper()
logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(message)s")

parser = argparse.ArgumentParser(
description="Script to generate the model's metrics by layer"
)
Expand All @@ -47,6 +44,13 @@
required=True,
help="Sets the output generation mode."
)
parser.add_argument(
"--model_loader",
choices=["fms", "hf"],
default="fms",
required=True,
help="Which model loader/runner to be used; fms - IBM's Foundation Model Stack or hf - HuggingFace Transformers."
)
parser.add_argument(
"--batch_sizes",
type=str,
Expand Down Expand Up @@ -86,6 +90,22 @@
output_path = args.output_path
sharegpt_path = args.sharegpt_path

if not os.path.exists(os.path.join(output_path,"layers-input-output-logs")):
os.makedirs(os.path.join(output_path,"layers-input-output-logs"))

logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(name)-12s %(message)s',
datefmt='%m-%d %H:%M',
filename=os.path.join(output_path, "layers-input-output-logs", f"layers_input.log"),
filemode='w')
console = logging.StreamHandler()
console.setLevel(os.getenv("LOG_LEVEL", "INFO"))
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)

logger = logging.getLogger('generate_layers_metrics')

common_model_paths = args.model_path if args.model_path else args.variant
if isinstance(common_model_paths, str):
common_model_paths = [str(bs) for bs in common_model_paths.split(",")]
Expand Down Expand Up @@ -134,34 +154,44 @@ def __infer_layer(model, max_len, device, max_new_tokens, batch_size, tokenizer)

do_sample = False
use_cache = True
result = None

prompts = prepare_inputs(batch_size, max_len, tokenizer, sharegpt_path)
ids, pad_input_ids = prompts

if "cuda" in device:
ids = ids.to("cuda")

if hasattr(model.config, "ntk_scaling") and model.config.ntk_scaling:
if args.model_loader == "hf":
max_seq_len = max_len
elif hasattr(model.config, "ntk_scaling") and model.config.ntk_scaling:
max_seq_len = max(max_len, model.config.max_expected_seq_len)
else:
# without ntk scaling, extending the seq length too far gives bogus results.
max_seq_len = model.config.max_expected_seq_len

if "generate" in mode:
with torch.no_grad():
result = generate(
model,
ids,
max_new_tokens=max_new_tokens,
use_cache=use_cache,
do_sample=do_sample,
max_seq_len=max_seq_len,
timing="e2e",
eos_token_id=None,
contiguous_cache=True,
extra_kwargs={},
)
result, timings = result
if args.model_loader == "fms":
result = generate(
model,
ids,
max_new_tokens=max_new_tokens,
use_cache=use_cache,
do_sample=do_sample,
max_seq_len=max_seq_len,
timing="e2e",
eos_token_id=None,
contiguous_cache=True,
extra_kwargs={},
)
result, timings = result
if args.model_loader == "hf":
result = model.generate(ids,
max_length=max_seq_len,
max_new_tokens=max_new_token,
do_sample=do_sample,
use_cache=use_cache)
logger.info(f"Generation completed: Result len is {len(result)}")
if len(result.shape) == 1:
result = result.unsqueeze(0)
Expand Down Expand Up @@ -304,7 +334,7 @@ def write_csv(values, path, metric, gpu_layer_shape, cpu_layer_shape, output_sha
f.write(f"{values}\n")
f.close()

def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens, model_thresholds_folder):
"""
Generate metrics for layers in a given model.

Expand All @@ -313,6 +343,7 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
batch_size (int): The batch size used for inference.
seq_length (int): The sequence length used for inference.
max_new_tokens (int): The maximum number of new tokens allowed for generation.
model_thresholds_folder (path): The path where the files will be saved.

Returns:
None
Expand All @@ -324,6 +355,14 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
if "HF_HOME" not in os.environ:
os.environ["HF_HOME"] = "/tmp/models/hf_cache"

model_prefix = get_model_prefix(model_path=model_path,
shapes_size=0,
max_new_tokens=max_new_tokens,
batch_size=batch_size,
seq_length=seq_length,
dtype="",
include_shapes=False)

model_path_kwargs = {"variant": model_path} if args.variant else {"model_path": model_path}
micro_model_kwargs = {"architecture": args.architecture}

Expand All @@ -332,30 +371,46 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
**micro_model_kwargs,
}

tokenizer = tokenizers.get_tokenizer(model_path)

# prepare the cpu model
validation_model = get_model(
device_type="cpu",
data_type=torch.float32,
fused_weights=False,
**get_model_kwargs,
)

# prepare the cuda model
validation_model_cuda = get_model(
device_type="cuda",
data_type=torch.float16,
fused_weights=False,
**get_model_kwargs,
)
if args.model_loader == "hf":
tokenizer = AutoTokenizer.from_pretrained(model_path)

# prepare the cpu model
validation_model = AutoModelForCausalLM.from_pretrained(model_path,
device_map="cpu",
torch_dtype=torch.float32
)
# prepare the cuda model
validation_model_cuda = AutoModelForCausalLM.from_pretrained(model_path,
device_map="cuda",
torch_dtype=torch.float16
)
if args.model_loader == "fms":
tokenizer = tokenizers.get_tokenizer(model_path)

# prepare the cpu model
validation_model = get_model(
device_type="cpu",
data_type=torch.float32,
fused_weights=False,
**get_model_kwargs,
)

# prepare the cuda model
validation_model_cuda = get_model(
device_type="cuda",
data_type=torch.float16,
fused_weights=False,
**get_model_kwargs,
)

layer_stack_cpu = __register_call_layers(model=validation_model,
batch_size=batch_size,
device="cpu",
seq_length=seq_length, max_new_tokens=max_new_tokens,
tokenizer=tokenizer)

torch.save(layer_stack_cpu, os.path.join(output_path, "layers-input-output-logs", f"{model_prefix}-{mode}-layer_stack_cpu.pt"))

global generate_iters
generate_iters = 0
logger.info(f"Finished registering CPU layers")
Expand All @@ -365,6 +420,8 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
device="cuda",
seq_length=seq_length, max_new_tokens=max_new_tokens,
tokenizer=tokenizer)

torch.save(layer_stack_cuda, os.path.join(output_path, "layers-input-output-logs", f"{model_prefix}-{mode}-layer_stack_cuda.pt"))

assert len(layer_stack_cuda.keys()) == len(layer_stack_cpu.keys())

Expand All @@ -389,9 +446,7 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
tensor_cuda_out = cuda_output[-1]
tensor_cpu_out = cpu_output[-1]
for i in range(len(cpu_output)):
logger.debug(f"inputs: {cuda_output[i].shape} {cpu_output[i].to('cuda').shape}")
cos_sim.append(tensor_cos_sim(cuda_output[i], cpu_output[i].to('cuda')))
logger.debug(f"cos_sim output:{tensor_cos_sim(cuda_output[i], cpu_output[i].to('cuda')).shape}")
abs_diff.append(tensor_abs_diff(cuda_output[i], cpu_output[i].to('cuda')))
else:
head_tensor_cpu = cpu_output[-1]
Expand All @@ -401,28 +456,32 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
for j in range(len(head_tensor_gpu[i])):
tensor_cuda_out = head_tensor_gpu[i][j]
tensor_cpu_out = head_tensor_cpu[i][j]
logger.debug(f"inputs: {head_tensor_gpu[i][j].shape} {head_tensor_cpu[i][j].to('cuda').shape}")
cos_sim.append(tensor_cos_sim(head_tensor_cpu[i][j].to('cuda'), head_tensor_gpu[i][j]))
logger.debug(f"cos_sim output:{tensor_cos_sim(head_tensor_cpu[i][j].to('cuda'), head_tensor_gpu[i][j]).shape}")
abs_diff.append(tensor_abs_diff(head_tensor_cpu[i][j].to('cuda'), head_tensor_gpu[i][j]))
else:
tensor_cuda_out = head_tensor_gpu[i]
tensor_cpu_out = head_tensor_cpu[i]
logger.debug(f"inputs: {head_tensor_gpu[i].shape} {head_tensor_cpu[i].to('cuda').shape}")
cos_sim.append(tensor_cos_sim(head_tensor_cpu[i].to('cuda'), head_tensor_gpu[i]))
logger.debug(f"cos_sim output:{tensor_cos_sim(head_tensor_cpu[i].to('cuda'), head_tensor_gpu[i]).shape}")
abs_diff.append(tensor_abs_diff(head_tensor_cpu[i].to('cuda'), head_tensor_gpu[i]))
else:
tensor_cpu_out = cpu_output.to('cuda')
tensor_cuda_out = cuda_output
abs_diff = tensor_abs_diff(tensor_cpu_out, cuda_output)
cos_sim = tensor_cos_sim(tensor_cpu_out, cuda_output)

prefix = get_default_validation_prefix(model_path, max_new_token, batch_size, seq_length, 'float16')
layer_name = str(layer_key).replace('[','').replace(']', '')

abs_diff_path = os.path.join(output_path, f"{prefix}--{layer_name}.abs_diff.csv")
cos_sim_path = os.path.join(output_path, f"{prefix}--{layer_name}.cos_sim.csv")
prefix = get_model_prefix(model_path=model_path,
shapes_size=len(common_shapes),
max_new_tokens=max_new_token,
batch_size=batch_size,
seq_length=seq_length,
dtype='float16',
include_shapes=True
)

abs_diff_path = os.path.join(model_thresholds_folder, f"{prefix}--{layer_name}.abs_diff.csv")
cos_sim_path = os.path.join(model_thresholds_folder, f"{prefix}--{layer_name}.cos_sim.csv")

cos_sim_res, cos_shape = get_metric_values(cos_sim)
abs_diff_res, abs_diff_shape = get_metric_values(abs_diff)
Expand All @@ -437,5 +496,25 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
logger.info(f"Completed {model_path} layers' metrics generation with {mode} mode")

for model_id, batch_size, sequence_length, max_new_token in common_shapes:

model_prefix = get_model_prefix(model_id,
shapes_size=len(common_shapes),
max_new_tokens=max_new_token,
batch_size=batch_size,
seq_length=sequence_length,
dtype="",
include_shapes=False
)

model_thresholds_folder = os.path.join(output_path, model_prefix)

if not os.path.exists(model_thresholds_folder):
os.makedirs(model_thresholds_folder)

logger.info(f"testing model_id-{model_id}, max_new_tokens-{max_new_token}, batch_size-{batch_size}, seq_length-{sequence_length}")
generate_layers_metrics(model_path=model_id, batch_size=batch_size, seq_length=sequence_length, max_new_tokens=max_new_token)
generate_layers_metrics(model_path=model_id,
batch_size=batch_size,
seq_length=sequence_length,
max_new_tokens=max_new_token,
model_thresholds_folder=model_thresholds_folder
)
8 changes: 5 additions & 3 deletions tests/LAYERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The idea is to run, the prompts through the model with the pre- and post-hooks a
The script [generate_layers_metrics.py](../scripts/generate_layers_metrics.py) requires the following arguments to be run:

```bash
usage: generate_layers_metrics.py [-h] [--architecture ARCHITECTURE] [--variant VARIANT] [--model_path MODEL_PATH] --mode {generate,model-forward} --batch_sizes BATCH_SIZES --seq_lengths SEQ_LENGTHS --max_new_tokens MAX_NEW_TOKENS [--output_path OUTPUT_PATH] [--sharegpt_path SHAREGPT_PATH]
usage: generate_layers_metrics.py [-h] [--architecture ARCHITECTURE] [--variant VARIANT] [--model_path MODEL_PATH] --mode {generate,model-forward} --model_loader {fms,hf} --batch_sizes BATCH_SIZES --seq_lengths SEQ_LENGTHS --max_new_tokens MAX_NEW_TOKENS [--output_path OUTPUT_PATH] [--sharegpt_path SHAREGPT_PATH]

Script to generate the model's metrics by layer

Expand All @@ -32,6 +32,8 @@ options:
Paths to the directory containing model's weights (.pth files sharded by tensor parallel rank, not HF weights)
--mode {generate,model-forward}
Sets the output generation mode.
--model_loader {fms,hf}
Which model loader/runner to be used; fms - IBM's Foundation Model Stack or hf - HuggingFace Transformers.
--batch_sizes BATCH_SIZES
Batch sizes separated by comma. Eg.: 1,2
--seq_lengths SEQ_LENGTHS
Expand Down Expand Up @@ -79,7 +81,7 @@ cd aiu-fms-testing-utils/tests/resources

mkdir /tmp/output

python3 generate_layers_metrics.py --mode model-forward --variant ibm-granite/granite-3.2-8b-instruct --architecture hf_pretrained --batch_sizes 1 --seq_lengths 64 --max_new_tokens 128
python3 generate_layers_metrics.py --mode model-forward --variant ibm-granite/granite-3.2-8b-instruct --architecture hf_pretrained --batch_sizes 1 --seq_lengths 64 --max_new_tokens 128 --model_loader fms
```
The files should get created at `/tmp/output` dir:
```bash
Expand All @@ -95,7 +97,7 @@ To get the second step of the flow and get the thresholds by layer, run:
```bash
cd /aiu-fms-testing-utils/tests/resources

python3 get_thresholds.py --models ibm-granite/granite-3.2-8b-instruct --metrics abs_diff cos_sim_avg cos_sim_men --file_base /tmp/output --layer_io
python3 get_thresholds.py --models ibm-granite/granite-3.2-8b-instruct --metrics abs_diff cos_sim_avg cos_sim_mean --file_base /tmp/output --layer_io
```
It should print the metric of each layer:
```bash
Expand Down