6
6
from aiu_fms_testing_utils .utils .aiu_setup import dprint
7
7
import os
8
8
9
- class LogitsExtractorHook (Callable [[int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]], Tuple [torch .Tensor , MutableMapping [str , Any ]],]):
10
9
10
+ class LogitsExtractorHook (
11
+ Callable [
12
+ [int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]],
13
+ Tuple [torch .Tensor , MutableMapping [str , Any ]],
14
+ ]
15
+ ):
11
16
def __init__ (self ):
12
17
super ().__init__ ()
13
18
self .extracted_logits : Optional [torch .Tensor ] = None
14
19
15
- def __call__ (self , token_position : torch .Tensor , logits : torch .Tensor , next_val : torch .Tensor , kwargs ):
20
+ def __call__ (
21
+ self ,
22
+ token_position : torch .Tensor ,
23
+ logits : torch .Tensor ,
24
+ next_val : torch .Tensor ,
25
+ kwargs ,
26
+ ):
16
27
if self .extracted_logits is None :
17
28
self .extracted_logits = logits .unsqueeze (1 )
18
29
else :
19
- self .extracted_logits = torch .cat ((self .extracted_logits , logits .unsqueeze (1 )), dim = 1 )
30
+ self .extracted_logits = torch .cat (
31
+ (self .extracted_logits , logits .unsqueeze (1 )), dim = 1
32
+ )
20
33
return next_val , kwargs
21
34
22
- class StaticTokenInjectorHook (Callable [[int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]], Tuple [torch .Tensor , MutableMapping [str , Any ]],]):
23
35
24
- def __init__ (self , static_tokens : List [torch .Tensor ], device_type : str = "cpu" ):
36
+ class StaticTokenInjectorHook (
37
+ Callable [
38
+ [int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]],
39
+ Tuple [torch .Tensor , MutableMapping [str , Any ]],
40
+ ]
41
+ ):
42
+ def __init__ (self , static_tokens : List [torch .Tensor ], device_type : str = "cpu" ):
25
43
super ().__init__ ()
26
- self .static_tokens = torch .tensor (static_tokens , device = device_type ).t () # transposing so batch tokens per token_position
44
+ self .static_tokens = torch .tensor (
45
+ static_tokens , device = device_type
46
+ ).t () # transposing so batch tokens per token_position
27
47
28
- def __call__ (self , token_position : int , logits : torch .Tensor , next_val : torch .Tensor , kwargs ):
48
+ def __call__ (
49
+ self , token_position : int , logits : torch .Tensor , next_val : torch .Tensor , kwargs
50
+ ):
29
51
next_val .copy_ (self .static_tokens [token_position ].unsqueeze (1 ))
30
52
return next_val , kwargs
31
53
32
- class GoldenTokenHook (Callable [[int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]], Tuple [torch .Tensor , MutableMapping [str , Any ]],]):
33
54
34
- def __init__ (self , static_tokens : torch .Tensor , device_type : str = "cpu" ):
55
+ class GoldenTokenHook (
56
+ Callable [
57
+ [int , torch .Tensor , torch .Tensor , MutableMapping [str , Any ]],
58
+ Tuple [torch .Tensor , MutableMapping [str , Any ]],
59
+ ]
60
+ ):
61
+ def __init__ (self , static_tokens : torch .Tensor , device_type : str = "cpu" ):
35
62
super ().__init__ ()
36
63
self .logits_extractor = LogitsExtractorHook ()
37
64
self .extracted_logits = None
38
- self .token_injector = StaticTokenInjectorHook (static_tokens , device_type = device_type )
65
+ self .token_injector = StaticTokenInjectorHook (
66
+ static_tokens , device_type = device_type
67
+ )
39
68
40
- def __call__ (self , token_position : int , logits : torch .Tensor , next_val : torch .Tensor , kwargs ):
41
- next_val , kwargs = self .logits_extractor (token_position , logits , next_val , kwargs )
69
+ def __call__ (
70
+ self , token_position : int , logits : torch .Tensor , next_val : torch .Tensor , kwargs
71
+ ):
72
+ next_val , kwargs = self .logits_extractor (
73
+ token_position , logits , next_val , kwargs
74
+ )
42
75
self .extracted_logits = self .logits_extractor .extracted_logits
43
76
return self .token_injector (token_position , logits , next_val , kwargs )
44
77
45
- class ValidationInfo :
46
78
79
+ class ValidationInfo :
47
80
def __init__ (self , validation_info_list ):
48
81
super ().__init__ ()
49
82
@@ -54,7 +87,10 @@ def __iter__(self):
54
87
yield vi
55
88
56
89
def get_info (self , info_name ):
57
- return [[t .unsqueeze (0 ) for t in sentence [info_name ]] for sentence in self ._validation_info_list ]
90
+ return [
91
+ [t .unsqueeze (0 ) for t in sentence [info_name ]]
92
+ for sentence in self ._validation_info_list
93
+ ]
58
94
59
95
def save (self , save_dir_path : str ):
60
96
"""Save the validation information into a directory.
@@ -86,12 +122,17 @@ def save(self, save_dir_path: str):
86
122
87
123
def __len__ (self ):
88
124
return len (self ._validation_info_list )
89
-
90
- def get_default_validation_prefix (model_id : str , max_new_tokens : int , batch_size : int , seq_length : int , dtype : str ):
125
+
126
+
127
+ def get_default_validation_prefix (
128
+ model_id : str , max_new_tokens : int , batch_size : int , seq_length : int , dtype : str
129
+ ):
91
130
return f"{ model_id .replace ('/' , '--' )} _max-new-tokens-{ max_new_tokens } _batch-size-{ batch_size } _seq-length-{ seq_length } _dtype-{ dtype } "
92
131
93
132
94
- def load_validation_information (validation_path , validation_files_type , batch_size , tokenizer = None ):
133
+ def load_validation_information (
134
+ validation_path , validation_files_type , batch_size , tokenizer = None
135
+ ):
95
136
"""Load the validation information from a directory
96
137
97
138
The files will be assumed to be in the following structure:
@@ -107,17 +148,15 @@ def load_validation_information(validation_path, validation_files_type, batch_si
107
148
if containing only tokens - torch.tensor
108
149
if containing tokens and logits - dict[tokens -> torch.tensor, logits -> torch.tensor]
109
150
if containing text - str
110
-
151
+
111
152
:param validation_path: path to validation info files
112
153
:param validation_files_type: validation file type to load, one of text, tokens, or logits
113
154
:param batch_size: the number of prompts to load
114
155
:param tokenizer: an optional tokenizer, required when validation_files_type=text
115
156
:return: a new validation info
116
157
"""
117
158
if isinstance (validation_path , str ):
118
- validation_files_path , sep , glob_pattern = validation_path .partition (
119
- "*"
120
- )
159
+ validation_files_path , sep , glob_pattern = validation_path .partition ("*" )
121
160
else :
122
161
sep = ""
123
162
glob_pattern = ""
@@ -146,22 +185,24 @@ def load_validation_information(validation_path, validation_files_type, batch_si
146
185
validation_files_paths = [validation_files_path ]
147
186
148
187
# Check if we found some files
149
- assert (
150
- len ( validation_files_paths ) > 0
151
- ), f"Can't find any validation files at { validation_files_path } "
188
+ assert len ( validation_files_paths ) > 0 , (
189
+ f"Can't find any validation files at { validation_files_path } "
190
+ )
152
191
153
192
# Check if we have enough files
154
- assert (
155
- len ( validation_files_paths ) >= batch_size
156
- ), f"Not enough validation files at { validation_files_path } for a batch size of { batch_size } "
193
+ assert len ( validation_files_paths ) >= batch_size , (
194
+ f"Not enough validation files at { validation_files_path } for a batch size of { batch_size } "
195
+ )
157
196
158
197
validation_info = []
159
198
for i , validation_file_path in enumerate (validation_files_paths ):
160
199
if i == batch_size :
161
200
break
162
201
if validation_files_type == "text" :
163
202
if tokenizer is None :
164
- raise ValueError ("must provide a tokenizer when validation_files_type=text" )
203
+ raise ValueError (
204
+ "must provide a tokenizer when validation_files_type=text"
205
+ )
165
206
# Text format will get tokenized
166
207
validation_info .append (
167
208
{
@@ -187,14 +228,27 @@ def load_validation_information(validation_path, validation_files_type, batch_si
187
228
188
229
return ValidationInfo (validation_info )
189
230
190
- def extract_validation_information (model , input_ids , max_new_tokens , post_iteration_hook , attn_algorithm = None , eos_token_id = None , only_last_token = False , timing = "" , attn_type = "sdpa" , ** padding_kwargs ):
231
+
232
+ def extract_validation_information (
233
+ model ,
234
+ input_ids ,
235
+ max_new_tokens ,
236
+ post_iteration_hook ,
237
+ attn_algorithm = None ,
238
+ eos_token_id = None ,
239
+ only_last_token = False ,
240
+ timing = "" ,
241
+ attn_type = "sdpa" ,
242
+ ** padding_kwargs ,
243
+ ):
191
244
max_seq_len = model .config .max_expected_seq_len
192
245
attention_specific_kwargs = {}
193
246
if attn_type == "paged" :
194
247
from aiu_fms_testing_utils .utils .paged import generate
195
248
else :
196
249
# TODO: Add a unified generation dependent on attn_type
197
250
from fms .utils .generation import generate
251
+
198
252
attention_specific_kwargs ["contiguous_cache" ] = True
199
253
attention_specific_kwargs ["max_seq_len" ] = max_seq_len
200
254
@@ -215,7 +269,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
215
269
eos_token_id = eos_token_id ,
216
270
timing = timing ,
217
271
extra_kwargs = extra_generation_kwargs ,
218
- ** attention_specific_kwargs
272
+ ** attention_specific_kwargs ,
219
273
)
220
274
221
275
if timing != "" :
@@ -226,7 +280,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
226
280
if timing == "e2e" :
227
281
dprint (f"E2E timing information: { timings [0 ]:.3f} s" )
228
282
elif timing == "per-token" :
229
- timings = [f"{ t * 1000 :.3f} " for t in timings ]
283
+ timings = [f"{ t * 1000 :.3f} " for t in timings ]
230
284
dprint (f"Per-token timing information: { ', ' .join (timings )} ms" )
231
285
232
286
if len (result .shape ) == 1 :
@@ -235,75 +289,88 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
235
289
if hasattr (post_iteration_hook , "extracted_logits" ):
236
290
validation_info = [
237
291
{"tokens" : t .to ("cpu" ), "logits" : l .to ("cpu" )}
238
- for t , l in zip (torch .unbind (result ), torch .unbind (post_iteration_hook .extracted_logits ))
292
+ for t , l in zip (
293
+ torch .unbind (result ), torch .unbind (post_iteration_hook .extracted_logits )
294
+ )
239
295
]
240
296
else :
241
297
validation_info = [{"tokens" : t .to ("cpu" )} for t in torch .unbind (result )]
242
298
return ValidationInfo (validation_info )
243
299
300
+
244
301
def validate_level_0 (aiu_tokens_per_sentence , validation_tokens_per_sentence ):
245
302
failed_cases = []
246
303
247
304
for sentence_idx , (aiu_sentence , validation_sentence ) in enumerate (
248
- zip (aiu_tokens_per_sentence , validation_tokens_per_sentence )
305
+ zip (aiu_tokens_per_sentence , validation_tokens_per_sentence )
249
306
):
250
307
for token_idx , (aiu_token , validation_token ) in enumerate (
251
- zip (aiu_sentence , validation_sentence )
308
+ zip (aiu_sentence , validation_sentence )
252
309
):
253
310
if aiu_token != validation_token :
254
311
failed_cases .append ((sentence_idx , token_idx ))
255
312
return failed_cases
256
313
257
- def top_k_loss_calculator (top_k : int , loss_f : Callable [[torch .Tensor , torch .Tensor ], float ]):
314
+
315
+ def top_k_loss_calculator (
316
+ top_k : int , loss_f : Callable [[torch .Tensor , torch .Tensor ], float ]
317
+ ):
258
318
"""
259
319
Function which will take the top_k logits indexes / values from a reference validation info and retrieve the same indexes from the test validation info logits
260
320
and perform a loss function over the 2 tensors
261
321
262
322
:param top_k: number of values to take from reference
263
323
:param loss_f: a loss function between the reference and test logits
264
324
"""
325
+
265
326
def loss_func (reference_logits , test_logits ):
266
327
reference_logits_prob = reference_logits .to (dtype = torch .float32 )
267
328
test_logits_prob = test_logits .to (dtype = torch .float32 )
268
329
269
- reference_values , reference_indices = torch .topk (reference_logits_prob , top_k , dim = 1 )
330
+ reference_values , reference_indices = torch .topk (
331
+ reference_logits_prob , top_k , dim = 1
332
+ )
270
333
test_values = test_logits_prob [:, reference_indices .squeeze (0 )]
271
334
272
335
return loss_f (reference_values , test_values )
336
+
273
337
return loss_func
274
338
275
339
276
- def capture_level_1_metrics (reference_logits_per_sentence , test_logits_per_sentence , metrics_calculator = None ):
340
+ def capture_level_1_metrics (
341
+ reference_logits_per_sentence , test_logits_per_sentence , metrics_calculator = None
342
+ ):
277
343
loss_metrics = []
278
344
279
345
for sentence_idx , (reference_sentence , test_sentence ) in enumerate (
280
- zip (reference_logits_per_sentence , test_logits_per_sentence )
346
+ zip (reference_logits_per_sentence , test_logits_per_sentence )
281
347
):
282
348
for token_idx , (reference_logits , test_logits ) in enumerate (
283
- zip (reference_sentence , test_sentence )
349
+ zip (reference_sentence , test_sentence )
284
350
):
285
351
# computing cross entropy loss per token
286
352
if metrics_calculator is None :
287
353
loss_fn = torch .nn .CrossEntropyLoss ()
288
354
metrics_value = loss_fn (
289
355
reference_logits .to (dtype = torch .float32 ),
290
- test_logits .softmax (dim = 1 ).to (dtype = torch .float32 )
356
+ test_logits .softmax (dim = 1 ).to (dtype = torch .float32 ),
291
357
)
292
358
else :
293
359
metrics_value = metrics_calculator (reference_logits , test_logits )
294
360
295
361
loss_metrics .append ((sentence_idx , token_idx , metrics_value ))
296
362
297
363
return loss_metrics
298
-
364
+
365
+
299
366
def filter_failed_level_1_cases (level_1_loss_metrics , fail_f , print_failed = False ):
300
367
failed_cases = []
301
- for ( sentence_idx , token_idx , metrics_value ) in level_1_loss_metrics :
368
+ for sentence_idx , token_idx , metrics_value in level_1_loss_metrics :
302
369
if fail_f (metrics_value ):
303
370
failed_cases .append ((sentence_idx , token_idx , metrics_value ))
304
371
if print_failed :
305
372
dprint (
306
- f"In sentence { sentence_idx + 1 } , the metric for token { token_idx } is { metrics_value } "
373
+ f"In sentence { sentence_idx + 1 } , the metric for token { token_idx } is { metrics_value } "
307
374
)
308
375
return failed_cases
309
376
@@ -313,6 +380,12 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer):
313
380
aiu_token = aiu_tokens [sentence_index ][token_index ]
314
381
validation_token = validation_tokens [sentence_index ][token_index ]
315
382
316
- aiu_str = tokenizer .convert_tokens_to_string (tokenizer .convert_ids_to_tokens (aiu_token ))
317
- validation_str = tokenizer .convert_tokens_to_string (tokenizer .convert_ids_to_tokens (validation_token ))
318
- print (f"In sentence { sentence_index + 1 } /{ len (aiu_tokens )} , token { token_index } , AIU outputs { aiu_token } instead of { validation_token } -- AIU val={ aiu_str } -- CPU val={ validation_str } " )
383
+ aiu_str = tokenizer .convert_tokens_to_string (
384
+ tokenizer .convert_ids_to_tokens (aiu_token )
385
+ )
386
+ validation_str = tokenizer .convert_tokens_to_string (
387
+ tokenizer .convert_ids_to_tokens (validation_token )
388
+ )
389
+ print (
390
+ f"In sentence { sentence_index + 1 } /{ len (aiu_tokens )} , token { token_index } , AIU outputs { aiu_token } instead of { validation_token } -- AIU val={ aiu_str } -- CPU val={ validation_str } "
391
+ )
0 commit comments