Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8697qfvzz train metacat on sup train #516

Merged
merged 6 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY
from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME
from medcat.stats.stats import get_stats
from medcat.stats.mctexport import count_all_annotations, iter_anns
from medcat.utils.filters import set_project_filters
from medcat.utils.usage_monitoring import UsageMonitor

Expand Down Expand Up @@ -808,7 +809,8 @@ def train_supervised_from_json(self,
retain_extra_cui_filter: bool = False,
checkpoint: Optional[Checkpoint] = None,
retain_filters: bool = False,
is_resumed: bool = False) -> Tuple:
is_resumed: bool = False,
train_meta_cats: bool = False) -> Tuple:
"""
Run supervised training on a dataset from MedCATtrainer in JSON format.

Expand All @@ -825,7 +827,7 @@ def train_supervised_from_json(self,
devalue_others, use_groups, never_terminate,
train_from_false_positives, extra_cui_filter,
retain_extra_cui_filter, checkpoint,
retain_filters, is_resumed)
retain_filters, is_resumed, train_meta_cats)

def train_supervised_raw(self,
data: Dict[str, List[Dict[str, dict]]],
Expand All @@ -845,7 +847,8 @@ def train_supervised_raw(self,
retain_extra_cui_filter: bool = False,
checkpoint: Optional[Checkpoint] = None,
retain_filters: bool = False,
is_resumed: bool = False) -> Tuple:
is_resumed: bool = False,
train_meta_cats: bool = False) -> Tuple:
"""Train supervised based on the raw data provided.

The raw data is expected in the following format:
Expand Down Expand Up @@ -922,6 +925,8 @@ def train_supervised_raw(self,
a ValueError is raised. The merging is done in the first epoch.
is_resumed (bool):
If True resume the previous training; If False, start a fresh new training.
train_meta_cats (bool):
If True, also trains the appropriate MetaCATs.

Raises:
ValueError: If attempting to retain filters with while training over multiple projects.
Expand Down Expand Up @@ -1081,6 +1086,21 @@ def train_supervised_raw(self,
use_overlaps=use_overlaps,
use_groups=use_groups,
extra_cui_filter=extra_cui_filter)
if (train_meta_cats and
# NOTE if no annnotaitons, no point
count_all_annotations(data) > 0): # type: ignore
# NOTE: if there
logger.info("Training MetaCATs within train_supervised_raw")
_, _, ann0 = next(iter_anns(data)) # type: ignore
for meta_cat in self._meta_cats:
# only consider meta-cats that have been defined for the category
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic that does the checking if the category can be exchanged with a different name, if thats available on medcat.meta_cat? @shubham-s-agarwal - is the category checking, swapping logic callable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did have a note about this in the PR comment:

PS:
With #515 the check for suitable MetaCATs according to category name may need adjusting since the PR allows for alterantives category names.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though doesn't look like it's easily callable. Would need some kind of general use category name resolver.

if 'meta_anns' in ann0:
ann_names = ann0['meta_anns'].keys() # type: ignore
# adapt to alternative names if applicable
cat_name = meta_cat.config.general.get_applicable_category_name(ann_names)
if cat_name in ann_names:
logger.debug("Training MetaCAT %s", meta_cat.config.general.category_name)
meta_cat.train_raw(data)

# reset the state of filters
self.config.linking.filters = orig_filters
Expand Down
46 changes: 45 additions & 1 deletion tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,27 @@ def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, re
with self.subTest(f'CUI: {filtered_cui}'):
self.assertTrue(filtered_cui in self.undertest.config.linking.filters.cuis)

def _test_train_sup_with_meta_cat(self, train_meta_cats: bool):
# def side_effect(doc, *args, **kwargs):
# raise ValueError()
# # return doc
meta_cat = _get_meta_cat(self.meta_cat_dir)
cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat])
with patch.object(MetaCAT, "train_raw") as mock_train:
with patch.object(MetaCAT, "__call__", side_effect=lambda doc: doc):
cat.train_supervised_raw(get_fixed_meta_cat_data(), never_terminate=True,
train_meta_cats=train_meta_cats)
if train_meta_cats:
mock_train.assert_called()
else:
mock_train.assert_not_called()

def test_train_supervised_does_not_train_meta_cat_by_default(self):
self._test_train_sup_with_meta_cat(False)

def test_train_supervised_can_train_meta_cats(self):
self._test_train_sup_with_meta_cat(True)

def test_train_supervised_no_leak_extra_cui_filters(self):
self.test_train_supervised_does_not_retain_MCT_filters_default(extra_cui_filter={'C123', 'C111'})

Expand Down Expand Up @@ -799,6 +820,9 @@ def test_loading_model_pack_without_any_config_raises_exception(self):
CAT.load_model_pack(self.temp_dir.name)


META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")


def _get_meta_cat(meta_cat_dir):
config = ConfigMetaCAT()
config.general["category_name"] = "Status"
Expand All @@ -808,11 +832,31 @@ def _get_meta_cat(meta_cat_dir):
embeddings=None,
config=config)
os.makedirs(meta_cat_dir, exist_ok=True)
json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")
json_path = META_CAT_JSON_PATH
meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir)
return meta_cat


def get_fixed_meta_cat_data(path: str = META_CAT_JSON_PATH):
with open(path) as f:
data = json.load(f)
for proj_num, project in enumerate(data['projects']):
if 'name' not in project:
project['name'] = f"Proj_{proj_num}"
if 'cuis' not in project:
project['cuis'] = ''
if 'id' not in project:
project['id'] = f'P{proj_num}'
for doc in project['documents']:
if 'entities' in doc and 'annotations' not in doc:
ents = doc.pop("entities")
doc['annotations'] = list(ents.values())
for ann in doc['annotations']:
if 'pretty_name' in ann and 'value' not in ann:
ann['value'] = ann.pop('pretty_name')
return data


class TestLoadingOldWeights(unittest.TestCase):
cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"..", "examples", "cdb_old_broken_weights_in_config.dat")
Expand Down