Skip to content

Commit e0fa3af

Browse files
authored
CU-8697qfvzz train metacat on sup train (#516)
* CU-8697qfvzz: Add new optional keyword argumnet to allow training MetaCAT models during supervised training * CU-8697qfvzz: Add tests regarding training meta-cats during supervised training * CU-8697qfvzz: Fix small typo in comment * CU-8697qfvzz: Allow using alternative category names if/when training meta cats through CAT.train_supervised
1 parent 6f3403e commit e0fa3af

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

medcat/cat.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY
4141
from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME
4242
from medcat.stats.stats import get_stats
43+
from medcat.stats.mctexport import count_all_annotations, iter_anns
4344
from medcat.utils.filters import set_project_filters
4445
from medcat.utils.usage_monitoring import UsageMonitor
4546

@@ -808,7 +809,8 @@ def train_supervised_from_json(self,
808809
retain_extra_cui_filter: bool = False,
809810
checkpoint: Optional[Checkpoint] = None,
810811
retain_filters: bool = False,
811-
is_resumed: bool = False) -> Tuple:
812+
is_resumed: bool = False,
813+
train_meta_cats: bool = False) -> Tuple:
812814
"""
813815
Run supervised training on a dataset from MedCATtrainer in JSON format.
814816
@@ -825,7 +827,7 @@ def train_supervised_from_json(self,
825827
devalue_others, use_groups, never_terminate,
826828
train_from_false_positives, extra_cui_filter,
827829
retain_extra_cui_filter, checkpoint,
828-
retain_filters, is_resumed)
830+
retain_filters, is_resumed, train_meta_cats)
829831

830832
def train_supervised_raw(self,
831833
data: Dict[str, List[Dict[str, dict]]],
@@ -845,7 +847,8 @@ def train_supervised_raw(self,
845847
retain_extra_cui_filter: bool = False,
846848
checkpoint: Optional[Checkpoint] = None,
847849
retain_filters: bool = False,
848-
is_resumed: bool = False) -> Tuple:
850+
is_resumed: bool = False,
851+
train_meta_cats: bool = False) -> Tuple:
849852
"""Train supervised based on the raw data provided.
850853
851854
The raw data is expected in the following format:
@@ -922,6 +925,8 @@ def train_supervised_raw(self,
922925
a ValueError is raised. The merging is done in the first epoch.
923926
is_resumed (bool):
924927
If True resume the previous training; If False, start a fresh new training.
928+
train_meta_cats (bool):
929+
If True, also trains the appropriate MetaCATs.
925930
926931
Raises:
927932
ValueError: If attempting to retain filters with while training over multiple projects.
@@ -1081,6 +1086,21 @@ def train_supervised_raw(self,
10811086
use_overlaps=use_overlaps,
10821087
use_groups=use_groups,
10831088
extra_cui_filter=extra_cui_filter)
1089+
if (train_meta_cats and
1090+
# NOTE if no annnotaitons, no point
1091+
count_all_annotations(data) > 0): # type: ignore
1092+
# NOTE: if there
1093+
logger.info("Training MetaCATs within train_supervised_raw")
1094+
_, _, ann0 = next(iter_anns(data)) # type: ignore
1095+
for meta_cat in self._meta_cats:
1096+
# only consider meta-cats that have been defined for the category
1097+
if 'meta_anns' in ann0:
1098+
ann_names = ann0['meta_anns'].keys() # type: ignore
1099+
# adapt to alternative names if applicable
1100+
cat_name = meta_cat.config.general.get_applicable_category_name(ann_names)
1101+
if cat_name in ann_names:
1102+
logger.debug("Training MetaCAT %s", meta_cat.config.general.category_name)
1103+
meta_cat.train_raw(data)
10841104

10851105
# reset the state of filters
10861106
self.config.linking.filters = orig_filters

tests/test_cat.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,27 @@ def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, re
383383
with self.subTest(f'CUI: {filtered_cui}'):
384384
self.assertTrue(filtered_cui in self.undertest.config.linking.filters.cuis)
385385

386+
def _test_train_sup_with_meta_cat(self, train_meta_cats: bool):
387+
# def side_effect(doc, *args, **kwargs):
388+
# raise ValueError()
389+
# # return doc
390+
meta_cat = _get_meta_cat(self.meta_cat_dir)
391+
cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat])
392+
with patch.object(MetaCAT, "train_raw") as mock_train:
393+
with patch.object(MetaCAT, "__call__", side_effect=lambda doc: doc):
394+
cat.train_supervised_raw(get_fixed_meta_cat_data(), never_terminate=True,
395+
train_meta_cats=train_meta_cats)
396+
if train_meta_cats:
397+
mock_train.assert_called()
398+
else:
399+
mock_train.assert_not_called()
400+
401+
def test_train_supervised_does_not_train_meta_cat_by_default(self):
402+
self._test_train_sup_with_meta_cat(False)
403+
404+
def test_train_supervised_can_train_meta_cats(self):
405+
self._test_train_sup_with_meta_cat(True)
406+
386407
def test_train_supervised_no_leak_extra_cui_filters(self):
387408
self.test_train_supervised_does_not_retain_MCT_filters_default(extra_cui_filter={'C123', 'C111'})
388409

@@ -799,6 +820,9 @@ def test_loading_model_pack_without_any_config_raises_exception(self):
799820
CAT.load_model_pack(self.temp_dir.name)
800821

801822

823+
META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")
824+
825+
802826
def _get_meta_cat(meta_cat_dir):
803827
config = ConfigMetaCAT()
804828
config.general["category_name"] = "Status"
@@ -808,11 +832,31 @@ def _get_meta_cat(meta_cat_dir):
808832
embeddings=None,
809833
config=config)
810834
os.makedirs(meta_cat_dir, exist_ok=True)
811-
json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")
835+
json_path = META_CAT_JSON_PATH
812836
meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir)
813837
return meta_cat
814838

815839

840+
def get_fixed_meta_cat_data(path: str = META_CAT_JSON_PATH):
841+
with open(path) as f:
842+
data = json.load(f)
843+
for proj_num, project in enumerate(data['projects']):
844+
if 'name' not in project:
845+
project['name'] = f"Proj_{proj_num}"
846+
if 'cuis' not in project:
847+
project['cuis'] = ''
848+
if 'id' not in project:
849+
project['id'] = f'P{proj_num}'
850+
for doc in project['documents']:
851+
if 'entities' in doc and 'annotations' not in doc:
852+
ents = doc.pop("entities")
853+
doc['annotations'] = list(ents.values())
854+
for ann in doc['annotations']:
855+
if 'pretty_name' in ann and 'value' not in ann:
856+
ann['value'] = ann.pop('pretty_name')
857+
return data
858+
859+
816860
class TestLoadingOldWeights(unittest.TestCase):
817861
cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
818862
"..", "examples", "cdb_old_broken_weights_in_config.dat")

0 commit comments

Comments
 (0)