1
1
import os
2
2
import json
3
3
import logging
4
+ import datasets
4
5
from spacy .tokens import Doc
5
6
from datetime import datetime
6
7
from typing import Iterable , Iterator , Optional , Dict , List , cast , Union
18
19
19
20
from transformers import Trainer , AutoModelForTokenClassification , AutoTokenizer
20
21
from transformers import pipeline , TrainingArguments
21
- import datasets
22
+ from transformers . trainer_callback import TrainerCallback
22
23
23
24
# It should be safe to do this always, as all other multiprocessing
24
25
#will be finished before data comes to meta_cat
@@ -137,7 +138,12 @@ def merge_data_loaded(base, other):
137
138
138
139
return out_path
139
140
140
- def train (self , json_path : Union [str , list , None ]= None , ignore_extra_labels = False , dataset = None , meta_requirements = None ):
141
+ def train (self ,
142
+ json_path : Union [str , list , None ]= None ,
143
+ ignore_extra_labels = False ,
144
+ dataset = None ,
145
+ meta_requirements = None ,
146
+ trainer_callbacks : Optional [List [TrainerCallback ]]= None ):
141
147
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
142
148
continue training if an existing model is loaded or start new training if the model is blank/new.
143
149
@@ -149,6 +155,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals
149
155
ignore_extra_labels:
150
156
Makes only sense when an existing deid model was loaded and from the new data we want to ignore
151
157
labels that did not exist in the old model.
158
+ trainer_callbacks (List[TrainerCallback]):
159
+ A list of trainer callbacks for collecting metrics during the training at the client side. The
160
+ transformers Trainer object will be passed in when each callback is called.
152
161
"""
153
162
154
163
if dataset is None and json_path is not None :
@@ -193,6 +202,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals
193
202
compute_metrics = lambda p : metrics (p , tokenizer = self .tokenizer , dataset = encoded_dataset ['test' ], verbose = self .config .general ['verbose_metrics' ]),
194
203
data_collator = data_collator , # type: ignore
195
204
tokenizer = None )
205
+ if trainer_callbacks :
206
+ for callback in trainer_callbacks :
207
+ trainer .add_callback (callback (trainer ))
196
208
197
209
trainer .train () # type: ignore
198
210
0 commit comments