Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 85cbe77

Browse files
authored
add: metacat can predict on spans in arbitrary spangroups (#391)
* add: ability to predict on other spangroups * add: pr comments and better error * fix: typo * fix: linting
1 parent 4de8931 commit 85cbe77

File tree

3 files changed

+67
-12
lines changed

3 files changed

+67
-12
lines changed

medcat/config_meta_cat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class General(MixingConfig, BaseModel):
3737
a deployment."""
3838
pipe_batch_size_in_chars: int = 20000000
3939
"""How many characters are piped at once into the meta_cat class"""
40+
span_group: Optional[str] = None
41+
"""If set, the spacy span group that the metacat model will assign annotations.
42+
Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings"""
4043

4144
class Config:
4245
extra = Extra.allow

medcat/meta_cat.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy
66
from multiprocessing import Lock
77
from torch import nn, Tensor
8-
from spacy.tokens import Doc
8+
from spacy.tokens import Doc, Span
99
from datetime import datetime
1010
from typing import Iterable, Iterator, Optional, Dict, List, Tuple, cast, Union
1111
from medcat.utils.hasher import Hasher
@@ -357,6 +357,20 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
357357

358358
return meta_cat
359359

360+
def get_ents(self, doc: Doc) -> Iterable[Span]:
361+
spangroup_name = self.config.general.span_group
362+
if spangroup_name:
363+
try:
364+
return doc.spans[spangroup_name]
365+
except KeyError:
366+
raise Exception(f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.")
367+
368+
# Should we annotate overlapping entities
369+
if self.config.general['annotate_overlapping']:
370+
return doc._.ents
371+
372+
return doc.ents
373+
360374
def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowercase: bool) -> Tuple:
361375
"""Prepares document.
362376
@@ -381,11 +395,7 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
381395
cntx_right = config.general['cntx_right']
382396
replace_center = config.general['replace_center']
383397

384-
# Should we annotate overlapping entities
385-
if config.general['annotate_overlapping']:
386-
ents = doc._.ents
387-
else:
388-
ents = doc.ents
398+
ents = self.get_ents(doc)
389399

390400
samples = []
391401
last_ind = 0
@@ -522,10 +532,7 @@ def _set_meta_anns(self,
522532

523533
predictions = all_predictions[start_ind:end_ind]
524534
confidences = all_confidences[start_ind:end_ind]
525-
if config.general['annotate_overlapping']:
526-
ents = doc._.ents
527-
else:
528-
ents = doc.ents
535+
ents = self.get_ents(doc)
529536

530537
for ent in ents:
531538
ent_ind = ent_id2ind[ent._.id]

tests/test_meta_cat.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from medcat.meta_cat import MetaCAT
88
from medcat.config_meta_cat import ConfigMetaCAT
99
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
10-
10+
import spacy
11+
from spacy.tokens import Span
1112

1213
class MetaCATTests(unittest.TestCase):
1314

@@ -19,7 +20,7 @@ def setUpClass(cls) -> None:
1920
config.train['nepochs'] = 1
2021
config.model['input_size'] = 100
2122

22-
cls.meta_cat = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config)
23+
cls.meta_cat: MetaCAT = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config)
2324

2425
cls.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
2526
os.makedirs(cls.tmp_dir, exist_ok=True)
@@ -44,6 +45,50 @@ def test_save_load(self):
4445

4546
self.assertEqual(f1, n_f1)
4647

48+
def _prepare_doc_w_spangroup(self, spangroup_name: str):
49+
"""
50+
Create spans under an arbitrary spangroup key
51+
"""
52+
Span.set_extension('id', default=0, force=True)
53+
Span.set_extension('meta_anns', default=None, force=True)
54+
nlp = spacy.blank("en")
55+
doc = nlp("Pt has diabetes and copd.")
56+
span_0 = doc.char_span(7,15, label="diabetes")
57+
assert span_0.text == 'diabetes'
58+
59+
span_1 = doc.char_span(20,24, label="copd")
60+
assert span_1.text == 'copd'
61+
doc.spans[spangroup_name] = [span_0, span_1]
62+
return doc
63+
64+
def test_predict_spangroup(self):
65+
json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json')
66+
self.meta_cat.train(json_path, save_dir_path=self.tmp_dir)
67+
self.meta_cat.save(self.tmp_dir)
68+
n_meta_cat = MetaCAT.load(self.tmp_dir)
69+
70+
spangroup_name = "mock_span_group"
71+
n_meta_cat.config.general.span_group = spangroup_name
72+
73+
doc = self._prepare_doc_w_spangroup(spangroup_name)
74+
doc = n_meta_cat(doc)
75+
spans = doc.spans[spangroup_name]
76+
self.assertEqual(len(spans), 2)
77+
78+
# All spans are annotate
79+
for span in spans:
80+
self.assertEqual(span._.meta_anns['Status']['value'], "Affirmed")
81+
82+
# Informative error if spangroup is not set
83+
doc = self._prepare_doc_w_spangroup("foo")
84+
n_meta_cat.config.general.span_group = "bar"
85+
try:
86+
doc = n_meta_cat(doc)
87+
except Exception as error:
88+
self.assertIn("Configuration error", str(error))
89+
90+
n_meta_cat.config.general.span_group = None
91+
4792

4893
if __name__ == '__main__':
4994
unittest.main()

0 commit comments

Comments
 (0)