Skip to content

Commit 0b23ef6

Browse files
committed
feat: add chunking function to allow sequence tagger training on sentences exceeding the token limit, including tests
1 parent e17ab12 commit 0b23ef6

File tree

5 files changed

+463
-45
lines changed

5 files changed

+463
-45
lines changed

Diff for: flair/class_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import importlib
22
import inspect
33
from types import ModuleType
4-
from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload
4+
from typing import Any, Iterable, List, Optional, Protocol, Type, TypeVar, Union, overload
55

66
T = TypeVar("T")
77

88

9+
class StringLike(Protocol):
10+
def __str__(self) -> str: ...
11+
12+
913
def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]:
1014
for subclass in cls.__subclasses__():
1115
yield from get_non_abstract_subclasses(subclass)

Diff for: flair/training_utils.py

+181-43
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,42 @@
11
import logging
2+
import pathlib
23
import random
34
from collections import defaultdict
45
from enum import Enum
56
from functools import reduce
67
from math import inf
78
from pathlib import Path
8-
from typing import Dict, List, Optional, Union
9+
from typing import Dict, List, Literal, NamedTuple, Optional, Union
910

11+
from numpy import ndarray
1012
from scipy.stats import pearsonr, spearmanr
13+
from scipy.stats._stats_py import PearsonRResult, SignificanceResult
1114
from sklearn.metrics import mean_absolute_error, mean_squared_error
1215
from torch.optim import Optimizer
1316
from torch.utils.data import Dataset
1417

1518
import flair
16-
from flair.data import DT, Dictionary, Sentence, _iter_dataset
19+
from flair.class_utils import StringLike
20+
from flair.data import DT, Dictionary, Sentence, Token, _iter_dataset
1721

18-
log = logging.getLogger("flair")
22+
MinMax = Literal["min", "max"]
23+
logger = logging.getLogger("flair")
1924

2025

2126
class Result:
2227
def __init__(
2328
self,
2429
main_score: float,
2530
detailed_results: str,
26-
classification_report: dict = {},
27-
scores: dict = {},
31+
classification_report: Optional[Dict] = None,
32+
scores: Optional[Dict] = None,
2833
) -> None:
29-
assert "loss" in scores, "No loss provided."
34+
assert scores is not None and "loss" in scores, "No loss provided."
3035

3136
self.main_score: float = main_score
3237
self.scores = scores
3338
self.detailed_results: str = detailed_results
34-
self.classification_report = classification_report
39+
self.classification_report = classification_report if classification_report is not None else {}
3540

3641
@property
3742
def loss(self):
@@ -42,40 +47,36 @@ def __str__(self) -> str:
4247

4348

4449
class MetricRegression:
45-
def __init__(self, name) -> None:
50+
def __init__(self, name: str) -> None:
4651
self.name = name
4752

4853
self.true: List[float] = []
4954
self.pred: List[float] = []
5055

51-
def mean_squared_error(self):
56+
def mean_squared_error(self) -> Union[float, ndarray]:
5257
return mean_squared_error(self.true, self.pred)
5358

5459
def mean_absolute_error(self):
5560
return mean_absolute_error(self.true, self.pred)
5661

57-
def pearsonr(self):
62+
def pearsonr(self) -> PearsonRResult:
5863
return pearsonr(self.true, self.pred)[0]
5964

60-
def spearmanr(self):
65+
def spearmanr(self) -> SignificanceResult:
6166
return spearmanr(self.true, self.pred)[0]
6267

63-
# dummy return to fulfill trainer.train() needs
64-
def micro_avg_f_score(self):
65-
return self.mean_squared_error()
66-
67-
def to_tsv(self):
68+
def to_tsv(self) -> str:
6869
return f"{self.mean_squared_error()}\t{self.mean_absolute_error()}\t{self.pearsonr()}\t{self.spearmanr()}"
6970

7071
@staticmethod
71-
def tsv_header(prefix=None):
72+
def tsv_header(prefix: StringLike = None) -> str:
7273
if prefix:
7374
return f"{prefix}_MEAN_SQUARED_ERROR\t{prefix}_MEAN_ABSOLUTE_ERROR\t{prefix}_PEARSON\t{prefix}_SPEARMAN"
7475

7576
return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN"
7677

7778
@staticmethod
78-
def to_empty_tsv():
79+
def to_empty_tsv() -> str:
7980
return "\t_\t_\t_\t_"
8081

8182
def __str__(self) -> str:
@@ -99,13 +100,13 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
99100
self.weights_dict: Dict[str, Dict[int, List[float]]] = defaultdict(lambda: defaultdict(list))
100101
self.number_of_weights = number_of_weights
101102

102-
def extract_weights(self, state_dict, iteration):
103+
def extract_weights(self, state_dict: Dict, iteration: int) -> None:
103104
for key in state_dict:
104105
vec = state_dict[key]
105-
# print(vec)
106106
try:
107107
weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size())))
108-
except Exception:
108+
except Exception as e:
109+
logger.debug(e)
109110
continue
110111

111112
if key not in self.weights_dict:
@@ -193,15 +194,15 @@ class AnnealOnPlateau:
193194
def __init__(
194195
self,
195196
optimizer,
196-
mode="min",
197-
aux_mode="min",
198-
factor=0.1,
199-
patience=10,
200-
initial_extra_patience=0,
201-
verbose=False,
202-
cooldown=0,
203-
min_lr=0,
204-
eps=1e-8,
197+
mode: MinMax = "min",
198+
aux_mode: MinMax = "min",
199+
factor: float = 0.1,
200+
patience: int = 10,
201+
initial_extra_patience: int = 0,
202+
verbose: bool = False,
203+
cooldown: int = 0,
204+
min_lr: float = 0.0,
205+
eps: float = 1e-8,
205206
) -> None:
206207
if factor >= 1.0:
207208
raise ValueError("Factor should be < 1.0.")
@@ -212,6 +213,7 @@ def __init__(
212213
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
213214
self.optimizer = optimizer
214215

216+
self.min_lrs: List[float]
215217
if isinstance(min_lr, (list, tuple)):
216218
if len(min_lr) != len(optimizer.param_groups):
217219
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
@@ -229,7 +231,7 @@ def __init__(
229231
self.best = None
230232
self.best_aux = None
231233
self.num_bad_epochs = None
232-
self.mode_worse = None # the worse value for the chosen mode
234+
self.mode_worse: Optional[float] = None # the worse value for the chosen mode
233235
self.eps = eps
234236
self.last_epoch = 0
235237
self._init_is_better(mode=mode)
@@ -256,7 +258,7 @@ def step(self, metric, auxiliary_metric=None) -> bool:
256258
if self.mode == "max" and current > self.best:
257259
is_better = True
258260

259-
if current == self.best and auxiliary_metric:
261+
if current == self.best and auxiliary_metric is not None:
260262
current_aux = float(auxiliary_metric)
261263
if self.aux_mode == "min" and current_aux < self.best_aux:
262264
is_better = True
@@ -287,20 +289,20 @@ def step(self, metric, auxiliary_metric=None) -> bool:
287289

288290
return reduce_learning_rate
289291

290-
def _reduce_lr(self, epoch):
292+
def _reduce_lr(self, epoch: int) -> None:
291293
for i, param_group in enumerate(self.optimizer.param_groups):
292294
old_lr = float(param_group["lr"])
293295
new_lr = max(old_lr * self.factor, self.min_lrs[i])
294296
if old_lr - new_lr > self.eps:
295297
param_group["lr"] = new_lr
296298
if self.verbose:
297-
log.info(f" - reducing learning rate of group {epoch} to {new_lr}")
299+
logger.info(f" - reducing learning rate of group {epoch} to {new_lr}")
298300

299301
@property
300302
def in_cooldown(self):
301303
return self.cooldown_counter > 0
302304

303-
def _init_is_better(self, mode):
305+
def _init_is_better(self, mode: MinMax) -> None:
304306
if mode not in {"min", "max"}:
305307
raise ValueError("mode " + mode + " is unknown!")
306308

@@ -311,10 +313,10 @@ def _init_is_better(self, mode):
311313

312314
self.mode = mode
313315

314-
def state_dict(self):
316+
def state_dict(self) -> Dict:
315317
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
316318

317-
def load_state_dict(self, state_dict):
319+
def load_state_dict(self, state_dict: Dict) -> None:
318320
self.__dict__.update(state_dict)
319321
self._init_is_better(mode=self.mode)
320322

@@ -348,11 +350,11 @@ def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionar
348350
return [[1 if label in labels else 0 for label in label_dict.get_items()] for labels in label_list]
349351

350352

351-
def log_line(log):
353+
def log_line(log: logging.Logger) -> None:
352354
log.info("-" * 100, stacklevel=3)
353355

354356

355-
def add_file_handler(log, output_file):
357+
def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.FileHandler:
356358
init_output_file(output_file.parents[0], output_file.name)
357359
fh = logging.FileHandler(output_file, mode="w", encoding="utf-8")
358360
fh.setLevel(logging.INFO)
@@ -363,12 +365,21 @@ def add_file_handler(log, output_file):
363365

364366

365367
def store_embeddings(
366-
data_points: Union[List[DT], Dataset], storage_mode: str, dynamic_embeddings: Optional[List[str]] = None
367-
):
368+
data_points: Union[List[DT], Dataset],
369+
storage_mode: str,
370+
dynamic_embeddings: Optional[List[str]] = None,
371+
) -> None:
372+
"""Stores embeddings of data points in memory or on disk.
373+
374+
Args:
375+
data_points: a DataSet or list of DataPoints for which embeddings should be stored
376+
storage_mode: store in either CPU or GPU memory, or delete them if set to 'none'
377+
dynamic_embeddings: these are always deleted. If not passed, they are identified automatically.
378+
"""
368379
if isinstance(data_points, Dataset):
369380
data_points = list(_iter_dataset(data_points))
370381

371-
# if memory mode option 'none' delete everything
382+
# if storage mode option 'none' delete everything
372383
if storage_mode == "none":
373384
dynamic_embeddings = None
374385

@@ -387,7 +398,7 @@ def store_embeddings(
387398
data_point.to("cpu", pin_memory=pin_memory)
388399

389400

390-
def identify_dynamic_embeddings(data_points: List[DT]):
401+
def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]:
391402
dynamic_embeddings = []
392403
all_embeddings = []
393404
for data_point in data_points:
@@ -407,3 +418,130 @@ def identify_dynamic_embeddings(data_points: List[DT]):
407418
if not all_embeddings:
408419
return None
409420
return list(set(dynamic_embeddings))
421+
422+
423+
class TokenEntity(NamedTuple):
424+
"""Entity represented by token indices."""
425+
426+
start_token_idx: int
427+
end_token_idx: int
428+
label: str
429+
value: str = "" # text value of the entity
430+
score: float = 1.0
431+
432+
433+
class CharEntity(NamedTuple):
434+
"""Entity represented by character indices."""
435+
436+
start_char_idx: int
437+
end_char_idx: int
438+
label: str
439+
value: str
440+
score: float = 1.0
441+
442+
443+
def create_labeled_sentence_from_tokens(
444+
tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner"
445+
) -> Sentence:
446+
"""Creates a new Sentence object from a list of tokens or strings and applies entity labels.
447+
448+
Tokens are recreated with the same text, but not attached to the previous sentence.
449+
450+
Args:
451+
tokens: a list of Token objects or strings - only the text is used, not any labels
452+
token_entities: a list of TokenEntity objects representing entity annotations
453+
type_name: the type of entity label to apply
454+
Returns:
455+
A labeled Sentence object
456+
"""
457+
tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence
458+
sentence = Sentence(tokens, use_tokenizer=True)
459+
for entity in token_entities:
460+
sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score)
461+
return sentence
462+
463+
464+
def create_flair_sentence(
465+
text: str,
466+
entities: List[CharEntity],
467+
token_limit: int = 512,
468+
use_context: bool = True,
469+
overlap: int = 0, # TODO: implement overlap
470+
) -> List[Sentence]:
471+
"""Constructs a Flair Sentence from text and a list of entity annotations.
472+
473+
The function explicitly tokenizes the text and labels separately, ensuring entity labels are
474+
not partially split across tokens.
475+
476+
Args:
477+
text (str): The full text to be tokenized and labeled.
478+
entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
479+
format (start_char_index, end_char_index, entity_class, entity_text).
480+
token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking
481+
use_context: whether to add context to the sentence
482+
overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context
483+
484+
Returns:
485+
A list of labeled Sentence objects representing the chunks of the original text
486+
"""
487+
chunks = []
488+
489+
tokens: List[Token] = []
490+
current_index = 0
491+
token_entities: List[TokenEntity] = []
492+
end_token_idx = 0
493+
494+
for entity in entities:
495+
496+
if entity.start_char_idx > current_index: # add non-entity text
497+
non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens
498+
while end_token_idx + len(non_entity_tokens) > token_limit:
499+
num_tokens = token_limit - len(tokens)
500+
tokens.extend(non_entity_tokens[:num_tokens])
501+
non_entity_tokens = non_entity_tokens[num_tokens:]
502+
# skip any fully negative samples, they cause fine_tune to fail with
503+
# `torch.cat(): expected a non-empty list of Tensors`
504+
if len(token_entities) > 0:
505+
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
506+
tokens, token_entities = [], []
507+
end_token_idx = 0
508+
tokens.extend(non_entity_tokens)
509+
510+
# add new entity tokens
511+
start_token_idx = len(tokens)
512+
entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx])
513+
if len(entity_sentence) > token_limit:
514+
logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}")
515+
end_token_idx = start_token_idx + len(entity_sentence)
516+
517+
if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk
518+
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
519+
520+
tokens, token_entities = [], []
521+
start_token_idx, end_token_idx = 0, len(entity_sentence)
522+
523+
token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score)
524+
token_entities.append(token_entity)
525+
tokens.extend(entity_sentence)
526+
527+
current_index = entity.end_char_idx
528+
529+
# add any remaining tokens to a new chunk
530+
if current_index < len(text):
531+
remaining_sentence = Sentence(text[current_index:])
532+
if end_token_idx + len(remaining_sentence) > token_limit:
533+
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
534+
tokens, token_entities = [], []
535+
tokens.extend(remaining_sentence)
536+
537+
if tokens:
538+
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
539+
540+
for chunk in chunks:
541+
if len(chunk) > token_limit:
542+
logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}")
543+
544+
if use_context:
545+
Sentence.set_context_for_sentences(chunks)
546+
547+
return chunks

0 commit comments

Comments
 (0)