-
Notifications
You must be signed in to change notification settings - Fork 243
Add NAS to Minitron pruning for parameter based auto-pruning #720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
afde981
b6d775d
0fd1e31
cae3178
5d3be27
4b645f3
f73ab04
7011abd
b015cd0
1234e55
c7968d3
f6e9992
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,26 +18,18 @@ | |
| import types | ||
| from abc import ABC | ||
| from collections.abc import Callable, Sequence | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from megatron.core.fusions.fused_layer_norm import FusedLayerNorm | ||
| from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding | ||
| from megatron.core.models.gpt import GPTModel | ||
| from megatron.core.parallel_state import ( | ||
| get_data_parallel_group, | ||
| get_pipeline_model_parallel_group, | ||
| get_tensor_model_parallel_group, | ||
| is_pipeline_first_stage, | ||
| is_pipeline_last_stage, | ||
| ) | ||
| from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage | ||
| from megatron.core.tensor_parallel.layers import ( | ||
| ColumnParallelLinear, | ||
| RowParallelLinear, | ||
| VocabParallelEmbedding, | ||
| ) | ||
| from megatron.core.transformer import MegatronModule | ||
| from megatron.core.transformer.attention import SelfAttention | ||
| from megatron.core.transformer.dot_product_attention import DotProductAttention | ||
| from megatron.core.transformer.mlp import MLP | ||
|
|
@@ -51,29 +43,14 @@ | |
| from modelopt.torch.nas.modules import DynamicModuleList | ||
| from modelopt.torch.opt.dynamic import DynamicModule | ||
| from modelopt.torch.opt.hparam import HPType | ||
| from modelopt.torch.opt.searcher import ConstraintsDict | ||
| from modelopt.torch.trace import Symbol | ||
| from modelopt.torch.utils import distributed as dist | ||
| from modelopt.torch.utils import ( | ||
| get_module_device, | ||
| make_divisible, | ||
| param_num_from_forward, | ||
| print_rank_0, | ||
| random, | ||
| ) | ||
| from modelopt.torch.utils import make_divisible | ||
|
|
||
| from ..algorithms import ( | ||
| MODULE_TYPE_TO_CONSTRAINTS_FUNC, | ||
| ConstraintEvalFunc, | ||
| ConstraintInterpolator, | ||
| ConstraintsFunc, | ||
| ConstraintsRes, | ||
| ) | ||
| from ..hparams.concat import build_concat_hp | ||
| from ..modules import _DynamicLayerNorm | ||
| from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices | ||
| from ..registry import DMRegistry | ||
| from ..search_space import SampleFunc | ||
| from ..traced_hp import TracedHp | ||
|
|
||
| SUPPORTED_MODELS = {GPTModel: "megatron.core.models.gpt.GPTModel"} | ||
|
|
@@ -634,7 +611,6 @@ def modify( | |
|
|
||
| def _export_reinit_token_dispatcher(self) -> None: | ||
| """Reinitialize the token dispatcher after pruning.""" | ||
| print_rank_0("Reinitializing token dispatcher after pruning") | ||
| if hasattr(moe_utils, "get_default_model_comm_pgs"): | ||
| model_comm_pgs = moe_utils.get_default_model_comm_pgs() | ||
| else: | ||
|
|
@@ -1045,27 +1021,30 @@ def modify( | |
| *, | ||
| hidden_size_divisor: int = 1, | ||
| ffn_hidden_size_divisor: int = 1, | ||
| mamba_num_heads_divisor: int = 1, | ||
| mamba_head_dim_divisor: int = 1, | ||
| num_moe_experts_divisor: int = 1, | ||
| num_layers_divisor: int = 1, | ||
| ): | ||
| """Modify the dynamic choices of the module according to provided keyword arguments. | ||
|
|
||
| Args: | ||
| hidden_size_divisor: The divisor of the hidden_size. | ||
| ffn_hidden_size_divisor: The divisor of the mlp ffn_hidden_size. | ||
| mamba_num_heads_divisor: The divisor of the mamba num_heads. | ||
| mamba_head_dim_divisor: The divisor of the mamba head_dim. | ||
| num_moe_experts_divisor: The divisor of the number of MoE experts. | ||
| num_layers_divisor: The divisor of the number of layers. | ||
| """ | ||
| hp = self.get_hparam("hidden_size") | ||
| choices = {int(make_divisible(c, hidden_size_divisor)) for c in hp.choices} # type: ignore[arg-type] | ||
| hp.choices = list(set(hp.choices) & choices | {hp.original}) | ||
| for hp_name, divisor in [ | ||
| ("hidden_size", hidden_size_divisor), | ||
| ("num_layers", num_layers_divisor), | ||
| ]: | ||
| hp = self.get_hparam(hp_name) | ||
| choices = {int(make_divisible(c, divisor)) for c in hp.choices} # type: ignore[arg-type] | ||
| hp.choices = list(set(hp.choices) & choices | {hp.original}) | ||
|
Comment on lines
+1037
to
+1043
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quick question: I see the individual layers have a seperate param for
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I first extract the |
||
|
|
||
| for layer in self.decoder.layers: | ||
| layer.modify( | ||
| ffn_hidden_size_divisor=ffn_hidden_size_divisor, | ||
| mamba_num_heads_divisor=mamba_num_heads_divisor, | ||
| mamba_head_dim_divisor=mamba_head_dim_divisor, | ||
| num_moe_experts_divisor=num_moe_experts_divisor, | ||
| ) | ||
|
|
@@ -1084,86 +1063,3 @@ def export(self) -> torch.nn.Module: | |
| ).export() | ||
| self.output_layer.export() | ||
| return super().export() | ||
|
|
||
|
|
||
| class MegatronConstraintsFunc(ConstraintsFunc): | ||
| """A Functor class to check if sub-net satisfied all provided constraints. | ||
|
|
||
| We intentionally expose some attributes like `limits` s.t. we can modify it manually. | ||
| """ | ||
|
|
||
| _sample_points_dict: dict[tuple[str, ...], dict[str, SampleFunc]] = { | ||
| ("params",): {"min": min, "centroid": random.centroid, "max": max}, | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| model: MegatronModule, | ||
| constraints: ConstraintsDict, | ||
| dummy_input: Any | tuple[Any, ...], | ||
| deployment: dict | None = None, | ||
| fast_eval: bool = True, | ||
| ): | ||
| """Initialize with additional data parallel group info from megatron.""" | ||
| for key in constraints: | ||
| if key != "params": | ||
| raise ValueError("Only params constraints is supported for MegatronModule!") | ||
|
|
||
| self.model = model | ||
| self.dummy_input = dummy_input | ||
| self.deployment = deployment | ||
| self._fast_eval = fast_eval | ||
|
|
||
| # Getting data parallel group for | ||
| self.dp_group = get_data_parallel_group() | ||
|
|
||
| # initialize latency interpolator | ||
| keys_for_interpolation = ("params",) | ||
| if ConstraintsFunc.is_configurable(self.model, "depth"): | ||
| keys_for_interpolation += ("flops_min_depth",) | ||
| self._latency_interpolator = ConstraintInterpolator( | ||
| self.model, | ||
| points_funcs={k: self.constraint_eval_funcs[k] for k in keys_for_interpolation}, | ||
| value_func=self._get_true_latency, | ||
| ) | ||
| # set fast/regular mode for latency interpolator | ||
| self._latency_interpolator.collect_mode = not self.fast_eval | ||
|
|
||
| # set limit at the end with setter to use sanity checks on constraints | ||
| self._limits = {} | ||
| self.limits = constraints | ||
|
|
||
| @property | ||
| def constraint_eval_funcs(self) -> dict[str, ConstraintEvalFunc]: | ||
| """Get constraint eval fns.""" | ||
| return { | ||
| "params": self._get_params, | ||
| } | ||
|
|
||
| def _get_params(self, _: ConstraintsRes | None = None) -> float: | ||
| """Get number of model parameters from forward pass.""" | ||
| params = param_num_from_forward(self.model, args=self.dummy_input, unit=1.0) | ||
| reduced_params = torch.Tensor([params]).to(device=get_module_device(self.model)) | ||
| torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) | ||
| torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) | ||
| return reduced_params.item() | ||
|
|
||
| def _get_flops(self, _: ConstraintsRes | None = None) -> float: | ||
| """Get inference FLOPs.""" | ||
| raise NotImplementedError | ||
|
|
||
| def _get_flops_min_depth(self, _: ConstraintsRes | None = None) -> float: | ||
| """Get inference FLOPs with depth set to minimum.""" | ||
| raise NotImplementedError | ||
|
|
||
| def _get_true_latency(self, _: ConstraintsRes | None = None) -> float: | ||
| """Get true inference latency.""" | ||
| raise NotImplementedError | ||
|
|
||
| def _get_latency(self, precomputed: ConstraintsRes | None = None) -> float: | ||
| """Get inference latency from interpolator.""" | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| # Clear the mapping and reinsert. | ||
| MODULE_TYPE_TO_CONSTRAINTS_FUNC[MegatronModule] = MegatronConstraintsFunc | ||
Uh oh!
There was an error while loading. Please reload this page.