Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 6ec9abb

Browse files
committed
CU-8697qfvzz: Add tests regarding training meta-cats during supervised training
1 parent c84dcde commit 6ec9abb

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

tests/test_cat.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,27 @@ def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, re
383383
with self.subTest(f'CUI: {filtered_cui}'):
384384
self.assertTrue(filtered_cui in self.undertest.config.linking.filters.cuis)
385385

386+
def _test_train_sup_with_meta_cat(self, train_meta_cats: bool):
387+
# def side_effect(doc, *args, **kwargs):
388+
# raise ValueError()
389+
# # return doc
390+
meta_cat = _get_meta_cat(self.meta_cat_dir)
391+
cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat])
392+
with patch.object(MetaCAT, "train_raw") as mock_train:
393+
with patch.object(MetaCAT, "__call__", side_effect=lambda doc: doc):
394+
cat.train_supervised_raw(get_fixed_meta_cat_data(), never_terminate=True,
395+
train_meta_cats=train_meta_cats)
396+
if train_meta_cats:
397+
mock_train.assert_called()
398+
else:
399+
mock_train.assert_not_called()
400+
401+
def test_train_supervised_does_not_train_meta_cat_by_default(self):
402+
self._test_train_sup_with_meta_cat(False)
403+
404+
def test_train_supervised_can_train_meta_cats(self):
405+
self._test_train_sup_with_meta_cat(True)
406+
386407
def test_train_supervised_no_leak_extra_cui_filters(self):
387408
self.test_train_supervised_does_not_retain_MCT_filters_default(extra_cui_filter={'C123', 'C111'})
388409

@@ -799,6 +820,9 @@ def test_loading_model_pack_without_any_config_raises_exception(self):
799820
CAT.load_model_pack(self.temp_dir.name)
800821

801822

823+
META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")
824+
825+
802826
def _get_meta_cat(meta_cat_dir):
803827
config = ConfigMetaCAT()
804828
config.general["category_name"] = "Status"
@@ -808,11 +832,31 @@ def _get_meta_cat(meta_cat_dir):
808832
embeddings=None,
809833
config=config)
810834
os.makedirs(meta_cat_dir, exist_ok=True)
811-
json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")
835+
json_path = META_CAT_JSON_PATH
812836
meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir)
813837
return meta_cat
814838

815839

840+
def get_fixed_meta_cat_data(path: str = META_CAT_JSON_PATH):
841+
with open(path) as f:
842+
data = json.load(f)
843+
for proj_num, project in enumerate(data['projects']):
844+
if 'name' not in project:
845+
project['name'] = f"Proj_{proj_num}"
846+
if 'cuis' not in project:
847+
project['cuis'] = ''
848+
if 'id' not in project:
849+
project['id'] = f'P{proj_num}'
850+
for doc in project['documents']:
851+
if 'entities' in doc and 'annotations' not in doc:
852+
ents = doc.pop("entities")
853+
doc['annotations'] = list(ents.values())
854+
for ann in doc['annotations']:
855+
if 'pretty_name' in ann and 'value' not in ann:
856+
ann['value'] = ann.pop('pretty_name')
857+
return data
858+
859+
816860
class TestLoadingOldWeights(unittest.TestCase):
817861
cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
818862
"..", "examples", "cdb_old_broken_weights_in_config.dat")

0 commit comments

Comments
 (0)