Skip to content

Commit 6adefcc

Browse files
authored
Merge pull request #3653 from flairNLP/GH-3632-multitask-learning-tutorial
Multitask learning tutorial and performance improvement
2 parents c0e427c + e0764b2 commit 6adefcc

File tree

5 files changed

+227
-125
lines changed

5 files changed

+227
-125
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Train a Multitask Model
2+
3+
In some cases, you might want to train a single model that can complete multiple tasks. For instance, you might want to
4+
train a model that can do both part-of-speech tagging and syntactic chunking. Or a model that can predict both entities
5+
and relations.
6+
7+
In such cases, you typically have a single language model as backbone and multiple prediction heads with task-specific
8+
prediction logic. The potential advantage is twofold:
9+
10+
1. Instead of having two separate language models for two tasks, you have a single model, making it more compact if you keep it in memory.
11+
2. The language model will simultaneously learn from the training data of both tasks. This may result in higher accuracy for both tasks if they are semantically close. (In practice however, such effects are rarely observed.)
12+
13+
14+
## Example 1: Two token-level tasks
15+
16+
Our first multitask example is a single model that predicts part-of-speech tags and syntactic chunks. Both tasks are
17+
token-level prediction tasks that are syntactic and therefore closely related.
18+
19+
The following script loads a single embedding for both tasks, but loads two separate training corpora. From these separate
20+
training corpora, it creates two label dictionaries and instantiates two prediction models (both [TokenClassifier](#flair.models.TokenClassifier)):
21+
22+
```python
23+
from flair.embeddings import TransformerWordEmbeddings
24+
from flair.datasets import UD_ENGLISH, CONLL_2000
25+
from flair.models import TokenClassifier
26+
from flair.trainers import ModelTrainer
27+
from flair.nn.multitask import make_multitask_model_and_corpus
28+
29+
# --- Embeddings that are shared by both models --- #
30+
shared_embedding = TransformerWordEmbeddings("distilbert-base-uncased",
31+
fine_tune=True)
32+
33+
# --- Task 1: Part-of-Speech tagging --- #
34+
corpus_1 = UD_ENGLISH().downsample(0.1)
35+
36+
model_1 = TokenClassifier(shared_embedding,
37+
label_dictionary=corpus_1.make_label_dictionary("pos"),
38+
label_type="pos")
39+
40+
# -- Task 2: Syntactic Chunking -- #
41+
corpus_2 = CONLL_2000().downsample(0.1)
42+
43+
model_2 = TokenClassifier(shared_embedding,
44+
label_dictionary=corpus_2.make_label_dictionary("np"),
45+
label_type="np",
46+
)
47+
48+
# -- Define mapping (which tagger should train on which model) -- #
49+
multitask_model, multicorpus = make_multitask_model_and_corpus(
50+
[
51+
(model_1, corpus_1),
52+
(model_2, corpus_2),
53+
]
54+
)
55+
56+
# -- Create model trainer and train -- #
57+
trainer = ModelTrainer(multitask_model, multicorpus)
58+
trainer.fine_tune(f"resources/taggers/multitask_test")
59+
```
60+
61+
The key is the function `make_multitask_model_and_corpus`, which takes these individual models and corpora and creates a
62+
single multitask model and corpus out of them. These are then passed to the model trainer as usual.
63+
64+
When you run this script, it should print a training log like always, just with the difference that at the end of each epoch,
65+
both tasks are evaluated on the dev split:
66+
67+
```
68+
2025-03-26 22:29:16,187 ----------------------------------------------------------------------------------------------------
69+
2025-03-26 22:29:16,187 EPOCH 1 done: loss 1.3350 - lr: 0.000049
70+
2025-03-26 22:29:16,525 Task_0 - TokenClassifier - loss: 0.28858715295791626 - f1-score (micro avg) 0.9292
71+
2025-03-26 22:29:16,776 Task_1 - TokenClassifier - loss: 7.221250534057617 - f1-score (micro avg) 0.9077
72+
2025-03-26 22:29:16,777 DEV : loss 3.7549188137054443 - f1-score (micro avg) 0.9184
73+
2025-03-26 22:29:16,783 ----------------------------------------------------------------------------------------------------
74+
```
75+
76+
### Prediction
77+
78+
A trained model can be loaded and used for prediction like any other model. However, it will make predictions for all
79+
tasks it was trained with:
80+
81+
```python
82+
from flair.data import Sentence
83+
from flair.models import MultitaskModel
84+
85+
# load the trained multitask model
86+
model = MultitaskModel.load("resources/taggers/multitask_test/final-model.pt")
87+
88+
# create example sentence
89+
sentence = Sentence("Mr Smith loves New York")
90+
91+
# predict (this triggers prediction of all tasks in the multitask model)
92+
model.predict(sentence)
93+
94+
# print the sentence with POS and chunk tags
95+
print(sentence)
96+
97+
# inspect the POS tags only
98+
print("\nPart of speech tags are: ")
99+
for label in sentence.get_labels('pos'):
100+
print(f'"{label.data_point.text}" {label.value}')
101+
102+
# inspect the chunks only
103+
print("\nChunks are: ")
104+
for label in sentence.get_labels('np'):
105+
print(f'"{label.data_point.text}" {label.value}')
106+
```
107+
108+
This will print:
109+
110+
```
111+
Sentence[5]: "Mr Smith loves New York" → ["Mr"/NNP, "Mr Smith"/NP, "Smith"/NNP, "loves"/VBZ, "loves"/VP, "New"/NNP, "New York"/NP, "York"/NNP]
112+
113+
Part of speech tags are:
114+
"Mr" NNP
115+
"Smith" NNP
116+
"loves" VBZ
117+
"New" NNP
118+
"York" NNP
119+
120+
Chunks are:
121+
"Mr Smith" NP
122+
"loves" VP
123+
"New York" NP
124+
```
125+
126+
127+
## Example 2: A token and a document-level task
128+
129+
In some cases, you may want to train a multitask model using [TransformerWordEmbeddings](#flair.embeddings.transformer.TransformerWordEmbeddings) (token-level embeddings)
130+
and [TransformerDocumentEmbeddings](#flair.embeddings.transformer.TransformerDocumentEmbeddings) (text-level embeddings). For instance, you may want to train a model that can both
131+
detect topics and entities in online news articles.
132+
133+
The code is similar to example 1, but you need more general [TransformerEmbeddings](#flair.embeddings.transformer.TransformerEmbeddings) that can produce both token- and text-level
134+
embeddings. You also need two different model classes: A [TextClassifier](#flair.models.TextClassifier) for predicting topics and a [TokenClassifier](#flair.models.TokenClassifier) for
135+
prediction NER tags:
136+
137+
```python
138+
from flair.embeddings import TransformerEmbeddings
139+
from flair.datasets import AGNEWS, CLEANCONLL
140+
from flair.models import TokenClassifier, TextClassifier
141+
from flair.trainers import ModelTrainer
142+
from flair.nn.multitask import make_multitask_model_and_corpus
143+
144+
# --- Embeddings that are shared by both models --- #
145+
# use a transformer that can do both: sentence-embedding and word-embedding
146+
shared_embedding = TransformerEmbeddings("distilbert-base-uncased",
147+
fine_tune=True,
148+
is_token_embedding=True,
149+
is_document_embedding=True)
150+
151+
# --- Task 1: Newswire topics, use a TextClassifier for this task --- #
152+
corpus_1 = AGNEWS().downsample(0.01)
153+
154+
model_1 = TextClassifier(shared_embedding,
155+
label_dictionary=corpus_1.make_label_dictionary("topic"),
156+
label_type="topic")
157+
158+
# -- Task 2: Named entities on newswire data, use a TokenClassifier for this task --- #
159+
corpus_2 = CLEANCONLL().downsample(0.1)
160+
161+
model_2 = TokenClassifier(shared_embedding,
162+
label_dictionary=corpus_2.make_label_dictionary("ner"),
163+
label_type="ner",
164+
)
165+
166+
# -- Define mapping (which tagger should train on which model) -- #
167+
multitask_model, multicorpus = make_multitask_model_and_corpus(
168+
[
169+
(model_1, corpus_1),
170+
(model_2, corpus_2),
171+
]
172+
)
173+
174+
# -- Create model trainer and train -- #
175+
trainer = ModelTrainer(multitask_model, multicorpus)
176+
trainer.fine_tune(f"resources/taggers/multitask_news_data")
177+
```
178+
179+
180+
### Prediction
181+
182+
Similar to example 1, you can load and predict tags for both classes with a single predict call:
183+
184+
```python
185+
from flair.data import Sentence
186+
from flair.models import MultitaskModel
187+
188+
# load the trained multitask model
189+
model = MultitaskModel.load("resources/taggers/multitask_news_data/final-model.pt")
190+
191+
# create example sentence
192+
sentence = Sentence("IBM made a lot of profit.")
193+
194+
model.predict(sentence)
195+
196+
# print the sentence with predicted topic and entity tag
197+
print(sentence)
198+
```
199+
200+
This prints:
201+
202+
```
203+
Sentence[7]: "IBM made a lot of profit." → Business (0.8645) → ["IBM"/ORG]
204+
```
205+
206+
Showing that the topic "Business" and the entity "IBM" were detected in this sentence.
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Train a span classifier
1+
# Train a Span Classifier
22

33
Span Classification models are used to model problems such as entity linking, where you already have extracted some
44
relevant spans within the `Sentence` and want to predict some more fine-grained labels.
@@ -107,127 +107,6 @@ data_folder = '/path/to/data/folder'
107107
corpus: Corpus = ColumnCorpus(data_folder, columns)
108108
```
109109

110-
## constructing a dataset in memory
111-
112-
If you have a pipeline where you need to construct your dataset from a different data source,
113-
you can always construct a [Corpus](#flair.data.Corpus) with [FlairDatapointDataset](#flair.datasets.base.FlairDatapointDataset) by hand.
114-
Let's assume you create a function `create_datapoint(datapoint) -> Sentence` that looks somewhat like this:
115-
```python
116-
from flair.data import Sentence
117-
118-
def create_sentence(datapoint) -> Sentence:
119-
tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary)
120-
spans = ... # create a list of tuples (start_token, end_token, label) from your data structure
121-
sentence = Sentence(tokens)
122-
for (start, end, label) in spans:
123-
sentence[start:end+1].add_label("nel", label)
124-
```
125-
Then you can use this function to create a full dataset:
126-
```python
127-
from flair.data import Corpus
128-
from flair.datasets import FlairDatapointDataset
129-
130-
def construct_corpus(data):
131-
return Corpus(
132-
train=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["train"])]),
133-
dev=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["dev"])]),
134-
test=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["test"])]),
135-
)
136-
```
137-
And use this to construct a corpus instead of loading a dataset.
138-
139-
140-
## Combining NEL with Mention Detection
141-
142-
often, you don't just want to use a Named Entity Linking model alone, but combine it with a Mention Detection or Named Entity Recognition model.
143-
For this, you can use a [Multitask Model](#flair.models.MultitaskModel) to combine a [SequenceTagger](#flair.models.SequenceTagger) and a [Span Classifier](#flair.models.SpanClassifier).
144-
145-
```python
146-
from flair.datasets import NER_MULTI_WIKINER, ZELDA
147-
from flair.embeddings import TransformerWordEmbeddings
148-
from flair.models import SequenceTagger, SpanClassifier
149-
from flair.models.entity_linker_model import CandidateGenerator
150-
from flair.trainers import ModelTrainer
151-
from flair.nn import PrototypicalDecoder
152-
from flair.nn.multitask import make_multitask_model_and_corpus
153-
154-
# 1. get the corpus
155-
ner_corpus = NER_MULTI_WIKINER()
156-
nel_corpus = ZELDA(column_format={0: "text", 2: "nel"}) # need to set the label type to be the same as the ner one
157-
158-
# --- Embeddings that are shared by both models --- #
159-
shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True)
160-
161-
ner_label_dict = ner_corpus.make_label_dictionary("ner", add_unk=False)
162-
163-
ner_model = SequenceTagger(
164-
embeddings=shared_embeddings,
165-
tag_dictionary=ner_label_dict,
166-
tag_type="ner",
167-
use_rnn=False,
168-
use_crf=False,
169-
reproject_embeddings=False,
170-
)
171-
172-
173-
nel_label_dict = nel_corpus.make_label_dictionary("nel", add_unk=True)
174-
175-
nel_model = SpanClassifier(
176-
embeddings=shared_embeddings,
177-
label_dictionary=nel_label_dict,
178-
label_type="nel",
179-
span_label_type="ner",
180-
decoder=PrototypicalDecoder(
181-
num_prototypes=len(nel_label_dict),
182-
embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans
183-
distance_function="dot_product",
184-
),
185-
candidates=CandidateGenerator("zelda"),
186-
)
187-
188-
189-
# -- Define mapping (which tagger should train on which model) -- #
190-
multitask_model, multicorpus = make_multitask_model_and_corpus(
191-
[
192-
(ner_model, ner_corpus),
193-
(nel_model, nel_corpus),
194-
]
195-
)
196-
197-
# -- Create model trainer and train -- #
198-
trainer = ModelTrainer(multitask_model, multicorpus)
199-
trainer.fine_tune(f"resources/taggers/zelda_with_mention")
200-
```
201-
202-
Here, the [make_multitask_model_and_corpus](#flair.nn.multitask.make_multitask_model_and_corpus) method creates a multitask model and a multicorpus where each sub-model is aligned for a sub-corpus.
203-
204-
### Multitask with aligned training data
205-
206-
If you have sentences with both annotations for ner and for nel, you might want to use a single corpus for both models.
207-
208-
This means, that you need to manually the `multitask_id` to the sentences:
209-
210-
```python
211-
from flair.data import Sentence
212-
213-
def create_sentence(datapoint) -> Sentence:
214-
tokens = ... # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary)
215-
spans = ... # create a list of tuples (start_token, end_token, label) from your data structure
216-
sentence = Sentence(tokens)
217-
for (start, end, ner_label, nel_label) in spans:
218-
sentence[start:end+1].add_label("ner", ner_label)
219-
sentence[start:end+1].add_label("nel", nel_label)
220-
sentence.add_label("multitask_id", "Task_0") # Task_0 for the NER model
221-
sentence.add_label("multitask_id", "Task_1") # Task_1 for the NEL model
222-
```
223-
224-
Then you can run the multitask training script with the exception that you create the [MultitaskModel](#flair.models.MultitaskModel) directly.
225-
226-
```python
227-
...
228-
multitask_model = MultitaskModel([ner_model, nel_model], use_all_tasks=True)
229-
```
230-
231-
Here, setting `use_all_tasks=True` means that we will jointly train on both tasks at the same time. This will save a lot of training time,
232-
as the shared embedding will be calculated once but used twice (once for each model).
110+
## Next
233111

112+
Next, learn [how to train a multitask model in Flair](how-to-train-multitask-model.md).

docs/tutorial/tutorial-training/how-to-train-text-classifier.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Train a text classifier
1+
# Train a Text Classifier
22

33
This tutorial shows you how to train your own text classifier models with Flair. For instance, you
44
could train your own sentiment analysis model, or offensive language detection model.

docs/tutorial/tutorial-training/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ This tutorial illustrates how you can train your own state-of-the-art NLP models
1414
how-to-train-sequence-tagger
1515
how-to-train-text-classifier
1616
how-to-train-span-classifier
17+
how-to-train-multitask-model

flair/models/multitask_model.py

+16
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,25 @@ def predict(
8888
sentences,
8989
**predictargs,
9090
):
91+
92+
if not isinstance(sentences, list):
93+
sentences = [sentences]
94+
95+
# if not specified, set embedding storage mode to "cpu" to ensure that embeddings are reused
96+
remove_embeddings_after_prediction = False
97+
if "embedding_storage_mode" not in predictargs:
98+
predictargs["embedding_storage_mode"] = "cpu"
99+
remove_embeddings_after_prediction = True
100+
101+
# predict for each task separately
91102
for task in self.tasks.values():
92103
task.predict(sentences, **predictargs)
93104

105+
# if embeddings were stored only to be reused for prediction, they can be removed after
106+
if remove_embeddings_after_prediction:
107+
for sentence in sentences:
108+
sentence.clear_embeddings()
109+
94110
@staticmethod
95111
def split_batch_to_task_ids(
96112
sentences: Union[list[Sentence], Sentence], all_tasks: bool = False

0 commit comments

Comments
 (0)