Skip to content

Commit 5020847

Browse files
committed
added lint.yml and ran ruff format
Signed-off-by: kcirred <[email protected]>
1 parent e53e1f1 commit 5020847

File tree

18 files changed

+774
-343
lines changed

18 files changed

+774
-343
lines changed

.github/workflows/lint.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name: Lint
2+
3+
on: [pull_request]
4+
5+
jobs:
6+
lint:
7+
runs-on: ubuntu-latest
8+
steps:
9+
- uses: actions/checkout@v4
10+
- uses: astral-sh/ruff-action@v3
11+
with:
12+
src: "."
13+
version: "~= 0.9.5"
14+
- run: ruff check
15+
- run: ruff format --check

aiu_fms_testing_utils/testing/validation.py

Lines changed: 119 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,44 +6,77 @@
66
from aiu_fms_testing_utils.utils.aiu_setup import dprint
77
import os
88

9-
class LogitsExtractorHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]):
109

10+
class LogitsExtractorHook(
11+
Callable[
12+
[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]],
13+
Tuple[torch.Tensor, MutableMapping[str, Any]],
14+
]
15+
):
1116
def __init__(self):
1217
super().__init__()
1318
self.extracted_logits: Optional[torch.Tensor] = None
1419

15-
def __call__(self, token_position: torch.Tensor, logits: torch.Tensor, next_val: torch.Tensor, kwargs):
20+
def __call__(
21+
self,
22+
token_position: torch.Tensor,
23+
logits: torch.Tensor,
24+
next_val: torch.Tensor,
25+
kwargs,
26+
):
1627
if self.extracted_logits is None:
1728
self.extracted_logits = logits.unsqueeze(1)
1829
else:
19-
self.extracted_logits = torch.cat((self.extracted_logits, logits.unsqueeze(1)), dim=1)
30+
self.extracted_logits = torch.cat(
31+
(self.extracted_logits, logits.unsqueeze(1)), dim=1
32+
)
2033
return next_val, kwargs
2134

22-
class StaticTokenInjectorHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]):
2335

24-
def __init__(self, static_tokens: List[torch.Tensor], device_type: str="cpu"):
36+
class StaticTokenInjectorHook(
37+
Callable[
38+
[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]],
39+
Tuple[torch.Tensor, MutableMapping[str, Any]],
40+
]
41+
):
42+
def __init__(self, static_tokens: List[torch.Tensor], device_type: str = "cpu"):
2543
super().__init__()
26-
self.static_tokens = torch.tensor(static_tokens, device=device_type).t() # transposing so batch tokens per token_position
44+
self.static_tokens = torch.tensor(
45+
static_tokens, device=device_type
46+
).t() # transposing so batch tokens per token_position
2747

28-
def __call__(self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs):
48+
def __call__(
49+
self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs
50+
):
2951
next_val.copy_(self.static_tokens[token_position].unsqueeze(1))
3052
return next_val, kwargs
3153

32-
class GoldenTokenHook(Callable[[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]], Tuple[torch.Tensor, MutableMapping[str, Any]],]):
3354

34-
def __init__(self, static_tokens: torch.Tensor, device_type: str="cpu"):
55+
class GoldenTokenHook(
56+
Callable[
57+
[int, torch.Tensor, torch.Tensor, MutableMapping[str, Any]],
58+
Tuple[torch.Tensor, MutableMapping[str, Any]],
59+
]
60+
):
61+
def __init__(self, static_tokens: torch.Tensor, device_type: str = "cpu"):
3562
super().__init__()
3663
self.logits_extractor = LogitsExtractorHook()
3764
self.extracted_logits = None
38-
self.token_injector = StaticTokenInjectorHook(static_tokens, device_type=device_type)
65+
self.token_injector = StaticTokenInjectorHook(
66+
static_tokens, device_type=device_type
67+
)
3968

40-
def __call__(self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs):
41-
next_val, kwargs = self.logits_extractor(token_position, logits, next_val, kwargs)
69+
def __call__(
70+
self, token_position: int, logits: torch.Tensor, next_val: torch.Tensor, kwargs
71+
):
72+
next_val, kwargs = self.logits_extractor(
73+
token_position, logits, next_val, kwargs
74+
)
4275
self.extracted_logits = self.logits_extractor.extracted_logits
4376
return self.token_injector(token_position, logits, next_val, kwargs)
4477

45-
class ValidationInfo:
4678

79+
class ValidationInfo:
4780
def __init__(self, validation_info_list):
4881
super().__init__()
4982

@@ -54,7 +87,10 @@ def __iter__(self):
5487
yield vi
5588

5689
def get_info(self, info_name):
57-
return [[t.unsqueeze(0) for t in sentence[info_name]] for sentence in self._validation_info_list]
90+
return [
91+
[t.unsqueeze(0) for t in sentence[info_name]]
92+
for sentence in self._validation_info_list
93+
]
5894

5995
def save(self, save_dir_path: str):
6096
"""Save the validation information into a directory.
@@ -86,12 +122,17 @@ def save(self, save_dir_path: str):
86122

87123
def __len__(self):
88124
return len(self._validation_info_list)
89-
90-
def get_default_validation_prefix(model_id: str, max_new_tokens: int, batch_size: int, seq_length: int, dtype: str):
125+
126+
127+
def get_default_validation_prefix(
128+
model_id: str, max_new_tokens: int, batch_size: int, seq_length: int, dtype: str
129+
):
91130
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}"
92131

93132

94-
def load_validation_information(validation_path, validation_files_type, batch_size, tokenizer=None):
133+
def load_validation_information(
134+
validation_path, validation_files_type, batch_size, tokenizer=None
135+
):
95136
"""Load the validation information from a directory
96137
97138
The files will be assumed to be in the following structure:
@@ -107,17 +148,15 @@ def load_validation_information(validation_path, validation_files_type, batch_si
107148
if containing only tokens - torch.tensor
108149
if containing tokens and logits - dict[tokens -> torch.tensor, logits -> torch.tensor]
109150
if containing text - str
110-
151+
111152
:param validation_path: path to validation info files
112153
:param validation_files_type: validation file type to load, one of text, tokens, or logits
113154
:param batch_size: the number of prompts to load
114155
:param tokenizer: an optional tokenizer, required when validation_files_type=text
115156
:return: a new validation info
116157
"""
117158
if isinstance(validation_path, str):
118-
validation_files_path, sep, glob_pattern = validation_path.partition(
119-
"*"
120-
)
159+
validation_files_path, sep, glob_pattern = validation_path.partition("*")
121160
else:
122161
sep = ""
123162
glob_pattern = ""
@@ -146,22 +185,24 @@ def load_validation_information(validation_path, validation_files_type, batch_si
146185
validation_files_paths = [validation_files_path]
147186

148187
# Check if we found some files
149-
assert (
150-
len(validation_files_paths) > 0
151-
), f"Can't find any validation files at {validation_files_path}"
188+
assert len(validation_files_paths) > 0, (
189+
f"Can't find any validation files at {validation_files_path}"
190+
)
152191

153192
# Check if we have enough files
154-
assert (
155-
len(validation_files_paths) >= batch_size
156-
), f"Not enough validation files at {validation_files_path} for a batch size of {batch_size}"
193+
assert len(validation_files_paths) >= batch_size, (
194+
f"Not enough validation files at {validation_files_path} for a batch size of {batch_size}"
195+
)
157196

158197
validation_info = []
159198
for i, validation_file_path in enumerate(validation_files_paths):
160199
if i == batch_size:
161200
break
162201
if validation_files_type == "text":
163202
if tokenizer is None:
164-
raise ValueError("must provide a tokenizer when validation_files_type=text")
203+
raise ValueError(
204+
"must provide a tokenizer when validation_files_type=text"
205+
)
165206
# Text format will get tokenized
166207
validation_info.append(
167208
{
@@ -187,14 +228,27 @@ def load_validation_information(validation_path, validation_files_type, batch_si
187228

188229
return ValidationInfo(validation_info)
189230

190-
def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", attn_type="sdpa", **padding_kwargs):
231+
232+
def extract_validation_information(
233+
model,
234+
input_ids,
235+
max_new_tokens,
236+
post_iteration_hook,
237+
attn_algorithm=None,
238+
eos_token_id=None,
239+
only_last_token=False,
240+
timing="",
241+
attn_type="sdpa",
242+
**padding_kwargs,
243+
):
191244
max_seq_len = model.config.max_expected_seq_len
192245
attention_specific_kwargs = {}
193246
if attn_type == "paged":
194247
from aiu_fms_testing_utils.utils.paged import generate
195248
else:
196249
# TODO: Add a unified generation dependent on attn_type
197250
from fms.utils.generation import generate
251+
198252
attention_specific_kwargs["contiguous_cache"] = True
199253
attention_specific_kwargs["max_seq_len"] = max_seq_len
200254

@@ -215,7 +269,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
215269
eos_token_id=eos_token_id,
216270
timing=timing,
217271
extra_kwargs=extra_generation_kwargs,
218-
**attention_specific_kwargs
272+
**attention_specific_kwargs,
219273
)
220274

221275
if timing != "":
@@ -226,7 +280,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
226280
if timing == "e2e":
227281
dprint(f"E2E timing information: {timings[0]:.3f}s")
228282
elif timing == "per-token":
229-
timings = [f"{t*1000:.3f}" for t in timings]
283+
timings = [f"{t * 1000:.3f}" for t in timings]
230284
dprint(f"Per-token timing information: {', '.join(timings)} ms")
231285

232286
if len(result.shape) == 1:
@@ -235,75 +289,88 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
235289
if hasattr(post_iteration_hook, "extracted_logits"):
236290
validation_info = [
237291
{"tokens": t.to("cpu"), "logits": l.to("cpu")}
238-
for t, l in zip(torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits))
292+
for t, l in zip(
293+
torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits)
294+
)
239295
]
240296
else:
241297
validation_info = [{"tokens": t.to("cpu")} for t in torch.unbind(result)]
242298
return ValidationInfo(validation_info)
243299

300+
244301
def validate_level_0(aiu_tokens_per_sentence, validation_tokens_per_sentence):
245302
failed_cases = []
246303

247304
for sentence_idx, (aiu_sentence, validation_sentence) in enumerate(
248-
zip(aiu_tokens_per_sentence, validation_tokens_per_sentence)
305+
zip(aiu_tokens_per_sentence, validation_tokens_per_sentence)
249306
):
250307
for token_idx, (aiu_token, validation_token) in enumerate(
251-
zip(aiu_sentence, validation_sentence)
308+
zip(aiu_sentence, validation_sentence)
252309
):
253310
if aiu_token != validation_token:
254311
failed_cases.append((sentence_idx, token_idx))
255312
return failed_cases
256313

257-
def top_k_loss_calculator(top_k: int, loss_f: Callable[[torch.Tensor, torch.Tensor], float]):
314+
315+
def top_k_loss_calculator(
316+
top_k: int, loss_f: Callable[[torch.Tensor, torch.Tensor], float]
317+
):
258318
"""
259319
Function which will take the top_k logits indexes / values from a reference validation info and retrieve the same indexes from the test validation info logits
260320
and perform a loss function over the 2 tensors
261321
262322
:param top_k: number of values to take from reference
263323
:param loss_f: a loss function between the reference and test logits
264324
"""
325+
265326
def loss_func(reference_logits, test_logits):
266327
reference_logits_prob = reference_logits.to(dtype=torch.float32)
267328
test_logits_prob = test_logits.to(dtype=torch.float32)
268329

269-
reference_values, reference_indices = torch.topk(reference_logits_prob, top_k, dim=1)
330+
reference_values, reference_indices = torch.topk(
331+
reference_logits_prob, top_k, dim=1
332+
)
270333
test_values = test_logits_prob[:, reference_indices.squeeze(0)]
271334

272335
return loss_f(reference_values, test_values)
336+
273337
return loss_func
274338

275339

276-
def capture_level_1_metrics(reference_logits_per_sentence, test_logits_per_sentence, metrics_calculator=None):
340+
def capture_level_1_metrics(
341+
reference_logits_per_sentence, test_logits_per_sentence, metrics_calculator=None
342+
):
277343
loss_metrics = []
278344

279345
for sentence_idx, (reference_sentence, test_sentence) in enumerate(
280-
zip(reference_logits_per_sentence, test_logits_per_sentence)
346+
zip(reference_logits_per_sentence, test_logits_per_sentence)
281347
):
282348
for token_idx, (reference_logits, test_logits) in enumerate(
283-
zip(reference_sentence, test_sentence)
349+
zip(reference_sentence, test_sentence)
284350
):
285351
# computing cross entropy loss per token
286352
if metrics_calculator is None:
287353
loss_fn = torch.nn.CrossEntropyLoss()
288354
metrics_value = loss_fn(
289355
reference_logits.to(dtype=torch.float32),
290-
test_logits.softmax(dim=1).to(dtype=torch.float32)
356+
test_logits.softmax(dim=1).to(dtype=torch.float32),
291357
)
292358
else:
293359
metrics_value = metrics_calculator(reference_logits, test_logits)
294360

295361
loss_metrics.append((sentence_idx, token_idx, metrics_value))
296362

297363
return loss_metrics
298-
364+
365+
299366
def filter_failed_level_1_cases(level_1_loss_metrics, fail_f, print_failed=False):
300367
failed_cases = []
301-
for (sentence_idx, token_idx, metrics_value) in level_1_loss_metrics:
368+
for sentence_idx, token_idx, metrics_value in level_1_loss_metrics:
302369
if fail_f(metrics_value):
303370
failed_cases.append((sentence_idx, token_idx, metrics_value))
304371
if print_failed:
305372
dprint(
306-
f"In sentence {sentence_idx+1}, the metric for token {token_idx} is {metrics_value}"
373+
f"In sentence {sentence_idx + 1}, the metric for token {token_idx} is {metrics_value}"
307374
)
308375
return failed_cases
309376

@@ -313,6 +380,12 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer):
313380
aiu_token = aiu_tokens[sentence_index][token_index]
314381
validation_token = validation_tokens[sentence_index][token_index]
315382

316-
aiu_str = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(aiu_token))
317-
validation_str = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(validation_token))
318-
print(f"In sentence {sentence_index+1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}")
383+
aiu_str = tokenizer.convert_tokens_to_string(
384+
tokenizer.convert_ids_to_tokens(aiu_token)
385+
)
386+
validation_str = tokenizer.convert_tokens_to_string(
387+
tokenizer.convert_ids_to_tokens(validation_token)
388+
)
389+
print(
390+
f"In sentence {sentence_index + 1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}"
391+
)

0 commit comments

Comments
 (0)