Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test Run] Oneshot refactor + hfquantizer #1055

Draft
wants to merge 58 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
92fbddc
add quantization then finetune -- run_compressed=False
horheynm Dec 9, 2024
299eed3
add test
horheynm Dec 9, 2024
aebae9a
Merge branch 'main' into quant-then-finetune
horheynm Dec 9, 2024
9ea94ed
clean up
horheynm Dec 9, 2024
b7a968e
update test_run_compressed
horheynm Dec 11, 2024
7067ad0
better var name
horheynm Dec 11, 2024
c47ca6a
fix conseq onehsot
horheynm Dec 11, 2024
4492e53
fix logic in is_model_quat_from_path
horheynm Dec 11, 2024
126d3d5
add decompress tests
horheynm Dec 11, 2024
e17c190
Merge branch 'main' into run_compressed-tests
horheynm Dec 11, 2024
5da01c4
fix test - use automodelforcausallm decompress
horheynm Dec 11, 2024
a264fc0
Merge branch 'main' into quant-then-finetune
horheynm Dec 11, 2024
569ef80
Merge remote-tracking branch 'origin/run_compressed-tests' into agg-test
horheynm Dec 16, 2024
32f8503
Merge remote-tracking branch 'origin/fix-test-conseq-oneshot' into ag…
horheynm Dec 16, 2024
5395623
Merge remote-tracking branch 'origin/fix-test_compress-tensors-utils'…
horheynm Dec 16, 2024
ac03b3b
agg test
horheynm Dec 16, 2024
e01f314
fix names
horheynm Dec 16, 2024
ecc6b86
Merge branch 'run_compressed-tests' into agg-test
horheynm Dec 16, 2024
2c73ff9
Merge branch 'main' into agg-test
horheynm Dec 16, 2024
2cd4809
Merge branch 'main' into run_compressed-tests
horheynm Dec 16, 2024
f36cbac
fix typo
horheynm Dec 17, 2024
12d3706
Merge branch 'run_compressed-tests' into agg-test
horheynm Dec 17, 2024
a851cf3
Merge branch 'agg-test' of github.com:vllm-project/llm-compressor int…
horheynm Dec 17, 2024
b4c8828
add compressedlinear vs linear generation
horheynm Dec 17, 2024
cfcd7e3
add run compressed forward
horheynm Dec 17, 2024
fd0745f
revert folder struct for config
horheynm Dec 17, 2024
c0a552a
Merge branch 'main' into fix-test_compress-tensors-utils
horheynm Dec 23, 2024
5ba651a
Merge branch 'main' into fix-test-conseq-oneshot
horheynm Dec 23, 2024
98d8ecf
Merge branch 'main' into run_compressed-tests
horheynm Dec 23, 2024
ee4c70d
comments
horheynm Dec 23, 2024
0d32d23
Merge branch 'main' into quant-then-finetune
horheynm Dec 23, 2024
74280d0
Merge branch 'fix-test-conseq-oneshot' into agg-test
horheynm Dec 23, 2024
d388da1
Merge remote-tracking branch 'origin/run_compressed-tests' into agg-test
horheynm Dec 23, 2024
bff7dcd
Merge remote-tracking branch 'origin/fix-test_compress-tensors-utils'…
horheynm Dec 23, 2024
794ea91
Merge remote-tracking branch 'origin/quant-then-finetune' into agg-test
horheynm Dec 23, 2024
db8dba9
revert to main gha flow
horheynm Dec 23, 2024
276b779
init
horheynm Jan 7, 2025
c690043
decouple main and successful fp8 run
horheynm Jan 7, 2025
166e4df
remove stage runner
horheynm Jan 7, 2025
40c73eb
run calib
horheynm Jan 7, 2025
7747bd6
Merge branch 'main' into oneshot-refac-1
horheynm Jan 7, 2025
3b7fd6a
potential non use of session
horheynm Jan 7, 2025
b3031c0
Merge branch 'oneshot-refac-1' of github.com:vllm-project/llm-compres…
horheynm Jan 7, 2025
1cd3d90
get rid of session, use oneshotclass
horheynm Jan 7, 2025
a5d0fd7
pass existing tests
horheynm Jan 8, 2025
33e1b16
Merge branch 'main' into oneshot-refac-1
horheynm Jan 8, 2025
e7407b9
pass finetune tests not dep on HF release
horheynm Jan 8, 2025
d352e4c
Merge branch 'oneshot-refac-1' of github.com:vllm-project/llm-compres…
horheynm Jan 8, 2025
bc532e7
remove unnecessary changes 1
horheynm Jan 8, 2025
137c02e
remove duplicate code
horheynm Jan 8, 2025
d691652
Merge branch 'main' into agg-test
horheynm Jan 9, 2025
6d5cdbc
remove duplicate code, set output_dir and save_tensors as training_ar…
horheynm Jan 9, 2025
2c7c5f0
pass tests pre HFQuantizer check
horheynm Jan 9, 2025
463a043
Merge branch 'agg-test-main' into oneshot-hfquantizer
horheynm Jan 10, 2025
f1d4539
fix devices error
horheynm Jan 10, 2025
1f2a505
clear func name
horheynm Jan 10, 2025
88817ba
Merge branch 'run_compressed-tests' into agg-test
horheynm Jan 10, 2025
826e23c
Merge branch 'agg-test' into oneshot-hfquantizer
horheynm Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from llmcompressor.modifiers import StageModifiers
from llmcompressor.recipe import RecipeContainer

__all__ = ["CompressionLifecycle"]
__all__ = [
"CompressionLifecycle",
]


@dataclass
Expand Down
19 changes: 11 additions & 8 deletions src/llmcompressor/pytorch/utils/sparsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,18 @@ def params_quantized(self) -> int:
"""
:return: number of parameters across quantized layers
"""
return sum(
torch.numel(self.trainable_params[f"{name}.weight"])
+ (
torch.numel(self.trainable_params[f"{name}.bias"])
if hasattr(layer, "bias") and layer.bias is not None
else 0
num_params = 0
for name, layer in get_quantized_layers(self.module):
num_param = torch.numel(
self.trainable_params.get(f"{name}.weight", torch.tensor([]))
)
for (name, layer) in get_quantized_layers(self.module)
)
if num_param is None:
logger.warning(f"{name} is not recognized in trainable_params")
continue
if hasattr(layer, "bias") and layer.bias is not None:
num_params += layer.bias

return num_params

@property
def params_quantized_percent(self) -> float:
Expand Down
135 changes: 135 additions & 0 deletions src/llmcompressor/transformers/calibration/oneshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from pathlib import PosixPath
from typing import Optional

from loguru import logger
from torch.utils.data import DataLoader

from llmcompressor.core.lifecycle import CompressionLifecycle
from llmcompressor.transformers.finetune.data.data_helpers import (
get_calibration_dataloader,
)
from llmcompressor.transformers.finetune.text_generation import (
initialize_model_from_path,
initialize_processor_from_path,
parse_args,
)
from llmcompressor.transformers.finetune.training_args import DEFAULT_OUTPUT_DIR
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
)


class Oneshot:
"""
Class responsible for carrying out oneshot calibration.

Usage:

```python
oneshot_calibrator = Oneshot(model=model, recipe=recipe, dataset=dataset)
oneshot_calibrator.run()

model = oneshot_calibrator.model
tokenizer_or_processor = oneshot_calibrator.tokenizer_or_processor
recipe = oneshot_calibrator.recipe
```
"""

MODIFIER_LIFECYCLE_ACTIONS = (
"initialize",
"finalize",
)

def __init__(self, **kwargs):
self.model_args, self.data_args, self.recipe_args, training_args = parse_args(
**kwargs
)
self.lifecycle = CompressionLifecycle()
self.output_dir = training_args.output_dir

# Preprocess the model and tokenizer/processor
self._pre_process()

# Set instance attributes
self.model = self.model_args.model
self.tokenizer_or_processor = self.model_args.processor
self.recipe = self.recipe_args.recipe
self.modifiers = self.lifecycle.modifiers

def run(self):
"""Perform oneshot calibration."""
calibration_dataloader = get_calibration_dataloader(
self.data_args, self.tokenizer_or_processor
)
self._apply_recipe_modifiers(calibration_dataloader)
self._post_process()

def _apply_recipe_modifiers(self, calibration_dataloader: Optional[DataLoader]):
"""Apply recipe modifiers to the model."""
for action in self.MODIFIER_LIFECYCLE_ACTIONS:
lifecycle = getattr(self.lifecycle, action)
lifecycle(
model=self.model,
recipe=self.recipe,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
start=-1, # oneshot-specific argument
copy_data=False,
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
)

def _pre_process(self):
"""Preprocess model and tokenizer/processor"""
self._warn_tied_embeddings()

# Initialize model
if isinstance(self.model_args.model, (str, PosixPath)):
self.model_args.model, _ = initialize_model_from_path(self.model_args)

patch_tied_tensors_bug(self.model_args.model)
modify_save_pretrained(self.model_args.model)

# Initialize processor
if isinstance(self.model_args.processor, (str, type(None))):
self.model_args.processor = initialize_processor_from_path(
self.model_args, self.model_args.model
)

# Set minimum tokens per module if data arguments are provided
if self.data_args:
self.min_tokens_per_module = self.data_args.min_tokens_per_module

def _warn_tied_embeddings(self):
if self.model_args.tie_word_embeddings:
logger.debug(
"The tie_word_embeddings flag is by default set to False. "
"This guarantees that the one-shot algorithm saves the final "
"weights without errors. Detected tie_word_embeddings=True. "
"This may cause issues with the one-shot algorithm on save."
)

def _post_process(self):
"""Save model and reset the lifecycle if requested"""
if (
isinstance(self.model_args.model, str)
or self.output_dir != DEFAULT_OUTPUT_DIR
):
self.save()

if self.recipe_args.clear_sparse_session:
self.reset_lifecycle()

def save(self):
"""Save the model and tokenizer/processor to the output directory"""
self.model.save_pretrained(
self.output_dir,
save_compressed=self.model_args.save_compressed,
stage_modifiers=self.lifecycle.modifiers,
)
if self.tokenizer_or_processor:
self.tokenizer_or_processor.save_pretrained(self.output_dir)

def reset_lifecycle(self):
"""Reset the CompressionLifecycle."""
self.lifecycle.reset()
14 changes: 10 additions & 4 deletions src/llmcompressor/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch import Tensor
from torch.nn import Module

from llmcompressor.core import active_session
from llmcompressor.core import CompressionLifecycle, active_session
from llmcompressor.modifiers.stage import StageModifiers
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
from llmcompressor.transformers.compression.helpers import (
infer_sparse_targets_and_ignores,
Expand Down Expand Up @@ -40,7 +41,10 @@ def infer_global_sparsity(
return global_sparsity

@staticmethod
def infer_sparsity_structure(model: Optional[Module] = None) -> str:
def infer_sparsity_structure(
model: Optional[Module] = None,
stage_modifiers: Optional[CompressionLifecycle] = None,
) -> str:
"""
Determines what sparsity structure, if any, was applied.

Expand All @@ -58,7 +62,7 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str:
sparsity_structure = None

current_session = active_session()
stage_modifiers = current_session.lifecycle.modifiers
stage_modifiers = stage_modifiers or current_session.lifecycle.modifiers
if stage_modifiers:
sparsity_structure = infer_sparsity_structure_from_stage_modifiers(
stage_modifiers
Expand All @@ -74,6 +78,7 @@ def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
stage_modifiers: Optional[StageModifiers] = None,
) -> Optional["SparsityCompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
Expand All @@ -93,7 +98,8 @@ def from_pretrained(
return None

sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
model=model
model=model,
stage_modifiers=stage_modifiers,
)
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
Expand Down
76 changes: 76 additions & 0 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import os
import re
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from loguru import logger
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator

Expand All @@ -15,6 +17,7 @@
"get_raw_dataset",
"make_dataset_splits",
"get_custom_datasets_from_path",
"get_calibration_dataloader",
]


Expand Down Expand Up @@ -243,3 +246,76 @@ def do_transform(candidate: str) -> bool:
transform_dataset_key(dataset_key)

return data_files


def get_calibration_dataloader(
data_args,
processor,
add_labels: bool = False, # for oneshot
do_oneshot=True,
):
"""
Loads datasets for each flow based on data_args, stores a Dataset for each
enabled flow in self.datasets

:param processor: processor or tokenizer to use for dataset tokenization
:param add_labels: if True, add labels column to dataset splits
"""
if data_args.dataset is None:
logger.info(
"Running oneshot without calibration data. This is expected for "
"weight-only and dynamic quantization"
)
return

splits = data_args.splits
tokenized_datasets = {}

def _get_split_name(inp_str):
# strip out split name, for ex train[60%:] -> train
match = re.match(r"(\w*)\[.*\]", inp_str)
if match is not None:
return match.group(1)
return inp_str

if splits is None:
splits = {"all": None}
elif isinstance(splits, str):
splits = {_get_split_name(splits): splits}
elif isinstance(splits, List):
splits = {_get_split_name(s): s for s in splits}

# default to custom dataset if dataset provided isn't a string
registry_id = data_args.dataset if isinstance(data_args.dataset, str) else "custom"
for split_name, split_str in splits.items():
dataset = data_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
tokenized_datasets[split_name] = dataset
else:
# dataset needs to be tokenized
from llmcompressor.transformers.finetune.data.base import (
TextGenerationDataset,
)

dataset_manager = TextGenerationDataset.load_from_registry(
registry_id,
data_args=data_args,
split=split_str,
processor=processor,
)
tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels)

datasets = make_dataset_splits(
tokenized_datasets,
do_oneshot=do_oneshot,
)

calibration_dataset = datasets.get("calibration")

return format_calibration_data(
tokenized_dataset=calibration_dataset,
num_calibration_samples=data_args.num_calibration_samples,
do_shuffle=data_args.shuffle_calibration_samples,
collate_fn=data_args.data_collator,
)
31 changes: 19 additions & 12 deletions src/llmcompressor/transformers/finetune/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from
Model variables used for oneshot calibration, training or finetuning and
stage runners (combination of oneshot and finetune going back and forth)

"""

model: str = field(
Expand Down Expand Up @@ -44,17 +46,7 @@ class ModelArguments:
default=None,
metadata={"help": "Where to store the pretrained data from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizers. Default True"},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use "
"(can be a branch name, tag name or commit id)"
},
)

use_auth_token: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -83,3 +75,18 @@ class ModelArguments:
"repositories you trust and in which you have read the code"
},
)
save_compressed: Optional[bool] = field(
default=True,
metadata={"help": "Whether to compress sparse models during save"},
)
oneshot_device: Optional[str] = field(
default="cuda:0",
metadata={"help": "Device to run oneshot calibration on"},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use "
"(can be a branch name, tag name or commit id)"
},
)
Loading
Loading