Skip to content

Multitask learning tutorial and performance improvement #3653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions docs/tutorial/tutorial-training/how-to-train-multitask-model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Train a Multitask Model

In some cases, you might want to train a single model that can complete multiple tasks. For instance, you might want to
train a model that can do both part-of-speech tagging and syntactic chunking. Or a model that can predict both entities
and relations.

In such cases, you typically have a single language model as backbone and multiple prediction heads with task-specific
prediction logic. The potential advantage is twofold:

1. Instead of having two separate language models for two tasks, you have a single model, making it more compact if you keep it in memory.
2. The language model will simultaneously learn from the training data of both tasks. This may result in higher accuracy for both tasks if they are semantically close. (In practice however, such effects are rarely observed.)


## Example 1: Two token-level tasks

Our first multitask example is a single model that predicts part-of-speech tags and syntactic chunks. Both tasks are
token-level prediction tasks that are syntactic and therefore closely related.

The following script loads a single embedding for both tasks, but loads two separate training corpora. From these separate
training corpora, it creates two label dictionaries and instantiates two prediction models (both [TokenClassifier](#flair.models.TokenClassifier)):

```python
from flair.embeddings import TransformerWordEmbeddings
from flair.datasets import UD_ENGLISH, CONLL_2000
from flair.models import TokenClassifier
from flair.trainers import ModelTrainer
from flair.nn.multitask import make_multitask_model_and_corpus

# --- Embeddings that are shared by both models --- #
shared_embedding = TransformerWordEmbeddings("distilbert-base-uncased",
fine_tune=True)

# --- Task 1: Part-of-Speech tagging --- #
corpus_1 = UD_ENGLISH().downsample(0.1)

model_1 = TokenClassifier(shared_embedding,
label_dictionary=corpus_1.make_label_dictionary("pos"),
label_type="pos")

# -- Task 2: Syntactic Chunking -- #
corpus_2 = CONLL_2000().downsample(0.1)

model_2 = TokenClassifier(shared_embedding,
label_dictionary=corpus_2.make_label_dictionary("np"),
label_type="np",
)

# -- Define mapping (which tagger should train on which model) -- #
multitask_model, multicorpus = make_multitask_model_and_corpus(
[
(model_1, corpus_1),
(model_2, corpus_2),
]
)

# -- Create model trainer and train -- #
trainer = ModelTrainer(multitask_model, multicorpus)
trainer.fine_tune(f"resources/taggers/multitask_test")
```

The key is the function `make_multitask_model_and_corpus`, which takes these individual models and corpora and creates a
single multitask model and corpus out of them. These are then passed to the model trainer as usual.

When you run this script, it should print a training log like always, just with the difference that at the end of each epoch,
both tasks are evaluated on the dev split:

```
2025-03-26 22:29:16,187 ----------------------------------------------------------------------------------------------------
2025-03-26 22:29:16,187 EPOCH 1 done: loss 1.3350 - lr: 0.000049
2025-03-26 22:29:16,525 Task_0 - TokenClassifier - loss: 0.28858715295791626 - f1-score (micro avg) 0.9292
2025-03-26 22:29:16,776 Task_1 - TokenClassifier - loss: 7.221250534057617 - f1-score (micro avg) 0.9077
2025-03-26 22:29:16,777 DEV : loss 3.7549188137054443 - f1-score (micro avg) 0.9184
2025-03-26 22:29:16,783 ----------------------------------------------------------------------------------------------------
```

### Prediction

A trained model can be loaded and used for prediction like any other model. However, it will make predictions for all
tasks it was trained with:

```python
from flair.data import Sentence
from flair.models import MultitaskModel

# load the trained multitask model
model = MultitaskModel.load("resources/taggers/multitask_test/final-model.pt")

# create example sentence
sentence = Sentence("Mr Smith loves New York")

# predict (this triggers prediction of all tasks in the multitask model)
model.predict(sentence)

# print the sentence with POS and chunk tags
print(sentence)

# inspect the POS tags only
print("\nPart of speech tags are: ")
for label in sentence.get_labels('pos'):
print(f'"{label.data_point.text}" {label.value}')

# inspect the chunks only
print("\nChunks are: ")
for label in sentence.get_labels('np'):
print(f'"{label.data_point.text}" {label.value}')
```

This will print:

```
Sentence[5]: "Mr Smith loves New York" → ["Mr"/NNP, "Mr Smith"/NP, "Smith"/NNP, "loves"/VBZ, "loves"/VP, "New"/NNP, "New York"/NP, "York"/NNP]

Part of speech tags are:
"Mr" NNP
"Smith" NNP
"loves" VBZ
"New" NNP
"York" NNP

Chunks are:
"Mr Smith" NP
"loves" VP
"New York" NP
```


## Example 2: A token and a document-level task

In some cases, you may want to train a multitask model using [TransformerWordEmbeddings](#flair.embeddings.transformer.TransformerWordEmbeddings) (token-level embeddings)
and [TransformerDocumentEmbeddings](#flair.embeddings.transformer.TransformerDocumentEmbeddings) (text-level embeddings). For instance, you may want to train a model that can both
detect topics and entities in online news articles.

The code is similar to example 1, but you need more general [TransformerEmbeddings](#flair.embeddings.transformer.TransformerEmbeddings) that can produce both token- and text-level
embeddings. You also need two different model classes: A [TextClassifier](#flair.models.TextClassifier) for predicting topics and a [TokenClassifier](#flair.models.TokenClassifier) for
prediction NER tags:

```python
from flair.embeddings import TransformerEmbeddings
from flair.datasets import AGNEWS, CLEANCONLL
from flair.models import TokenClassifier, TextClassifier
from flair.trainers import ModelTrainer
from flair.nn.multitask import make_multitask_model_and_corpus

# --- Embeddings that are shared by both models --- #
# use a transformer that can do both: sentence-embedding and word-embedding
shared_embedding = TransformerEmbeddings("distilbert-base-uncased",
fine_tune=True,
is_token_embedding=True,
is_document_embedding=True)

# --- Task 1: Newswire topics, use a TextClassifier for this task --- #
corpus_1 = AGNEWS().downsample(0.01)

model_1 = TextClassifier(shared_embedding,
label_dictionary=corpus_1.make_label_dictionary("topic"),
label_type="topic")

# -- Task 2: Named entities on newswire data, use a TokenClassifier for this task --- #
corpus_2 = CLEANCONLL().downsample(0.1)

model_2 = TokenClassifier(shared_embedding,
label_dictionary=corpus_2.make_label_dictionary("ner"),
label_type="ner",
)

# -- Define mapping (which tagger should train on which model) -- #
multitask_model, multicorpus = make_multitask_model_and_corpus(
[
(model_1, corpus_1),
(model_2, corpus_2),
]
)

# -- Create model trainer and train -- #
trainer = ModelTrainer(multitask_model, multicorpus)
trainer.fine_tune(f"resources/taggers/multitask_news_data")
```


### Prediction

Similar to example 1, you can load and predict tags for both classes with a single predict call:

```python
from flair.data import Sentence
from flair.models import MultitaskModel

# load the trained multitask model
model = MultitaskModel.load("resources/taggers/multitask_news_data/final-model.pt")

# create example sentence
sentence = Sentence("IBM made a lot of profit.")

model.predict(sentence)

# print the sentence with predicted topic and entity tag
print(sentence)
```

This prints:

```
Sentence[7]: "IBM made a lot of profit." → Business (0.8645) → ["IBM"/ORG]
```

Showing that the topic "Business" and the entity "IBM" were detected in this sentence.
127 changes: 3 additions & 124 deletions docs/tutorial/tutorial-training/how-to-train-span-classifier.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Train a span classifier
# Train a Span Classifier

Span Classification models are used to model problems such as entity linking, where you already have extracted some
relevant spans within the `Sentence` and want to predict some more fine-grained labels.
Expand Down Expand Up @@ -107,127 +107,6 @@ data_folder = '/path/to/data/folder'
corpus: Corpus = ColumnCorpus(data_folder, columns)
```

## constructing a dataset in memory

If you have a pipeline where you need to construct your dataset from a different data source,
you can always construct a [Corpus](#flair.data.Corpus) with [FlairDatapointDataset](#flair.datasets.base.FlairDatapointDataset) by hand.
Let's assume you create a function `create_datapoint(datapoint) -> Sentence` that looks somewhat like this:
```python
from flair.data import Sentence

def create_sentence(datapoint) -> Sentence:
tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary)
spans = ... # create a list of tuples (start_token, end_token, label) from your data structure
sentence = Sentence(tokens)
for (start, end, label) in spans:
sentence[start:end+1].add_label("nel", label)
```
Then you can use this function to create a full dataset:
```python
from flair.data import Corpus
from flair.datasets import FlairDatapointDataset

def construct_corpus(data):
return Corpus(
train=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["train"])]),
dev=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["dev"])]),
test=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["test"])]),
)
```
And use this to construct a corpus instead of loading a dataset.


## Combining NEL with Mention Detection

often, you don't just want to use a Named Entity Linking model alone, but combine it with a Mention Detection or Named Entity Recognition model.
For this, you can use a [Multitask Model](#flair.models.MultitaskModel) to combine a [SequenceTagger](#flair.models.SequenceTagger) and a [Span Classifier](#flair.models.SpanClassifier).

```python
from flair.datasets import NER_MULTI_WIKINER, ZELDA
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger, SpanClassifier
from flair.models.entity_linker_model import CandidateGenerator
from flair.trainers import ModelTrainer
from flair.nn import PrototypicalDecoder
from flair.nn.multitask import make_multitask_model_and_corpus

# 1. get the corpus
ner_corpus = NER_MULTI_WIKINER()
nel_corpus = ZELDA(column_format={0: "text", 2: "nel"}) # need to set the label type to be the same as the ner one

# --- Embeddings that are shared by both models --- #
shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True)

ner_label_dict = ner_corpus.make_label_dictionary("ner", add_unk=False)

ner_model = SequenceTagger(
embeddings=shared_embeddings,
tag_dictionary=ner_label_dict,
tag_type="ner",
use_rnn=False,
use_crf=False,
reproject_embeddings=False,
)


nel_label_dict = nel_corpus.make_label_dictionary("nel", add_unk=True)

nel_model = SpanClassifier(
embeddings=shared_embeddings,
label_dictionary=nel_label_dict,
label_type="nel",
span_label_type="ner",
decoder=PrototypicalDecoder(
num_prototypes=len(nel_label_dict),
embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans
distance_function="dot_product",
),
candidates=CandidateGenerator("zelda"),
)


# -- Define mapping (which tagger should train on which model) -- #
multitask_model, multicorpus = make_multitask_model_and_corpus(
[
(ner_model, ner_corpus),
(nel_model, nel_corpus),
]
)

# -- Create model trainer and train -- #
trainer = ModelTrainer(multitask_model, multicorpus)
trainer.fine_tune(f"resources/taggers/zelda_with_mention")
```

Here, the [make_multitask_model_and_corpus](#flair.nn.multitask.make_multitask_model_and_corpus) method creates a multitask model and a multicorpus where each sub-model is aligned for a sub-corpus.

### Multitask with aligned training data

If you have sentences with both annotations for ner and for nel, you might want to use a single corpus for both models.

This means, that you need to manually the `multitask_id` to the sentences:

```python
from flair.data import Sentence

def create_sentence(datapoint) -> Sentence:
tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary)
spans = ... # create a list of tuples (start_token, end_token, label) from your data structure
sentence = Sentence(tokens)
for (start, end, ner_label, nel_label) in spans:
sentence[start:end+1].add_label("ner", ner_label)
sentence[start:end+1].add_label("nel", nel_label)
sentence.add_label("multitask_id", "Task_0") # Task_0 for the NER model
sentence.add_label("multitask_id", "Task_1") # Task_1 for the NEL model
```

Then you can run the multitask training script with the exception that you create the [MultitaskModel](#flair.models.MultitaskModel) directly.

```python
...
multitask_model = MultitaskModel([ner_model, nel_model], use_all_tasks=True)
```

Here, setting `use_all_tasks=True` means that we will jointly train on both tasks at the same time. This will save a lot of training time,
as the shared embedding will be calculated once but used twice (once for each model).
## Next

Next, learn [how to train a multitask model in Flair](how-to-train-multitask-model.md).
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Train a text classifier
# Train a Text Classifier

This tutorial shows you how to train your own text classifier models with Flair. For instance, you
could train your own sentiment analysis model, or offensive language detection model.
Expand Down
1 change: 1 addition & 0 deletions docs/tutorial/tutorial-training/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ This tutorial illustrates how you can train your own state-of-the-art NLP models
how-to-train-sequence-tagger
how-to-train-text-classifier
how-to-train-span-classifier
how-to-train-multitask-model
16 changes: 16 additions & 0 deletions flair/models/multitask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,25 @@ def predict(
sentences,
**predictargs,
):

if not isinstance(sentences, list):
sentences = [sentences]

# if not specified, set embedding storage mode to "cpu" to ensure that embeddings are reused
remove_embeddings_after_prediction = False
if "embedding_storage_mode" not in predictargs:
predictargs["embedding_storage_mode"] = "cpu"
remove_embeddings_after_prediction = True

# predict for each task separately
for task in self.tasks.values():
task.predict(sentences, **predictargs)

# if embeddings were stored only to be reused for prediction, they can be removed after
if remove_embeddings_after_prediction:
for sentence in sentences:
sentence.clear_embeddings()

@staticmethod
def split_batch_to_task_ids(
sentences: Union[list[Sentence], Sentence], all_tasks: bool = False
Expand Down