|
| 1 | +# Train a Multitask Model |
| 2 | + |
| 3 | +In some cases, you might want to train a single model that can complete multiple tasks. For instance, you might want to |
| 4 | +train a model that can do both part-of-speech tagging and syntactic chunking. Or a model that can predict both entities |
| 5 | +and relations. |
| 6 | + |
| 7 | +In such cases, you typically have a single language model as backbone and multiple prediction heads with task-specific |
| 8 | +prediction logic. The potential advantage is twofold: |
| 9 | + |
| 10 | +1. Instead of having two separate language models for two tasks, you have a single model, making it more compact if you keep it in memory. |
| 11 | +2. The language model will simultaneously learn from the training data of both tasks. This may result in higher accuracy for both tasks if they are semantically close. (In practice however, such effects are rarely observed.) |
| 12 | + |
| 13 | + |
| 14 | +## Example 1: Two token-level tasks |
| 15 | + |
| 16 | +Our first multitask example is a single model that predicts part-of-speech tags and syntactic chunks. Both tasks are |
| 17 | +token-level prediction tasks that are syntactic and therefore closely related. |
| 18 | + |
| 19 | +The following script loads a single embedding for both tasks, but loads two separate training corpora. From these separate |
| 20 | +training corpora, it creates two label dictionaries and instantiates two prediction models (both [TokenClassifier](#flair.models.TokenClassifier)): |
| 21 | + |
| 22 | +```python |
| 23 | +from flair.embeddings import TransformerWordEmbeddings |
| 24 | +from flair.datasets import UD_ENGLISH, CONLL_2000 |
| 25 | +from flair.models import TokenClassifier |
| 26 | +from flair.trainers import ModelTrainer |
| 27 | +from flair.nn.multitask import make_multitask_model_and_corpus |
| 28 | + |
| 29 | +# --- Embeddings that are shared by both models --- # |
| 30 | +shared_embedding = TransformerWordEmbeddings("distilbert-base-uncased", |
| 31 | + fine_tune=True) |
| 32 | + |
| 33 | +# --- Task 1: Part-of-Speech tagging --- # |
| 34 | +corpus_1 = UD_ENGLISH().downsample(0.1) |
| 35 | + |
| 36 | +model_1 = TokenClassifier(shared_embedding, |
| 37 | + label_dictionary=corpus_1.make_label_dictionary("pos"), |
| 38 | + label_type="pos") |
| 39 | + |
| 40 | +# -- Task 2: Syntactic Chunking -- # |
| 41 | +corpus_2 = CONLL_2000().downsample(0.1) |
| 42 | + |
| 43 | +model_2 = TokenClassifier(shared_embedding, |
| 44 | + label_dictionary=corpus_2.make_label_dictionary("np"), |
| 45 | + label_type="np", |
| 46 | + ) |
| 47 | + |
| 48 | +# -- Define mapping (which tagger should train on which model) -- # |
| 49 | +multitask_model, multicorpus = make_multitask_model_and_corpus( |
| 50 | + [ |
| 51 | + (model_1, corpus_1), |
| 52 | + (model_2, corpus_2), |
| 53 | + ] |
| 54 | +) |
| 55 | + |
| 56 | +# -- Create model trainer and train -- # |
| 57 | +trainer = ModelTrainer(multitask_model, multicorpus) |
| 58 | +trainer.fine_tune(f"resources/taggers/multitask_test") |
| 59 | +``` |
| 60 | + |
| 61 | +The key is the function `make_multitask_model_and_corpus`, which takes these individual models and corpora and creates a |
| 62 | +single multitask model and corpus out of them. These are then passed to the model trainer as usual. |
| 63 | + |
| 64 | +When you run this script, it should print a training log like always, just with the difference that at the end of each epoch, |
| 65 | +both tasks are evaluated on the dev split: |
| 66 | + |
| 67 | +``` |
| 68 | +2025-03-26 22:29:16,187 ---------------------------------------------------------------------------------------------------- |
| 69 | +2025-03-26 22:29:16,187 EPOCH 1 done: loss 1.3350 - lr: 0.000049 |
| 70 | +2025-03-26 22:29:16,525 Task_0 - TokenClassifier - loss: 0.28858715295791626 - f1-score (micro avg) 0.9292 |
| 71 | +2025-03-26 22:29:16,776 Task_1 - TokenClassifier - loss: 7.221250534057617 - f1-score (micro avg) 0.9077 |
| 72 | +2025-03-26 22:29:16,777 DEV : loss 3.7549188137054443 - f1-score (micro avg) 0.9184 |
| 73 | +2025-03-26 22:29:16,783 ---------------------------------------------------------------------------------------------------- |
| 74 | +``` |
| 75 | + |
| 76 | +### Prediction |
| 77 | + |
| 78 | +A trained model can be loaded and used for prediction like any other model. However, it will make predictions for all |
| 79 | +tasks it was trained with: |
| 80 | + |
| 81 | +```python |
| 82 | +from flair.data import Sentence |
| 83 | +from flair.models import MultitaskModel |
| 84 | + |
| 85 | +# load the trained multitask model |
| 86 | +model = MultitaskModel.load("resources/taggers/multitask_test/final-model.pt") |
| 87 | + |
| 88 | +# create example sentence |
| 89 | +sentence = Sentence("Mr Smith loves New York") |
| 90 | + |
| 91 | +# predict (this triggers prediction of all tasks in the multitask model) |
| 92 | +model.predict(sentence) |
| 93 | + |
| 94 | +# print the sentence with POS and chunk tags |
| 95 | +print(sentence) |
| 96 | + |
| 97 | +# inspect the POS tags only |
| 98 | +print("\nPart of speech tags are: ") |
| 99 | +for label in sentence.get_labels('pos'): |
| 100 | + print(f'"{label.data_point.text}" {label.value}') |
| 101 | + |
| 102 | +# inspect the chunks only |
| 103 | +print("\nChunks are: ") |
| 104 | +for label in sentence.get_labels('np'): |
| 105 | + print(f'"{label.data_point.text}" {label.value}') |
| 106 | +``` |
| 107 | + |
| 108 | +This will print: |
| 109 | + |
| 110 | +``` |
| 111 | +Sentence[5]: "Mr Smith loves New York" → ["Mr"/NNP, "Mr Smith"/NP, "Smith"/NNP, "loves"/VBZ, "loves"/VP, "New"/NNP, "New York"/NP, "York"/NNP] |
| 112 | +
|
| 113 | +Part of speech tags are: |
| 114 | +"Mr" NNP |
| 115 | +"Smith" NNP |
| 116 | +"loves" VBZ |
| 117 | +"New" NNP |
| 118 | +"York" NNP |
| 119 | +
|
| 120 | +Chunks are: |
| 121 | +"Mr Smith" NP |
| 122 | +"loves" VP |
| 123 | +"New York" NP |
| 124 | +``` |
| 125 | + |
| 126 | + |
| 127 | +## Example 2: A token and a document-level task |
| 128 | + |
| 129 | +In some cases, you may want to train a multitask model using [TransformerWordEmbeddings](#flair.embeddings.transformer.TransformerWordEmbeddings) (token-level embeddings) |
| 130 | +and [TransformerDocumentEmbeddings](#flair.embeddings.transformer.TransformerDocumentEmbeddings) (text-level embeddings). For instance, you may want to train a model that can both |
| 131 | +detect topics and entities in online news articles. |
| 132 | + |
| 133 | +The code is similar to example 1, but you need more general [TransformerEmbeddings](#flair.embeddings.transformer.TransformerEmbeddings) that can produce both token- and text-level |
| 134 | +embeddings. You also need two different model classes: A [TextClassifier](#flair.models.TextClassifier) for predicting topics and a [TokenClassifier](#flair.models.TokenClassifier) for |
| 135 | +prediction NER tags: |
| 136 | + |
| 137 | +```python |
| 138 | +from flair.embeddings import TransformerEmbeddings |
| 139 | +from flair.datasets import AGNEWS, CLEANCONLL |
| 140 | +from flair.models import TokenClassifier, TextClassifier |
| 141 | +from flair.trainers import ModelTrainer |
| 142 | +from flair.nn.multitask import make_multitask_model_and_corpus |
| 143 | + |
| 144 | +# --- Embeddings that are shared by both models --- # |
| 145 | +# use a transformer that can do both: sentence-embedding and word-embedding |
| 146 | +shared_embedding = TransformerEmbeddings("distilbert-base-uncased", |
| 147 | + fine_tune=True, |
| 148 | + is_token_embedding=True, |
| 149 | + is_document_embedding=True) |
| 150 | + |
| 151 | +# --- Task 1: Newswire topics, use a TextClassifier for this task --- # |
| 152 | +corpus_1 = AGNEWS().downsample(0.01) |
| 153 | + |
| 154 | +model_1 = TextClassifier(shared_embedding, |
| 155 | + label_dictionary=corpus_1.make_label_dictionary("topic"), |
| 156 | + label_type="topic") |
| 157 | + |
| 158 | +# -- Task 2: Named entities on newswire data, use a TokenClassifier for this task --- # |
| 159 | +corpus_2 = CLEANCONLL().downsample(0.1) |
| 160 | + |
| 161 | +model_2 = TokenClassifier(shared_embedding, |
| 162 | + label_dictionary=corpus_2.make_label_dictionary("ner"), |
| 163 | + label_type="ner", |
| 164 | + ) |
| 165 | + |
| 166 | +# -- Define mapping (which tagger should train on which model) -- # |
| 167 | +multitask_model, multicorpus = make_multitask_model_and_corpus( |
| 168 | + [ |
| 169 | + (model_1, corpus_1), |
| 170 | + (model_2, corpus_2), |
| 171 | + ] |
| 172 | +) |
| 173 | + |
| 174 | +# -- Create model trainer and train -- # |
| 175 | +trainer = ModelTrainer(multitask_model, multicorpus) |
| 176 | +trainer.fine_tune(f"resources/taggers/multitask_news_data") |
| 177 | +``` |
| 178 | + |
| 179 | + |
| 180 | +### Prediction |
| 181 | + |
| 182 | +Similar to example 1, you can load and predict tags for both classes with a single predict call: |
| 183 | + |
| 184 | +```python |
| 185 | +from flair.data import Sentence |
| 186 | +from flair.models import MultitaskModel |
| 187 | + |
| 188 | +# load the trained multitask model |
| 189 | +model = MultitaskModel.load("resources/taggers/multitask_news_data/final-model.pt") |
| 190 | + |
| 191 | +# create example sentence |
| 192 | +sentence = Sentence("IBM made a lot of profit.") |
| 193 | + |
| 194 | +model.predict(sentence) |
| 195 | + |
| 196 | +# print the sentence with predicted topic and entity tag |
| 197 | +print(sentence) |
| 198 | +``` |
| 199 | + |
| 200 | +This prints: |
| 201 | + |
| 202 | +``` |
| 203 | +Sentence[7]: "IBM made a lot of profit." → Business (0.8645) → ["IBM"/ORG] |
| 204 | +``` |
| 205 | + |
| 206 | +Showing that the topic "Business" and the entity "IBM" were detected in this sentence. |
0 commit comments