1
1
import logging
2
+ import pathlib
2
3
import random
3
4
from collections import defaultdict
4
5
from enum import Enum
5
6
from functools import reduce
6
7
from math import inf
7
8
from pathlib import Path
8
- from typing import Dict , List , Optional , Union
9
+ from typing import Dict , List , Literal , NamedTuple , Optional , Union
9
10
11
+ from numpy import ndarray
10
12
from scipy .stats import pearsonr , spearmanr
13
+ from scipy .stats ._stats_py import PearsonRResult , SignificanceResult
11
14
from sklearn .metrics import mean_absolute_error , mean_squared_error
12
15
from torch .optim import Optimizer
13
16
from torch .utils .data import Dataset
14
17
15
18
import flair
16
- from flair .data import DT , Dictionary , Sentence , _iter_dataset
19
+ from flair .class_utils import StringLike
20
+ from flair .data import DT , Dictionary , Sentence , Token , _iter_dataset
17
21
18
- log = logging .getLogger ("flair" )
22
+ MinMax = Literal ["min" , "max" ]
23
+ logger = logging .getLogger ("flair" )
19
24
20
25
21
26
class Result :
22
27
def __init__ (
23
28
self ,
24
29
main_score : float ,
25
30
detailed_results : str ,
26
- classification_report : dict = {} ,
27
- scores : dict = {} ,
31
+ classification_report : Optional [ Dict ] = None ,
32
+ scores : Optional [ Dict ] = None ,
28
33
) -> None :
29
- assert "loss" in scores , "No loss provided."
34
+ assert scores is not None and "loss" in scores , "No loss provided."
30
35
31
36
self .main_score : float = main_score
32
37
self .scores = scores
33
38
self .detailed_results : str = detailed_results
34
- self .classification_report = classification_report
39
+ self .classification_report = classification_report if classification_report is not None else {}
35
40
36
41
@property
37
42
def loss (self ):
@@ -42,40 +47,36 @@ def __str__(self) -> str:
42
47
43
48
44
49
class MetricRegression :
45
- def __init__ (self , name ) -> None :
50
+ def __init__ (self , name : str ) -> None :
46
51
self .name = name
47
52
48
53
self .true : List [float ] = []
49
54
self .pred : List [float ] = []
50
55
51
- def mean_squared_error (self ):
56
+ def mean_squared_error (self ) -> Union [ float , ndarray ] :
52
57
return mean_squared_error (self .true , self .pred )
53
58
54
59
def mean_absolute_error (self ):
55
60
return mean_absolute_error (self .true , self .pred )
56
61
57
- def pearsonr (self ):
62
+ def pearsonr (self ) -> PearsonRResult :
58
63
return pearsonr (self .true , self .pred )[0 ]
59
64
60
- def spearmanr (self ):
65
+ def spearmanr (self ) -> SignificanceResult :
61
66
return spearmanr (self .true , self .pred )[0 ]
62
67
63
- # dummy return to fulfill trainer.train() needs
64
- def micro_avg_f_score (self ):
65
- return self .mean_squared_error ()
66
-
67
- def to_tsv (self ):
68
+ def to_tsv (self ) -> str :
68
69
return f"{ self .mean_squared_error ()} \t { self .mean_absolute_error ()} \t { self .pearsonr ()} \t { self .spearmanr ()} "
69
70
70
71
@staticmethod
71
- def tsv_header (prefix = None ):
72
+ def tsv_header (prefix : StringLike = None ) -> str :
72
73
if prefix :
73
74
return f"{ prefix } _MEAN_SQUARED_ERROR\t { prefix } _MEAN_ABSOLUTE_ERROR\t { prefix } _PEARSON\t { prefix } _SPEARMAN"
74
75
75
76
return "MEAN_SQUARED_ERROR\t MEAN_ABSOLUTE_ERROR\t PEARSON\t SPEARMAN"
76
77
77
78
@staticmethod
78
- def to_empty_tsv ():
79
+ def to_empty_tsv () -> str :
79
80
return "\t _\t _\t _\t _"
80
81
81
82
def __str__ (self ) -> str :
@@ -99,13 +100,13 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
99
100
self .weights_dict : Dict [str , Dict [int , List [float ]]] = defaultdict (lambda : defaultdict (list ))
100
101
self .number_of_weights = number_of_weights
101
102
102
- def extract_weights (self , state_dict , iteration ) :
103
+ def extract_weights (self , state_dict : Dict , iteration : int ) -> None :
103
104
for key in state_dict :
104
105
vec = state_dict [key ]
105
- # print(vec)
106
106
try :
107
107
weights_to_watch = min (self .number_of_weights , reduce (lambda x , y : x * y , list (vec .size ())))
108
- except Exception :
108
+ except Exception as e :
109
+ logger .debug (e )
109
110
continue
110
111
111
112
if key not in self .weights_dict :
@@ -193,15 +194,15 @@ class AnnealOnPlateau:
193
194
def __init__ (
194
195
self ,
195
196
optimizer ,
196
- mode = "min" ,
197
- aux_mode = "min" ,
198
- factor = 0.1 ,
199
- patience = 10 ,
200
- initial_extra_patience = 0 ,
201
- verbose = False ,
202
- cooldown = 0 ,
203
- min_lr = 0 ,
204
- eps = 1e-8 ,
197
+ mode : MinMax = "min" ,
198
+ aux_mode : MinMax = "min" ,
199
+ factor : float = 0.1 ,
200
+ patience : int = 10 ,
201
+ initial_extra_patience : int = 0 ,
202
+ verbose : bool = False ,
203
+ cooldown : int = 0 ,
204
+ min_lr : float = 0. 0 ,
205
+ eps : float = 1e-8 ,
205
206
) -> None :
206
207
if factor >= 1.0 :
207
208
raise ValueError ("Factor should be < 1.0." )
@@ -212,6 +213,7 @@ def __init__(
212
213
raise TypeError (f"{ type (optimizer ).__name__ } is not an Optimizer" )
213
214
self .optimizer = optimizer
214
215
216
+ self .min_lrs : List [float ]
215
217
if isinstance (min_lr , (list , tuple )):
216
218
if len (min_lr ) != len (optimizer .param_groups ):
217
219
raise ValueError (f"expected { len (optimizer .param_groups )} min_lrs, got { len (min_lr )} " )
@@ -229,7 +231,7 @@ def __init__(
229
231
self .best = None
230
232
self .best_aux = None
231
233
self .num_bad_epochs = None
232
- self .mode_worse = None # the worse value for the chosen mode
234
+ self .mode_worse : Optional [ float ] = None # the worse value for the chosen mode
233
235
self .eps = eps
234
236
self .last_epoch = 0
235
237
self ._init_is_better (mode = mode )
@@ -256,7 +258,7 @@ def step(self, metric, auxiliary_metric=None) -> bool:
256
258
if self .mode == "max" and current > self .best :
257
259
is_better = True
258
260
259
- if current == self .best and auxiliary_metric :
261
+ if current == self .best and auxiliary_metric is not None :
260
262
current_aux = float (auxiliary_metric )
261
263
if self .aux_mode == "min" and current_aux < self .best_aux :
262
264
is_better = True
@@ -287,20 +289,20 @@ def step(self, metric, auxiliary_metric=None) -> bool:
287
289
288
290
return reduce_learning_rate
289
291
290
- def _reduce_lr (self , epoch ) :
292
+ def _reduce_lr (self , epoch : int ) -> None :
291
293
for i , param_group in enumerate (self .optimizer .param_groups ):
292
294
old_lr = float (param_group ["lr" ])
293
295
new_lr = max (old_lr * self .factor , self .min_lrs [i ])
294
296
if old_lr - new_lr > self .eps :
295
297
param_group ["lr" ] = new_lr
296
298
if self .verbose :
297
- log .info (f" - reducing learning rate of group { epoch } to { new_lr } " )
299
+ logger .info (f" - reducing learning rate of group { epoch } to { new_lr } " )
298
300
299
301
@property
300
302
def in_cooldown (self ):
301
303
return self .cooldown_counter > 0
302
304
303
- def _init_is_better (self , mode ) :
305
+ def _init_is_better (self , mode : MinMax ) -> None :
304
306
if mode not in {"min" , "max" }:
305
307
raise ValueError ("mode " + mode + " is unknown!" )
306
308
@@ -311,10 +313,10 @@ def _init_is_better(self, mode):
311
313
312
314
self .mode = mode
313
315
314
- def state_dict (self ):
316
+ def state_dict (self ) -> Dict :
315
317
return {key : value for key , value in self .__dict__ .items () if key != "optimizer" }
316
318
317
- def load_state_dict (self , state_dict ) :
319
+ def load_state_dict (self , state_dict : Dict ) -> None :
318
320
self .__dict__ .update (state_dict )
319
321
self ._init_is_better (mode = self .mode )
320
322
@@ -348,11 +350,11 @@ def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionar
348
350
return [[1 if label in labels else 0 for label in label_dict .get_items ()] for labels in label_list ]
349
351
350
352
351
- def log_line (log ) :
353
+ def log_line (log : logging . Logger ) -> None :
352
354
log .info ("-" * 100 , stacklevel = 3 )
353
355
354
356
355
- def add_file_handler (log , output_file ) :
357
+ def add_file_handler (log : logging . Logger , output_file : pathlib . Path ) -> logging . FileHandler :
356
358
init_output_file (output_file .parents [0 ], output_file .name )
357
359
fh = logging .FileHandler (output_file , mode = "w" , encoding = "utf-8" )
358
360
fh .setLevel (logging .INFO )
@@ -363,12 +365,21 @@ def add_file_handler(log, output_file):
363
365
364
366
365
367
def store_embeddings (
366
- data_points : Union [List [DT ], Dataset ], storage_mode : str , dynamic_embeddings : Optional [List [str ]] = None
367
- ):
368
+ data_points : Union [List [DT ], Dataset ],
369
+ storage_mode : str ,
370
+ dynamic_embeddings : Optional [List [str ]] = None ,
371
+ ) -> None :
372
+ """Stores embeddings of data points in memory or on disk.
373
+
374
+ Args:
375
+ data_points: a DataSet or list of DataPoints for which embeddings should be stored
376
+ storage_mode: store in either CPU or GPU memory, or delete them if set to 'none'
377
+ dynamic_embeddings: these are always deleted. If not passed, they are identified automatically.
378
+ """
368
379
if isinstance (data_points , Dataset ):
369
380
data_points = list (_iter_dataset (data_points ))
370
381
371
- # if memory mode option 'none' delete everything
382
+ # if storage mode option 'none' delete everything
372
383
if storage_mode == "none" :
373
384
dynamic_embeddings = None
374
385
@@ -387,7 +398,7 @@ def store_embeddings(
387
398
data_point .to ("cpu" , pin_memory = pin_memory )
388
399
389
400
390
- def identify_dynamic_embeddings (data_points : List [DT ]):
401
+ def identify_dynamic_embeddings (data_points : List [DT ]) -> Optional [ List [ str ]] :
391
402
dynamic_embeddings = []
392
403
all_embeddings = []
393
404
for data_point in data_points :
@@ -407,3 +418,130 @@ def identify_dynamic_embeddings(data_points: List[DT]):
407
418
if not all_embeddings :
408
419
return None
409
420
return list (set (dynamic_embeddings ))
421
+
422
+
423
+ class TokenEntity (NamedTuple ):
424
+ """Entity represented by token indices."""
425
+
426
+ start_token_idx : int
427
+ end_token_idx : int
428
+ label : str
429
+ value : str = "" # text value of the entity
430
+ score : float = 1.0
431
+
432
+
433
+ class CharEntity (NamedTuple ):
434
+ """Entity represented by character indices."""
435
+
436
+ start_char_idx : int
437
+ end_char_idx : int
438
+ label : str
439
+ value : str
440
+ score : float = 1.0
441
+
442
+
443
+ def create_labeled_sentence_from_tokens (
444
+ tokens : Union [List [Token ]], token_entities : List [TokenEntity ], type_name : str = "ner"
445
+ ) -> Sentence :
446
+ """Creates a new Sentence object from a list of tokens or strings and applies entity labels.
447
+
448
+ Tokens are recreated with the same text, but not attached to the previous sentence.
449
+
450
+ Args:
451
+ tokens: a list of Token objects or strings - only the text is used, not any labels
452
+ token_entities: a list of TokenEntity objects representing entity annotations
453
+ type_name: the type of entity label to apply
454
+ Returns:
455
+ A labeled Sentence object
456
+ """
457
+ tokens = [Token (token .text ) for token in tokens ] # create new tokens that do not already belong to a sentence
458
+ sentence = Sentence (tokens , use_tokenizer = True )
459
+ for entity in token_entities :
460
+ sentence [entity .start_token_idx : entity .end_token_idx ].add_label (type_name , entity .label , score = entity .score )
461
+ return sentence
462
+
463
+
464
+ def create_flair_sentence (
465
+ text : str ,
466
+ entities : List [CharEntity ],
467
+ token_limit : int = 512 ,
468
+ use_context : bool = True ,
469
+ overlap : int = 0 , # TODO: implement overlap
470
+ ) -> List [Sentence ]:
471
+ """Constructs a Flair Sentence from text and a list of entity annotations.
472
+
473
+ The function explicitly tokenizes the text and labels separately, ensuring entity labels are
474
+ not partially split across tokens.
475
+
476
+ Args:
477
+ text (str): The full text to be tokenized and labeled.
478
+ entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
479
+ format (start_char_index, end_char_index, entity_class, entity_text).
480
+ token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking
481
+ use_context: whether to add context to the sentence
482
+ overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context
483
+
484
+ Returns:
485
+ A list of labeled Sentence objects representing the chunks of the original text
486
+ """
487
+ chunks = []
488
+
489
+ tokens : List [Token ] = []
490
+ current_index = 0
491
+ token_entities : List [TokenEntity ] = []
492
+ end_token_idx = 0
493
+
494
+ for entity in entities :
495
+
496
+ if entity .start_char_idx > current_index : # add non-entity text
497
+ non_entity_tokens = Sentence (text [current_index : entity .start_char_idx ]).tokens
498
+ while end_token_idx + len (non_entity_tokens ) > token_limit :
499
+ num_tokens = token_limit - len (tokens )
500
+ tokens .extend (non_entity_tokens [:num_tokens ])
501
+ non_entity_tokens = non_entity_tokens [num_tokens :]
502
+ # skip any fully negative samples, they cause fine_tune to fail with
503
+ # `torch.cat(): expected a non-empty list of Tensors`
504
+ if len (token_entities ) > 0 :
505
+ chunks .append (create_labeled_sentence_from_tokens (tokens , token_entities ))
506
+ tokens , token_entities = [], []
507
+ end_token_idx = 0
508
+ tokens .extend (non_entity_tokens )
509
+
510
+ # add new entity tokens
511
+ start_token_idx = len (tokens )
512
+ entity_sentence = Sentence (text [entity .start_char_idx : entity .end_char_idx ])
513
+ if len (entity_sentence ) > token_limit :
514
+ logger .warning (f"Entity length is greater than token limit! { len (entity_sentence )} > { token_limit } " )
515
+ end_token_idx = start_token_idx + len (entity_sentence )
516
+
517
+ if end_token_idx >= token_limit : # create chunk from existing and add this entity to next chunk
518
+ chunks .append (create_labeled_sentence_from_tokens (tokens , token_entities ))
519
+
520
+ tokens , token_entities = [], []
521
+ start_token_idx , end_token_idx = 0 , len (entity_sentence )
522
+
523
+ token_entity = TokenEntity (start_token_idx , end_token_idx , entity .label , entity .value , entity .score )
524
+ token_entities .append (token_entity )
525
+ tokens .extend (entity_sentence )
526
+
527
+ current_index = entity .end_char_idx
528
+
529
+ # add any remaining tokens to a new chunk
530
+ if current_index < len (text ):
531
+ remaining_sentence = Sentence (text [current_index :])
532
+ if end_token_idx + len (remaining_sentence ) > token_limit :
533
+ chunks .append (create_labeled_sentence_from_tokens (tokens , token_entities ))
534
+ tokens , token_entities = [], []
535
+ tokens .extend (remaining_sentence )
536
+
537
+ if tokens :
538
+ chunks .append (create_labeled_sentence_from_tokens (tokens , token_entities ))
539
+
540
+ for chunk in chunks :
541
+ if len (chunk ) > token_limit :
542
+ logger .warning (f"Chunk size is longer than token limit: { len (chunk )} > { token_limit } " )
543
+
544
+ if use_context :
545
+ Sentence .set_context_for_sentences (chunks )
546
+
547
+ return chunks
0 commit comments