@@ -383,6 +383,27 @@ def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, re
383383 with self .subTest (f'CUI: { filtered_cui } ' ):
384384 self .assertTrue (filtered_cui in self .undertest .config .linking .filters .cuis )
385385
386+ def _test_train_sup_with_meta_cat (self , train_meta_cats : bool ):
387+ # def side_effect(doc, *args, **kwargs):
388+ # raise ValueError()
389+ # # return doc
390+ meta_cat = _get_meta_cat (self .meta_cat_dir )
391+ cat = CAT (cdb = self .cdb , config = self .cdb .config , vocab = self .vocab , meta_cats = [meta_cat ])
392+ with patch .object (MetaCAT , "train_raw" ) as mock_train :
393+ with patch .object (MetaCAT , "__call__" , side_effect = lambda doc : doc ):
394+ cat .train_supervised_raw (get_fixed_meta_cat_data (), never_terminate = True ,
395+ train_meta_cats = train_meta_cats )
396+ if train_meta_cats :
397+ mock_train .assert_called ()
398+ else :
399+ mock_train .assert_not_called ()
400+
401+ def test_train_supervised_does_not_train_meta_cat_by_default (self ):
402+ self ._test_train_sup_with_meta_cat (False )
403+
404+ def test_train_supervised_can_train_meta_cats (self ):
405+ self ._test_train_sup_with_meta_cat (True )
406+
386407 def test_train_supervised_no_leak_extra_cui_filters (self ):
387408 self .test_train_supervised_does_not_retain_MCT_filters_default (extra_cui_filter = {'C123' , 'C111' })
388409
@@ -799,6 +820,9 @@ def test_loading_model_pack_without_any_config_raises_exception(self):
799820 CAT .load_model_pack (self .temp_dir .name )
800821
801822
823+ META_CAT_JSON_PATH = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "resources" , "mct_export_for_meta_cat_test.json" )
824+
825+
802826def _get_meta_cat (meta_cat_dir ):
803827 config = ConfigMetaCAT ()
804828 config .general ["category_name" ] = "Status"
@@ -808,11 +832,31 @@ def _get_meta_cat(meta_cat_dir):
808832 embeddings = None ,
809833 config = config )
810834 os .makedirs (meta_cat_dir , exist_ok = True )
811- json_path = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ )), "resources" , "mct_export_for_meta_cat_test.json" )
835+ json_path = META_CAT_JSON_PATH
812836 meta_cat .train_from_json (json_path , save_dir_path = meta_cat_dir )
813837 return meta_cat
814838
815839
840+ def get_fixed_meta_cat_data (path : str = META_CAT_JSON_PATH ):
841+ with open (path ) as f :
842+ data = json .load (f )
843+ for proj_num , project in enumerate (data ['projects' ]):
844+ if 'name' not in project :
845+ project ['name' ] = f"Proj_{ proj_num } "
846+ if 'cuis' not in project :
847+ project ['cuis' ] = ''
848+ if 'id' not in project :
849+ project ['id' ] = f'P{ proj_num } '
850+ for doc in project ['documents' ]:
851+ if 'entities' in doc and 'annotations' not in doc :
852+ ents = doc .pop ("entities" )
853+ doc ['annotations' ] = list (ents .values ())
854+ for ann in doc ['annotations' ]:
855+ if 'pretty_name' in ann and 'value' not in ann :
856+ ann ['value' ] = ann .pop ('pretty_name' )
857+ return data
858+
859+
816860class TestLoadingOldWeights (unittest .TestCase ):
817861 cdb_path = os .path .join (os .path .dirname (os .path .realpath (__file__ )),
818862 ".." , "examples" , "cdb_old_broken_weights_in_config.dat" )
0 commit comments