40
40
from medcat .utils .saving .serializer import SPECIALITY_NAMES , ONE2MANY
41
41
from medcat .utils .saving .envsnapshot import get_environment_info , ENV_SNAPSHOT_FILE_NAME
42
42
from medcat .stats .stats import get_stats
43
+ from medcat .stats .mctexport import count_all_annotations , iter_anns
43
44
from medcat .utils .filters import set_project_filters
44
45
from medcat .utils .usage_monitoring import UsageMonitor
45
46
@@ -808,7 +809,8 @@ def train_supervised_from_json(self,
808
809
retain_extra_cui_filter : bool = False ,
809
810
checkpoint : Optional [Checkpoint ] = None ,
810
811
retain_filters : bool = False ,
811
- is_resumed : bool = False ) -> Tuple :
812
+ is_resumed : bool = False ,
813
+ train_meta_cats : bool = False ) -> Tuple :
812
814
"""
813
815
Run supervised training on a dataset from MedCATtrainer in JSON format.
814
816
@@ -825,7 +827,7 @@ def train_supervised_from_json(self,
825
827
devalue_others , use_groups , never_terminate ,
826
828
train_from_false_positives , extra_cui_filter ,
827
829
retain_extra_cui_filter , checkpoint ,
828
- retain_filters , is_resumed )
830
+ retain_filters , is_resumed , train_meta_cats )
829
831
830
832
def train_supervised_raw (self ,
831
833
data : Dict [str , List [Dict [str , dict ]]],
@@ -845,7 +847,8 @@ def train_supervised_raw(self,
845
847
retain_extra_cui_filter : bool = False ,
846
848
checkpoint : Optional [Checkpoint ] = None ,
847
849
retain_filters : bool = False ,
848
- is_resumed : bool = False ) -> Tuple :
850
+ is_resumed : bool = False ,
851
+ train_meta_cats : bool = False ) -> Tuple :
849
852
"""Train supervised based on the raw data provided.
850
853
851
854
The raw data is expected in the following format:
@@ -922,6 +925,8 @@ def train_supervised_raw(self,
922
925
a ValueError is raised. The merging is done in the first epoch.
923
926
is_resumed (bool):
924
927
If True resume the previous training; If False, start a fresh new training.
928
+ train_meta_cats (bool):
929
+ If True, also trains the appropriate MetaCATs.
925
930
926
931
Raises:
927
932
ValueError: If attempting to retain filters with while training over multiple projects.
@@ -1081,6 +1086,21 @@ def train_supervised_raw(self,
1081
1086
use_overlaps = use_overlaps ,
1082
1087
use_groups = use_groups ,
1083
1088
extra_cui_filter = extra_cui_filter )
1089
+ if (train_meta_cats and
1090
+ # NOTE if no annnotaitons, no point
1091
+ count_all_annotations (data ) > 0 ): # type: ignore
1092
+ # NOTE: if there
1093
+ logger .info ("Training MetaCATs within train_supervised_raw" )
1094
+ _ , _ , ann0 = next (iter_anns (data )) # type: ignore
1095
+ for meta_cat in self ._meta_cats :
1096
+ # only consider meta-cats that have been defined for the category
1097
+ if 'meta_anns' in ann0 :
1098
+ ann_names = ann0 ['meta_anns' ].keys () # type: ignore
1099
+ # adapt to alternative names if applicable
1100
+ cat_name = meta_cat .config .general .get_applicable_category_name (ann_names )
1101
+ if cat_name in ann_names :
1102
+ logger .debug ("Training MetaCAT %s" , meta_cat .config .general .category_name )
1103
+ meta_cat .train_raw (data )
1084
1104
1085
1105
# reset the state of filters
1086
1106
self .config .linking .filters = orig_filters
0 commit comments