Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit fc992a7

Browse files
committed
CU-86938vf30 add trainer callbacks for Transformer NER
1 parent 76b75cc commit fc992a7

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

medcat/ner/transformers_ner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import json
33
import logging
4+
import datasets
45
from spacy.tokens import Doc
56
from datetime import datetime
67
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union
@@ -18,7 +19,7 @@
1819

1920
from transformers import Trainer, AutoModelForTokenClassification, AutoTokenizer
2021
from transformers import pipeline, TrainingArguments
21-
import datasets
22+
from transformers.trainer_callback import TrainerCallback
2223

2324
# It should be safe to do this always, as all other multiprocessing
2425
#will be finished before data comes to meta_cat
@@ -137,7 +138,12 @@ def merge_data_loaded(base, other):
137138

138139
return out_path
139140

140-
def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=False, dataset=None, meta_requirements=None):
141+
def train(self,
142+
json_path: Union[str, list, None]=None,
143+
ignore_extra_labels=False,
144+
dataset=None,
145+
meta_requirements=None,
146+
trainer_callbacks: Optional[List[TrainerCallback]]=None):
141147
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
142148
continue training if an existing model is loaded or start new training if the model is blank/new.
143149
@@ -149,6 +155,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals
149155
ignore_extra_labels:
150156
Makes only sense when an existing deid model was loaded and from the new data we want to ignore
151157
labels that did not exist in the old model.
158+
trainer_callbacks (List[TrainerCallback]):
159+
A list of trainer callbacks for collecting metrics during the training at the client side. The
160+
transformers Trainer object will be passed in when each callback is called.
152161
"""
153162

154163
if dataset is None and json_path is not None:
@@ -193,6 +202,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals
193202
compute_metrics=lambda p: metrics(p, tokenizer=self.tokenizer, dataset=encoded_dataset['test'], verbose=self.config.general['verbose_metrics']),
194203
data_collator=data_collator, # type: ignore
195204
tokenizer=None)
205+
if trainer_callbacks:
206+
for callback in trainer_callbacks:
207+
trainer.add_callback(callback(trainer))
196208

197209
trainer.train() # type: ignore
198210

tests/ner/test_transformers_ner.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import os
2+
import unittest
3+
from spacy.lang.en import English
4+
from spacy.tokens import Doc, Span
5+
from transformers import TrainerCallback
6+
from medcat.ner.transformers_ner import TransformersNER
7+
from medcat.config import Config
8+
from medcat.cdb_maker import CDBMaker
9+
10+
11+
class TransformerNERTest(unittest.TestCase):
12+
13+
@classmethod
14+
def setUpClass(cls) -> None:
15+
config = Config()
16+
config.general["spacy_model"] = "en_core_web_md"
17+
cdb_maker = CDBMaker(config)
18+
cdb_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "cdb.csv")
19+
cdb = cdb_maker.prepare_csvs([cdb_csv], full_build=True)
20+
Doc.set_extension("ents", default=[], force=True)
21+
Span.set_extension("confidence", default=-1, force=True)
22+
Span.set_extension("id", default=0, force=True)
23+
Span.set_extension("detected_name", default=None, force=True)
24+
Span.set_extension("link_candidates", default=None, force=True)
25+
Span.set_extension("cui", default=-1, force=True)
26+
Span.set_extension("context_similarity", default=-1, force=True)
27+
cls.undertest = TransformersNER(cdb)
28+
cls.undertest.create_eval_pipeline()
29+
30+
def test_pipe(self):
31+
doc = English().make_doc("Intracerebral hemorrhage is not Movar Virus")
32+
doc = next(self.undertest.pipe([doc]))
33+
assert len(doc.ents) > 0, "No entities were recognised"
34+
35+
def test_train_with_callbacks(self):
36+
tracker = unittest.mock.Mock()
37+
class _DummyCallback(TrainerCallback):
38+
def __init__(self, trainer) -> None:
39+
self._trainer = trainer
40+
def on_epoch_end(self, *args, **kwargs) -> None:
41+
tracker.call()
42+
43+
train_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_train_data.json")
44+
self.undertest.training_arguments.num_train_epochs = 1
45+
self.undertest.train(train_data, trainer_callbacks=[_DummyCallback, _DummyCallback])
46+
self.assertEqual(tracker.call.call_count, 2)

0 commit comments

Comments
 (0)