-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CU-8697qfvzz train metacat on sup train #516
Changes from 2 commits
c84dcde
6ec9abb
66be372
5a1c8cb
06978de
2b2f259
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ | |
from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY | ||
from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME | ||
from medcat.stats.stats import get_stats | ||
from medcat.stats.mctexport import count_all_annotations, iter_anns | ||
from medcat.utils.filters import set_project_filters | ||
from medcat.utils.usage_monitoring import UsageMonitor | ||
|
||
|
@@ -808,7 +809,8 @@ def train_supervised_from_json(self, | |
retain_extra_cui_filter: bool = False, | ||
checkpoint: Optional[Checkpoint] = None, | ||
retain_filters: bool = False, | ||
is_resumed: bool = False) -> Tuple: | ||
is_resumed: bool = False, | ||
train_meta_cats: bool = False) -> Tuple: | ||
""" | ||
Run supervised training on a dataset from MedCATtrainer in JSON format. | ||
|
||
|
@@ -825,7 +827,7 @@ def train_supervised_from_json(self, | |
devalue_others, use_groups, never_terminate, | ||
train_from_false_positives, extra_cui_filter, | ||
retain_extra_cui_filter, checkpoint, | ||
retain_filters, is_resumed) | ||
retain_filters, is_resumed, train_meta_cats) | ||
|
||
def train_supervised_raw(self, | ||
data: Dict[str, List[Dict[str, dict]]], | ||
|
@@ -845,7 +847,8 @@ def train_supervised_raw(self, | |
retain_extra_cui_filter: bool = False, | ||
checkpoint: Optional[Checkpoint] = None, | ||
retain_filters: bool = False, | ||
is_resumed: bool = False) -> Tuple: | ||
is_resumed: bool = False, | ||
train_meta_cats: bool = False) -> Tuple: | ||
"""Train supervised based on the raw data provided. | ||
|
||
The raw data is expected in the following format: | ||
|
@@ -922,6 +925,8 @@ def train_supervised_raw(self, | |
a ValueError is raised. The merging is done in the first epoch. | ||
is_resumed (bool): | ||
If True resume the previous training; If False, start a fresh new training. | ||
train_meta_cats (bool): | ||
If True, also trains the appropriate MetaCATs. | ||
|
||
Raises: | ||
ValueError: If attempting to retain filters with while training over multiple projects. | ||
|
@@ -1081,6 +1086,20 @@ def train_supervised_raw(self, | |
use_overlaps=use_overlaps, | ||
use_groups=use_groups, | ||
extra_cui_filter=extra_cui_filter) | ||
if (train_meta_cats and | ||
# NOTE if no annnotaitons, not point | ||
count_all_annotations(data) > 0): # type: ignore | ||
# NOTE: if there | ||
logger.info("Training MetaCATs within train_supervised_raw") | ||
_, _, ann0 = next(iter_anns(data)) # type: ignore | ||
for meta_cat in self._meta_cats: | ||
# only consider meta-cats that have been defined for the category | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the logic that does the checking if the category can be exchanged with a different name, if thats available on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I did have a note about this in the PR comment:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though doesn't look like it's easily callable. Would need some kind of general use category name resolver. |
||
# NOTE: as of PR #515 this may become more complicated since it could work | ||
# without this exact match as well | ||
cat_name = meta_cat.config.general.category_name | ||
if 'meta_anns' in ann0 and cat_name in ann0['meta_anns']: # type: ignore | ||
logger.debug("Training MetaCAT %s", meta_cat.config.general.category_name) | ||
meta_cat.train_raw(data) | ||
|
||
# reset the state of filters | ||
self.config.linking.filters = orig_filters | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo in the comment?
no point