Skip to content

Add option to retokenize ColumnCorpus #3658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 172 additions & 72 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,92 +1697,180 @@ def truncate(self, max_tokens: int) -> None:
)
]

def _clear_internal_state(self) -> None:
"""
Resets the internal tokenization and annotation state of the sentence.
Used before operations like retokenization that rebuild the sentence structure.
"""
# Clear the central annotation registry
self.annotation_layers: dict[str, list[Label]] = {}
# Clear token list
self._tokens = []
# Clear known spans/relations cache
self._known_spans = {}
# Reset cached tokenized string representation
self.tokenized = None

def retokenize(self, tokenizer):
"""
Retokenizes the sentence using the provided tokenizer while preserving span labels.
Retokenizes the sentence using the provided tokenizer while attempting to preserve
span, relation, and sentence labels. Token-level labels are discarded.

Note: Relation preservation depends on successfully re-mapping both constituent spans
based on character offsets, which might fail if tokenization changes boundaries significantly.

Args:
tokenizer: The tokenizer to use for retokenization
"""
# --- Step 0: Initial Setup ---
import logging # Ensure logging is available

Example::

# Create a sentence with default tokenization
sentence = Sentence("01-03-2025 New York")

# Add span labels
sentence.get_span(1, 3).add_label('ner', "LOC")
sentence.get_span(0, 1).add_label('ner', "DATE")
log = logging.getLogger("flair") # Use Flair's logger

# Retokenize with a different tokenizer while preserving labels
sentence.retokenize(StaccatoTokenizer())
"""
# Store the original text
original_text = self.to_original_text()

# Save all span-level labels with their text spans and character positions
span_labels = {}
for label_type in list(self.annotation_layers.keys()):
spans = self.get_spans(label_type)
if spans:
if label_type not in span_labels:
span_labels[label_type] = []

for span in spans:
# Store the span text, character positions, and its labels
span_labels[label_type].append(
(
span.text,
span.start_position,
span.end_position,
[label.value for label in span.labels],
[label.score for label in span.labels],
# --- Step 1: Save Sentence-Level Labels ---
sentence_level_labels = [
(label.typename, label.value, label.score, label.metadata)
for label in self.labels
if label.data_point is self
]

# --- Step 2: Save Span and Relation Info ---
original_known_spans_relations = list(self._known_spans.values())
span_data_to_reapply = {}
relations_to_reapply = []

# Identify all label types associated with spans and relations for later clearing
span_relation_label_types = set()

for dp in original_known_spans_relations:
if isinstance(dp, Span):
span_id = dp.unlabeled_identifier
if span_id not in span_data_to_reapply:
span_data_to_reapply[span_id] = {
"text": dp.text,
"start": dp.start_position,
"end": dp.end_position,
"labels": [],
}
# Save only labels actually belonging to this span object
for label in dp.labels:
if label.data_point is dp:
span_data_to_reapply[span_id]["labels"].append(
(label.typename, label.value, label.score, label.metadata)
)
)
span_relation_label_types.add(label.typename) # Track type

elif isinstance(dp, Relation):
relation_info = {
"first_span_id": dp.first.unlabeled_identifier,
"second_span_id": dp.second.unlabeled_identifier,
"labels": [],
}
for label in dp.labels:
if label.data_point is dp:
relation_info["labels"].append((label.typename, label.value, label.score, label.metadata))
span_relation_label_types.add(label.typename) # Track type
relations_to_reapply.append(relation_info)

# --- Step 3: Clear Internal State ---
self._clear_internal_state()

# --- Step 4: Retokenize ---
temp_sentence = Sentence(original_text, use_tokenizer=tokenizer)
self._tokens = []
self._known_spans = {} # CRITICAL: Clear known spans cache before reconstruction
self.tokenized = None

# Remove all labels of this type
self.remove_labels(label_type)
for token in temp_sentence.tokens:
token.sentence = self
token._internal_index = len(self._tokens) + 1
self._tokens.append(token)

# --- Step 5: Reconstruct Spans and Build Mapping ---
reconstructed_span_map = {} # Map: original_span_identifier -> new_span_object

for original_span_id, span_data in span_data_to_reapply.items():
start_pos = span_data["start"]
end_pos = span_data["end"]

token_indices = []
# Find tokens based on character overlap
for i, token in enumerate(self.tokens):
token_start = token.start_position
token_end = token.end_position
# Check if token is within or overlaps with the span
# A token is part of the span if:
# 1. It starts within the span, or
# 2. It ends within the span, or
# 3. It completely contains the span
if (
(token_start >= start_pos and token_start < end_pos)
or (token_end > start_pos and token_end <= end_pos)
or (token_start <= start_pos and token_end >= end_pos)
):
token_indices.append(i)

if token_indices:
span_start_idx = min(token_indices)
span_end_idx = max(token_indices) + 1

# Get/Create the new span using slicing (handles caching via __new__)
new_span = self[span_start_idx:span_end_idx]

# Add the saved labels back to this new span object.
# add_label propagates to the sentence layer.
for typename, value, score, metadata in span_data["labels"]:
new_span.add_label(typename, value, score, **metadata)

# Add mapping from original ID to the NEW span object
reconstructed_span_map[original_span_id] = new_span
else:
log.warning(
f"Could not map original span '{original_span_id}' with text '{span_data['text']}' to new tokens after retokenization."
)

# Create a new sentence with the same text but using the new tokenizer
new_sentence = Sentence(original_text, use_tokenizer=tokenizer)
# --- Step 6: Reconstruct Relations ---
for relation_info in relations_to_reapply:
original_first_id = relation_info["first_span_id"]
original_second_id = relation_info["second_span_id"]

# Replace the tokens in the current sentence with the tokens from the new sentence
self.tokens.clear()
for token in new_sentence.tokens:
self.tokens.append(token)
# Update the token's sentence reference to point to this sentence
token.sentence = self
# Find the corresponding NEW spans using the map built in Step 4
new_first_span = reconstructed_span_map.get(original_first_id)
new_second_span = reconstructed_span_map.get(original_second_id)

if new_first_span and new_second_span:
# If both constituent spans were successfully reconstructed, create the relation
# Relation.__new__ handles caching in self._known_spans
new_relation = Relation(new_first_span, new_second_span)

# Add the saved relation labels back using add_label for propagation
for typename, value, score, metadata in relation_info["labels"]:
new_relation.add_label(typename, value, score, **metadata)
else:
# Log warning if relation couldn't be reconstructed
log.warning(
f"Could not reconstruct relation between original spans '{original_first_id}' and "
f"'{original_second_id}' because one or both spans failed to map after retokenization."
)

# Reapply span labels based on character positions
for label_type, spans in span_labels.items():
for span_text, start_pos, end_pos, label_values, label_scores in spans:
# Find tokens that are fully or partially contained within the span
token_indices = []

for i, token in enumerate(self.tokens):
# Check if token is within or overlaps with the span
# A token is part of the span if:
# 1. It starts within the span, or
# 2. It ends within the span, or
# 3. It completely contains the span
token_start = token.start_position
token_end = token.end_position

if (
(token_start >= start_pos and token_start < end_pos)
or (token_end > start_pos and token_end <= end_pos) # Token starts within span
or (token_start <= start_pos and token_end >= end_pos) # Token ends within span
): # Token contains span
token_indices.append(i)

# If we found tokens covering this span
if token_indices:
span_start = min(token_indices)
span_end = max(token_indices) + 1

# Create the span and add labels
span = self.get_span(span_start, span_end)
for value, score in zip(label_values, label_scores):
span.add_label(label_type, value, score)
# --- Step 7: Reapply Sentence-Level Labels ---
# Clear only sentence-level labels from sentence layer before reapplying
sentence_only_label_types = {label[0] for label in sentence_level_labels}
for label_type in sentence_only_label_types:
if label_type in self.annotation_layers:
# Keep only labels NOT attached to the sentence itself
self.annotation_layers[label_type] = [
lbl for lbl in self.annotation_layers[label_type] if lbl.data_point is not self
]
# Optional cleanup
if not self.annotation_layers[label_type]:
del self.annotation_layers[label_type]

# Add sentence labels back
for typename, value, score, metadata in sentence_level_labels:
self.add_label(typename, value, score, **metadata) # Attaches only to self


class DataPair(DataPoint, typing.Generic[DT, DT2]):
Expand Down Expand Up @@ -2021,6 +2109,9 @@ def __init__(
self._test: Optional[Dataset[T_co]] = test
self._dev: Optional[Dataset[T_co]] = dev

# --- Add attribute to store tokenizer (will be set by subclasses) ---
self.tokenizer: Optional[Tokenizer] = None

@property
def train(self) -> Optional[Dataset[T_co]]:
"""The training split as a :class:`torch.utils.data.Dataset` object."""
Expand All @@ -2036,6 +2127,15 @@ def test(self) -> Optional[Dataset[T_co]]:
"""The test split as a :class:`torch.utils.data.Dataset` object."""
return self._test

@property
def corpus_tokenizer(self) -> Optional[Tokenizer]:
"""
Returns the custom tokenizer provided during corpus initialization for retokenization, if any.
Returns None if no custom retokenizer was specified.
"""
# The tokenizer attribute is set by subclasses like ColumnCorpus during their init
return self.tokenizer

def downsample(
self,
percentage: float = 0.1,
Expand Down
20 changes: 20 additions & 0 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def __init__(
default_whitespace_after: int = 1,
every_sentence_is_independent: bool = False,
documents_as_sentences: bool = False,
use_tokenizer: Optional[Tokenizer] = None,
**corpusargs,
) -> None:
r"""Instantiates a Corpus from CoNLL column-formatted task data such as CoNLL03 or CoNLL2000.
Expand All @@ -346,6 +347,7 @@ def __init__(
in_memory: If set to True, the dataset is kept in memory as Sentence objects, otherwise does disk reads
label_name_map: Optionally map tag names to different schema.
banned_sentences: Optionally remove sentences from the corpus. Works only if `in_memory` is true
use_tokenizer: Optionally provide a flair Tokenizer object for retokenization.
"""
# get train data
train: Optional[Dataset] = (
Expand All @@ -365,6 +367,7 @@ def __init__(
default_whitespace_after=default_whitespace_after,
every_sentence_is_independent=every_sentence_is_independent,
documents_as_sentences=documents_as_sentences,
use_tokenizer=use_tokenizer,
)
for train_file in train_files
]
Expand All @@ -391,6 +394,7 @@ def __init__(
default_whitespace_after=default_whitespace_after,
every_sentence_is_independent=every_sentence_is_independent,
documents_as_sentences=documents_as_sentences,
use_tokenizer=use_tokenizer,
)
for test_file in test_files
]
Expand All @@ -417,6 +421,7 @@ def __init__(
default_whitespace_after=default_whitespace_after,
every_sentence_is_independent=every_sentence_is_independent,
documents_as_sentences=documents_as_sentences,
use_tokenizer=use_tokenizer,
)
for dev_file in dev_files
]
Expand All @@ -427,6 +432,9 @@ def __init__(

super().__init__(train, dev, test, **corpusargs)

# --- Store the retokenizer in the Corpus ---
self.tokenizer: Optional[Tokenizer] = use_tokenizer


class ColumnCorpus(MultiFileColumnCorpus):
def __init__(
Expand All @@ -439,6 +447,7 @@ def __init__(
autofind_splits: bool = True,
name: Optional[str] = None,
comment_symbol="# ",
use_tokenizer: Optional[Tokenizer] = None,
**corpusargs,
) -> None:
r"""Instantiates a Corpus from CoNLL column-formatted task data such as CoNLL03 or CoNLL2000.
Expand Down Expand Up @@ -468,6 +477,7 @@ def __init__(
test_files=[test_file] if test_file else [],
name=name if data_folder is None else str(data_folder),
comment_symbol=comment_symbol,
use_tokenizer=use_tokenizer,
**corpusargs,
)

Expand Down Expand Up @@ -495,6 +505,7 @@ def __init__(
label_name_map: Optional[dict[str, str]] = None,
default_whitespace_after: int = 1,
documents_as_sentences: bool = False,
use_tokenizer: Optional[Tokenizer] = None,
) -> None:
r"""Instantiates a column dataset.

Expand Down Expand Up @@ -530,6 +541,8 @@ def __init__(
# store either Sentence objects in memory, or only file offsets
self.in_memory = in_memory

self.use_tokenizer = use_tokenizer

self.total_sentence_count: int = 0

# most data sets have the token text in the first column, if not, pass 'text' as column
Expand Down Expand Up @@ -588,6 +601,10 @@ def __init__(
if previous_sentence:
previous_sentence._next_sentence = sentence

# retokenize sentence if custom tokenizer is provided
if self.use_tokenizer is not None and sentence:
sentence.retokenize(self.use_tokenizer)

# append parsed sentence to list in memory
self.sentences.append(sentence)

Expand Down Expand Up @@ -865,6 +882,9 @@ def __getitem__(self, index: int = 0) -> Sentence:
span_level_tag_columns=self.span_level_tag_columns,
)

if self.use_tokenizer is not None:
sentence.retokenize(self.use_tokenizer)

# set sentence context using partials TODO: pointer to dataset is really inefficient
sentence._has_context = True
sentence._position_in_dataset = (self, index)
Expand Down
Loading