From 6f752c8d2d10768fe9e3a822eb26a1b5aa973aa7 Mon Sep 17 00:00:00 2001 From: adam-sutton-1992 Date: Wed, 13 Dec 2023 20:09:14 +0000 Subject: [PATCH] bug fixes, additional tests, and more documentation --- medcat/cdb.py | 51 ++++++++++++++++++++++++----------------------- tests/test_cdb.py | 31 ++++++++++++++++++---------- 2 files changed, 47 insertions(+), 35 deletions(-) diff --git a/medcat/cdb.py b/medcat/cdb.py index d773d1f4f..1737b4bad 100644 --- a/medcat/cdb.py +++ b/medcat/cdb.py @@ -932,6 +932,7 @@ def merge_cdb(cdb1: "CDB", overwrite_training: int = 0, full_build: bool = False): """Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed. + `addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build` Args: cdb1 (medcat.cdb.CDB): @@ -967,35 +968,33 @@ def merge_cdb(cdb1: "CDB", for cui in cdb2.cui2names: names = dict() for name in cdb2.cui2names[cui]: - names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}} + names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)} name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A' + # For addl_info check cui2original_names as they MUST be added ontologies = set() description = '' - # For addl_info check cui2original_names as they MUST be added - if full_build and cui in cdb2.addl_info['cui2original_names']: + to_build = False + if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']): + to_build = True if 'cui2ontologies' in cdb2.addl_info: ontologies.update(cdb2.addl_info['cui2ontologies'][cui]) if 'cui2description' in cdb2.addl_info: description = cdb2.addl_info['cui2description'][cui] cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, - type_ids=cdb2.cui2type_ids[cui], description=description, full_build=full_build) + type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build) if cui in cdb1.cui2names: - if cui in cdb1.cui2count_train or cui in cdb2.cui2count_train: - if overwrite_training == 1 and cui in cdb1.cui2count_train[cui]: - cdb.cui2count_train[cui] = cdb1.cui2count_train[cui] - elif overwrite_training == 2 and cui in cdb2.cui2count_train[cui]: + if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train): + if overwrite_training == 2 and cui in cdb2.cui2count_train: cdb.cui2count_train[cui] = cdb2.cui2count_train[cui] else: cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0) - if cui in cdb1.cui2context_vectors: - contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short - if overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]: - weights = [1, 0] - elif overwrite_training == 2 and cui in cdb2.cui2context_vectors[cui]: + if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]): + if overwrite_training == 2 and cui in cdb2.cui2context_vectors: weights = [0, 1] else: norm = cdb.cui2count_train[cui] weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)] + contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short for s in contexts: cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) if cui in cdb1.cui2tags: @@ -1014,23 +1013,25 @@ def merge_cdb(cdb1: "CDB", if cui in cdb2.cui2type_ids: cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui] - for name in cdb2.name2cuis: - if name in cdb1.name2cuis: # if they exist in both cdbs - if name in cdb1.name2count_train and name in cdb2.name2count_train: - cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason - else: - if name in cdb2.name2count_train: - cdb.name2count_train[name] = cdb2.name2count_train[name] + if overwrite_training != 1: + for name in cdb2.name2cuis: + if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs + if name in cdb1.name2count_train and name in cdb2.name2count_train: + cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason + else: + if name in cdb2.name2count_train: + cdb.name2count_train[name] = cdb2.name2count_train[name] # snames cdb.snames = cdb1.snames.union(cdb2.snames) # vocab, adding counts if they occur in both cdb.vocab = deepcopy(cdb1.vocab) - for word in cdb2.vocab: - if word in cdb.vocab: - cdb.vocab[word] += cdb2.vocab[word] - else: - cdb.vocab[word] = cdb2.vocab[word] + if overwrite_training != 1: + for word in cdb2.vocab: + if word in cdb.vocab and overwrite_training == 0: + cdb.vocab[word] += cdb2.vocab[word] + else: + cdb.vocab[word] = cdb2.vocab[word] return cdb diff --git a/tests/test_cdb.py b/tests/test_cdb.py index 29c603daa..3ff7e5dad 100644 --- a/tests/test_cdb.py +++ b/tests/test_cdb.py @@ -101,32 +101,43 @@ def test_merge_cdb(self): cdb1 = maker1.prepare_csvs(csv_paths=[path]) cdb2 = maker2.prepare_csvs(csv_paths=[path]) - # generating vectors and setting up + # generating context vectors here for for testing the weighted average function (based off cui2count_train) zeroes = np.zeros(shape=(1,300)) ones = np.ones(shape=(1,300)) for i, cui in enumerate(cdb1.cui2names): - cdb1.cui2context_vectors[cui] = {"short" : ones} - cdb2.cui2context_vectors[cui] = {"short" : zeroes} + cdb1.cui2context_vectors[cui] = {"short": ones} + cdb2.cui2context_vectors[cui] = {"short": zeroes} cdb1.cui2count_train[cui] = 1 cdb2.cui2count_train[cui] = i + 1 - test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper" : "P"}} + # adding new names and cuis to each cdb to test after merging + test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}} cdb1.add_names("C0006826", test_add) - unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper" : "P"}} + unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}} cdb2.add_names("UniqueTest", unique_test) - cdb2.cui2context_vectors["UniqueTest"] = {"short" : ones} + cdb2.cui2context_vectors["UniqueTest"] = {"short": zeroes} + cdb2.addl_info["cui2ontologies"] = {} + cdb2.addl_info["cui2description"] = {} + for cui in cdb2.cui2names: + cdb2.addl_info["cui2ontologies"][cui] = ["test_ontology"] + cdb2.addl_info["cui2description"][cui] = "test_description" # merging cdb = CDB.merge_cdb(cdb1=cdb1, cdb2=cdb2) + overwrite_cdb = CDB.merge_cdb(cdb1=cdb1, cdb2=cdb2, overwrite_training=2, full_build=True) # tests self.assertIn("test", cdb.cui2names["C0006826"]) self.assertIn("test_name", cdb.cui2snames["C0006826"]) self.assertEqual("Cancer", cdb.cui2preferred_name["C0006826"]) - self.assertTrue(np.array_equal(np.ones(shape=(1,300)), cdb.cui2context_vectors["UniqueTest"]["short"])) - base = np.ones(shape=(1,300)) + self.assertTrue(np.array_equal(zeroes, cdb.cui2context_vectors["UniqueTest"]["short"])) for i, cui in enumerate(cdb1.cui2names): - self.assertTrue(np.array_equal(cdb.cui2context_vectors[cui]["short"], np.divide(base, i+2))) - + self.assertTrue(np.array_equal(cdb.cui2context_vectors[cui]["short"], np.divide(ones, i+2))) + self.assertEqual(cdb.addl_info["cui2ontologies"], dict()) + self.assertEqual(cdb.addl_info["cui2ontologies"], dict()) + for cui in cdb2.cui2names: + self.assertTrue(np.array_equal(overwrite_cdb.cui2context_vectors[cui]["short"], zeroes)) + self.assertEqual(overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"}) + self.assertEqual(overwrite_cdb.addl_info["cui2description"][cui], "test_description") if __name__ == '__main__':