Skip to content

Commit 39ec21e

Browse files
authored
Merge pull request #3641 from flairNLP/GH-3635-lazy-tokenization
Lazy Tokenization in Flair
2 parents fd5d1ad + 4999a4b commit 39ec21e

File tree

8 files changed

+311
-106
lines changed

8 files changed

+311
-106
lines changed

flair/data.py

+172-76
Original file line numberDiff line numberDiff line change
@@ -818,22 +818,21 @@ def __init__(
818818
"""
819819
super().__init__()
820820

821-
self.tokens: list[Token] = []
821+
self._tokens: Optional[list[Token]] = None
822+
self._text: str = "" # Change from Optional[str] to str with empty string default
822823

823-
# private field for all known spans
824-
self._known_spans: dict[str, _PartOfSentence] = {}
824+
# private field for all known spans with explicit typing
825+
self._known_spans: dict[str, Union[Span, Relation]] = {}
825826

826827
self.language_code: Optional[str] = language_code
827828

828829
self._start_position = start_position
829830

830831
# the tokenizer used for this sentence
831832
if isinstance(use_tokenizer, Tokenizer):
832-
tokenizer = use_tokenizer
833-
833+
self._tokenizer = use_tokenizer
834834
elif isinstance(use_tokenizer, bool):
835-
tokenizer = SegtokTokenizer() if use_tokenizer else SpaceTokenizer()
836-
835+
self._tokenizer = SegtokTokenizer() if use_tokenizer else SpaceTokenizer()
837836
else:
838837
raise AssertionError("Unexpected type of parameter 'use_tokenizer'. Parameter should be bool or Tokenizer")
839838

@@ -848,24 +847,79 @@ def __init__(
848847
self._next_sentence: Optional[Sentence] = None
849848
self._position_in_dataset: Optional[tuple[Dataset, int]] = None
850849

851-
# if text is passed, instantiate sentence with tokens (words)
852-
if isinstance(text, str):
853-
text = Sentence._handle_problem_characters(text)
854-
words = tokenizer.tokenize(text)
855-
elif text and isinstance(text[0], Token):
856-
for t in text:
857-
self._add_token(t)
858-
self.tokens[-1].whitespace_after = 0
859-
return
850+
# if list of strings or tokens is passed, create tokens directly
851+
if not isinstance(text, str):
852+
self._tokens = []
853+
854+
# First construct the text from tokens to ensure proper text reconstruction
855+
if len(text) > 0:
856+
# Type check the input list and cast
857+
if all(isinstance(t, Token) for t in text):
858+
tokens = cast(list[Token], text)
859+
reconstructed_text = ""
860+
for i, token in enumerate(tokens):
861+
reconstructed_text += token.text
862+
if i < len(tokens) - 1: # Add whitespace between tokens
863+
reconstructed_text += " " * token.whitespace_after
864+
self._text = reconstructed_text
865+
elif all(isinstance(t, str) for t in text):
866+
strings = cast(list[str], text)
867+
self._text = " ".join(strings)
868+
else:
869+
raise TypeError("All elements must be either Token or str")
870+
else:
871+
self._text = ""
872+
873+
# Now add the tokens
874+
current_position = 0
875+
for i, item in enumerate(text):
876+
# create Token if string, otherwise use existing Token
877+
if isinstance(item, str):
878+
# For strings, create new Token with default whitespace
879+
token = Token(text=item)
880+
token.whitespace_after = 0 if i == len(text) - 1 else 1
881+
elif isinstance(item, Token):
882+
# For existing Tokens, preserve their whitespace_after
883+
token = item
884+
885+
# Set start position for the token
886+
token.start_position = current_position
887+
current_position += len(token.text) + token.whitespace_after
888+
889+
self._add_token(token)
890+
891+
if len(text) > 0:
892+
# convention: the last token has no whitespace after
893+
self.tokens[-1].whitespace_after = 0
860894
else:
861-
words = cast(list[str], text)
862-
text = " ".join(words)
895+
self._text = Sentence._handle_problem_characters(text)
896+
897+
# log a warning if the dataset is empty
898+
if self._text == "":
899+
log.warning("Warning: An empty Sentence was created! Are there empty strings in your dataset?")
900+
901+
@property
902+
def tokens(self) -> list[Token]:
903+
"""Gets the tokens of this sentence. Automatically triggers tokenization if not yet tokenized."""
904+
if self._tokens is None:
905+
self._tokenize()
906+
if self._tokens is None:
907+
raise ValueError("Tokens are None after tokenization - this indicates a bug in the tokenization process")
908+
return self._tokens
909+
910+
def _tokenize(self) -> None:
911+
"""Internal method that performs tokenization."""
912+
913+
# tokenize the text
914+
words = self._tokenizer.tokenize(self._text)
863915

864916
# determine token positions and whitespace_after flag
865917
current_offset: int = 0
866918
previous_token: Optional[Token] = None
919+
self._tokens = []
920+
867921
for word in words:
868-
word_start_position: int = text.index(word, current_offset)
922+
word_start_position: int = self._text.index(word, current_offset)
869923
delta_offset: int = word_start_position - current_offset
870924

871925
token: Token = Token(text=word, start_position=word_start_position)
@@ -878,17 +932,56 @@ def __init__(
878932
previous_token = token
879933

880934
# the last token has no whitespace after
881-
if len(self) > 0:
882-
self.tokens[-1].whitespace_after = 0
935+
if len(self._tokens) > 0:
936+
self._tokens[-1].whitespace_after = 0
883937

884-
# log a warning if the dataset is empty
885-
if text == "":
886-
log.warning("Warning: An empty Sentence was created! Are there empty strings in your dataset?")
938+
def __iter__(self):
939+
"""Allows iteration over tokens. Triggers tokenization if not yet tokenized."""
940+
return iter(self.tokens)
941+
942+
def __len__(self) -> int:
943+
"""Returns the number of tokens in this sentence. Triggers tokenization if not yet tokenized."""
944+
return len(self.tokens)
887945

888946
@property
889947
def unlabeled_identifier(self):
890948
return f'Sentence[{len(self)}]: "{self.text}"'
891949

950+
@property
951+
def text(self) -> str:
952+
"""Returns the original text of this sentence. Does not trigger tokenization."""
953+
return self._text
954+
955+
def to_original_text(self) -> str:
956+
"""Returns the original text of this sentence."""
957+
return self._text
958+
959+
def to_tagged_string(self, main_label: Optional[str] = None) -> str:
960+
# For sentence-level labels, we don't need tokenization
961+
if not self._tokens:
962+
output = f'Sentence: "{self.text}"'
963+
if self.labels:
964+
output += self._printout_labels(main_label)
965+
return output
966+
967+
# Only tokenize if we have token-level labels or spans to print
968+
already_printed = [self]
969+
output = super().__str__()
970+
971+
label_append = []
972+
for label in self.get_labels(main_label):
973+
if label.data_point in already_printed:
974+
continue
975+
label_append.append(
976+
f'"{label.data_point.text}"{label.data_point._printout_labels(main_label=main_label, add_score=False)}'
977+
)
978+
already_printed.append(label.data_point)
979+
980+
if len(label_append) > 0:
981+
output += f"{flair._arrow}[" + ", ".join(label_append) + "]"
982+
983+
return output
984+
892985
def get_relations(self, label_type: Optional[str] = None) -> list[Relation]:
893986
relations: list[Relation] = []
894987
for label in self.get_labels(label_type):
@@ -951,11 +1044,13 @@ def to(self, device: str, pin_memory: bool = False):
9511044
token.to(device, pin_memory)
9521045

9531046
def clear_embeddings(self, embedding_names: Optional[list[str]] = None):
1047+
# clear sentence embeddings
9541048
super().clear_embeddings(embedding_names)
9551049

956-
# clear token embeddings
957-
for token in self:
958-
token.clear_embeddings(embedding_names)
1050+
# clear token embeddings if sentence is tokenized
1051+
if self._is_tokenized():
1052+
for token in self.tokens:
1053+
token.clear_embeddings(embedding_names)
9591054

9601055
def left_context(self, context_length: int, respect_document_boundaries: bool = True) -> list[Token]:
9611056
sentence = self
@@ -987,29 +1082,6 @@ def right_context(self, context_length: int, respect_document_boundaries: bool =
9871082
def __str__(self) -> str:
9881083
return self.to_tagged_string()
9891084

990-
def to_tagged_string(self, main_label: Optional[str] = None) -> str:
991-
already_printed = [self]
992-
993-
output = super().__str__()
994-
995-
label_append = []
996-
for label in self.get_labels(main_label):
997-
if label.data_point in already_printed:
998-
continue
999-
label_append.append(
1000-
f'"{label.data_point.text}"{label.data_point._printout_labels(main_label=main_label, add_score=False)}'
1001-
)
1002-
already_printed.append(label.data_point)
1003-
1004-
if len(label_append) > 0:
1005-
output += f"{flair._arrow}[" + ", ".join(label_append) + "]"
1006-
1007-
return output
1008-
1009-
@property
1010-
def text(self) -> str:
1011-
return self.to_original_text()
1012-
10131085
def to_tokenized_string(self) -> str:
10141086
if self.tokenized is None:
10151087
self.tokenized = " ".join([t.text for t in self.tokens])
@@ -1056,15 +1128,6 @@ def infer_space_after(self):
10561128
last_token = token
10571129
return self
10581130

1059-
def to_original_text(self) -> str:
1060-
# if sentence has no tokens, return empty string
1061-
if len(self) == 0:
1062-
return ""
1063-
# otherwise, return concatenation of tokens with the correct offsets
1064-
return (self[0].start_position - self.start_position) * " " + "".join(
1065-
[t.text + t.whitespace_after * " " for t in self.tokens]
1066-
).strip()
1067-
10681131
def to_dict(self, tag_type: Optional[str] = None) -> dict[str, Any]:
10691132
return {
10701133
"text": self.to_original_text(),
@@ -1090,12 +1153,6 @@ def __getitem__(self, subscript):
10901153
else:
10911154
return self.tokens[subscript]
10921155

1093-
def __iter__(self):
1094-
return iter(self.tokens)
1095-
1096-
def __len__(self) -> int:
1097-
return len(self.tokens)
1098-
10991156
def __repr__(self) -> str:
11001157
return self.__str__()
11011158

@@ -1233,20 +1290,59 @@ def get_labels(self, label_type: Optional[str] = None):
12331290
return []
12341291

12351292
def remove_labels(self, typename: str):
1236-
# labels also need to be deleted at all tokens
1237-
for token in self:
1238-
token.remove_labels(typename)
1239-
1240-
# labels also need to be deleted at all known spans
1241-
for span in self._known_spans.values():
1242-
span.remove_labels(typename)
1293+
# only access tokens if already tokenized
1294+
if self._is_tokenized():
1295+
# labels also need to be deleted at all tokens
1296+
for token in self.tokens:
1297+
token.remove_labels(typename)
12431298

1244-
# remove spans without labels
1245-
self._known_spans = {k: v for k, v in self._known_spans.items() if len(v.labels) > 0}
1299+
# labels also need to be deleted at all known spans
1300+
for span in self._known_spans.values():
1301+
span.remove_labels(typename)
12461302

1247-
# delete labels at object itself
1303+
# delete labels at object itself first
12481304
super().remove_labels(typename)
12491305

1306+
def _is_tokenized(self) -> bool:
1307+
return self._tokens is not None
1308+
1309+
def truncate(self, max_tokens: int) -> None:
1310+
"""Truncates the sentence to a maximum number of tokens and updates all annotations accordingly."""
1311+
if len(self.tokens) <= max_tokens:
1312+
return
1313+
1314+
# Truncate tokens
1315+
self._tokens = self.tokens[:max_tokens]
1316+
1317+
# Remove spans that reference removed tokens
1318+
self._known_spans = {
1319+
identifier: span
1320+
for identifier, span in self._known_spans.items()
1321+
if isinstance(span, Span) and all(token.idx <= max_tokens for token in span.tokens)
1322+
}
1323+
1324+
# Remove relations that reference removed spans
1325+
self._known_spans = {
1326+
identifier: relation
1327+
for identifier, relation in self._known_spans.items()
1328+
if not isinstance(relation, Relation)
1329+
or (
1330+
all(token.idx <= max_tokens for token in relation.first.tokens)
1331+
and all(token.idx <= max_tokens for token in relation.second.tokens)
1332+
)
1333+
}
1334+
1335+
# Clean up any labels that reference removed spans/relations
1336+
for typename in list(self.annotation_layers.keys()):
1337+
self.annotation_layers[typename] = [
1338+
label
1339+
for label in self.annotation_layers[typename]
1340+
if (
1341+
not isinstance(label.data_point, (Span, Relation))
1342+
or label.data_point.unlabeled_identifier in self._known_spans
1343+
)
1344+
]
1345+
12501346

12511347
class DataPair(DataPoint, typing.Generic[DT, DT2]):
12521348
def __init__(self, first: DT, second: DT2) -> None:
@@ -1375,7 +1471,7 @@ class Corpus(typing.Generic[T_co]):
13751471
"""The main object in Flair for holding a dataset used for training and testing.
13761472
13771473
A corpus consists of three splits: A `train` split used for training, a `dev` split used for model selection
1378-
and/or early stopping and a `test` split used for testing. All three splits are optional, so it is possible
1474+
or early stopping and a `test` split used for testing. All three splits are optional, so it is possible
13791475
to create a corpus only using one or two splits. If the option `sample_missing_splits` is set to True,
13801476
missing splits will be randomly sampled from the training split.
13811477
"""

flair/datasets/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _parse_document_to_sentence(
207207
sentence.add_label(self.tag_type, label)
208208

209209
if self.max_tokens_per_doc > 0:
210-
sentence.tokens = sentence.tokens[: min(len(sentence), self.max_tokens_per_doc)]
210+
sentence.truncate(self.max_tokens_per_doc)
211211

212212
return sentence
213213
return None

flair/datasets/document_classification.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,15 @@ def _parse_line_to_sentence(self, line: str, label_prefix: str, tokenizer: Union
268268
if text and (labels or self.allow_examples_without_labels):
269269
sentence = Sentence(text, use_tokenizer=tokenizer)
270270

271+
if 0 < self.truncate_to_max_tokens < len(sentence):
272+
# Create new sentence with truncated text
273+
truncated_text = " ".join(token.text for token in sentence.tokens[: self.truncate_to_max_tokens])
274+
sentence = Sentence(truncated_text, use_tokenizer=tokenizer)
275+
276+
# Add the labels
271277
for label in labels:
272278
sentence.add_label(self.label_type, label)
273279

274-
if sentence is not None and 0 < self.truncate_to_max_tokens < len(sentence):
275-
sentence.tokens = sentence.tokens[: self.truncate_to_max_tokens]
276-
277280
return sentence
278281
return None
279282

0 commit comments

Comments
 (0)