Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix smoothquant ignore, Fix typing, Add glm mappings #1015

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils.offload import is_module_offloaded
Expand All @@ -14,7 +14,11 @@
)
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import get_layers, get_matching_layer
from llmcompressor.utils.pytorch.module import (
get_layers,
get_matching_layer,
match_targets,
)

MINIMUM_SMOOTHING_SCALE = 1e-5

Expand Down Expand Up @@ -95,7 +99,7 @@ class SmoothQuantModifier(Modifier):
"""

smoothing_strength: float = 0.5
mappings: Optional[List[Tuple]] = None
mappings: Optional[List[Union[Tuple, List]]] = None
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None
Expand Down Expand Up @@ -176,7 +180,7 @@ def _resolve_mappings(self, model: Module) -> List:
for to_balance, to_smooth in self.mappings:
to_smooth_layers = get_layers(to_smooth, model)
for layer_name, smooth_layer in to_smooth_layers.items():
if layer_name not in self.ignore:
if not match_targets(layer_name, self.ignore)[0]:
balance_layers = []
for balance_suffix in to_balance:
# find the submodule that matches the activation layer
Expand Down
5 changes: 3 additions & 2 deletions src/llmcompressor/modifiers/smoothquant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
smooth_layers="re:.*post_attention_layernorm",
),
]
MIXTRAL_MAPPINGS: List[LayerMap] = [
MIXTRAL_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [
LayerMap(
balance_layers=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
smooth_layers="re:.*input_layernorm",
Expand All @@ -49,10 +49,11 @@
# Add more mappings here
MAPPINGS_REGISTRY: Dict[str, List[LayerMap]] = {
"LlamaForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"MixtralForCausalLM": MIXTRAL_MAPPINGS,
"MixtralForCausalLM": MIXTRAL_SMOOTHQUANT_MAPPINGS,
"MistralForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Qwen2ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"BloomForCausalLM": BLOOM_SMOOTHQUANT_MAPPINGS,
"ChatGLMForConditionalGeneration": BLOOM_SMOOTHQUANT_MAPPINGS,
}


Expand Down
Loading