-
Notifications
You must be signed in to change notification settings - Fork 109
/
Copy pathsparsification.py
213 lines (189 loc) · 7.31 KB
/
sparsification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""
Helper functions for retrieving information related to model sparsification
"""
import json
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)
import torch
from accelerate.accelerator import get_state_dict_offloaded_model
from loguru import logger
from torch.nn import Module
from tqdm import tqdm
from llmcompressor.pytorch.utils.helpers import get_quantized_layers, tensor_sparsity
__all__ = [
"ModuleSparsificationInfo",
"GradSampler",
]
class ModuleSparsificationInfo:
"""
Helper class for providing information related to torch Module parameters
and the amount of sparsification applied. Includes information for pruning
and quantization
:param module: torch Module to analyze
:param state_dict: optional state_dict to analyze in place of the torch model. This
is used when analyzing an FSDP model, where the full weights may not be accessible
"""
def __init__(
self, module: Module, state_dict: Optional[Dict[str, torch.Tensor]] = None
):
self.module = module
if state_dict is not None:
# when analyzing an FSDP model, the state_dict does not differentiate
# between trainable and non-trainable parameters
# (e.g. it can contain buffers) this means that the
# self.trainable_parameters may be overestimated
self.trainable_params = state_dict
else:
if hasattr(module, "_hf_hook"):
self.trainable_params = get_state_dict_offloaded_model(module)
else:
self.trainable_params = {
k: v for k, v in self.module.named_parameters() if v.requires_grad
}
def __str__(self):
return json.dumps(
{
"params_summary": {
"total": self.params_total,
"sparse": self.params_sparse,
"sparsity_percent": self.params_sparse_percent,
"quantized": self.params_quantized,
"quantized_percent": self.params_quantized_percent,
},
"params_info": self.params_info,
}
)
@property
def params_total(self) -> int:
"""
:return: total number of trainable parameters in the model
"""
return sum(torch.numel(param) for param in self.trainable_params.values())
@property
def params_sparse(self) -> int:
"""
:return: total number of sparse (0) trainable parameters in the model
"""
return sum(
round(tensor_sparsity(param).item() * torch.numel(param))
for param in tqdm(
self.trainable_params.values(), desc="Calculating model sparsity"
)
)
@property
def params_sparse_percent(self) -> float:
"""
:return: percent of sparsified parameters in the entire model
"""
return self.params_sparse / float(self.params_total) * 100
@property
def params_quantized(self) -> int:
"""
:return: number of parameters across quantized layers
"""
num_params = 0
for name, layer in get_quantized_layers(self.module):
num_param = torch.numel(
self.trainable_params.get(f"{name}.weight", torch.tensor([]))
)
if num_param is None:
logger.warning(f"{name} is not recognized in trainable_params")
continue
if hasattr(layer, "bias") and layer.bias is not None:
num_params += layer.bias
return num_params
@property
def params_quantized_percent(self) -> float:
"""
:return: percentage of parameters that have been quantized
"""
return self.params_quantized / float(self.params_total) * 100
class GradSampler:
"""
Class for computing gradient samples for a Model given a sample data loader and
loss function.
:param data_loader: iterator of data samples to use as model inputs and their loss
targets. items must be tuples of
(forward_args: List, forward_kwargs: Dict, loss_targets: Any)
where the forward pass will be outputs = model(*forward_args, **forward_kwargs)
and loss will be loss = loss_fn(outputs, loss_targets)
:param loss_fn: function to be called on model outputs to compute the loss at
each step
"""
def __init__(
self,
data_loader: Union[Iterator[Tuple[List[Any], Dict[str, Any], Any]], Callable],
loss_fn: Callable[[Any, Any], Any],
):
if not isinstance(data_loader, Iterable) and not callable(data_loader):
raise ValueError(
"data_loader for GradSampler must be Iterable or Callable, received "
f"object of type {type(data_loader)}"
)
if not callable(loss_fn):
raise ValueError(
"loss_fn for GradSampler must be callable, given input "
f"with type {type(loss_fn)}"
)
self._data_loader = data_loader
self._loss_fn = loss_fn
def iter_module_backwards(
self,
module: Module,
num_grads: int,
progress_bar: bool = True,
) -> Generator[int, None, None]:
"""
:param module: module to compute gradients for
:param num_grads: number of gradient samples to compute
:return: generator that yields after every gradient is computed with the index
of the gradient sample number
"""
computed_grads = 0
pbar = tqdm(
total=num_grads, desc="Collecting gradients", disable=not progress_bar
)
with pbar:
while computed_grads < num_grads:
data_loader = (
self._data_loader()
if callable(self._data_loader)
else self._data_loader
)
for forward_args, forward_kwargs, loss_target in data_loader:
module.zero_grad()
# run sample forward and backwards pass
model_outputs = module(*forward_args, **forward_kwargs)
# Image classification models have been overridden to compute both
# the logit values and the probabilities, returning a tuple.
# No other models do this.
if model_outputs.__class__ == tuple:
model_outputs = model_outputs[0]
loss = self._loss_fn(model_outputs, loss_target)
loss.backward()
# yield so gradients can be collected
computed_grads += 1
yield computed_grads
if progress_bar:
pbar.update(1)
if computed_grads >= num_grads:
break
if computed_grads < num_grads:
logger.warning(
f"The requested num_grads:{num_grads} "
f"is greater than allowed by the dataset. \
Proceeding with less than requested. \
Please reduce num_grads to suppress the warning."
)
break
module.zero_grad()