Skip to content

Commit

Permalink
bug fixes, additional tests, and more documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-sutton-1992 committed Dec 13, 2023
1 parent 1975b1c commit 6f752c8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 35 deletions.
51 changes: 26 additions & 25 deletions medcat/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,7 @@ def merge_cdb(cdb1: "CDB",
overwrite_training: int = 0,
full_build: bool = False):
"""Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed.
`addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build`
Args:
cdb1 (medcat.cdb.CDB):
Expand Down Expand Up @@ -967,35 +968,33 @@ def merge_cdb(cdb1: "CDB",
for cui in cdb2.cui2names:
names = dict()
for name in cdb2.cui2names[cui]:
names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}}
names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)}
name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A'
# For addl_info check cui2original_names as they MUST be added
ontologies = set()
description = ''
# For addl_info check cui2original_names as they MUST be added
if full_build and cui in cdb2.addl_info['cui2original_names']:
to_build = False
if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']):
to_build = True
if 'cui2ontologies' in cdb2.addl_info:
ontologies.update(cdb2.addl_info['cui2ontologies'][cui])
if 'cui2description' in cdb2.addl_info:
description = cdb2.addl_info['cui2description'][cui]
cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
type_ids=cdb2.cui2type_ids[cui], description=description, full_build=full_build)
type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build)
if cui in cdb1.cui2names:
if cui in cdb1.cui2count_train or cui in cdb2.cui2count_train:
if overwrite_training == 1 and cui in cdb1.cui2count_train[cui]:
cdb.cui2count_train[cui] = cdb1.cui2count_train[cui]
elif overwrite_training == 2 and cui in cdb2.cui2count_train[cui]:
if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train):
if overwrite_training == 2 and cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2count_train[cui]
else:
cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0)
if cui in cdb1.cui2context_vectors:
contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short
if overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]:
weights = [1, 0]
elif overwrite_training == 2 and cui in cdb2.cui2context_vectors[cui]:
if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]):
if overwrite_training == 2 and cui in cdb2.cui2context_vectors:
weights = [0, 1]
else:
norm = cdb.cui2count_train[cui]
weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)]
contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short
for s in contexts:
cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300))))
if cui in cdb1.cui2tags:
Expand All @@ -1014,23 +1013,25 @@ def merge_cdb(cdb1: "CDB",
if cui in cdb2.cui2type_ids:
cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui]

for name in cdb2.name2cuis:
if name in cdb1.name2cuis: # if they exist in both cdbs
if name in cdb1.name2count_train and name in cdb2.name2count_train:
cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason
else:
if name in cdb2.name2count_train:
cdb.name2count_train[name] = cdb2.name2count_train[name]
if overwrite_training != 1:
for name in cdb2.name2cuis:
if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs
if name in cdb1.name2count_train and name in cdb2.name2count_train:
cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason
else:
if name in cdb2.name2count_train:
cdb.name2count_train[name] = cdb2.name2count_train[name]

# snames
cdb.snames = cdb1.snames.union(cdb2.snames)

# vocab, adding counts if they occur in both
cdb.vocab = deepcopy(cdb1.vocab)
for word in cdb2.vocab:
if word in cdb.vocab:
cdb.vocab[word] += cdb2.vocab[word]
else:
cdb.vocab[word] = cdb2.vocab[word]
if overwrite_training != 1:
for word in cdb2.vocab:
if word in cdb.vocab and overwrite_training == 0:
cdb.vocab[word] += cdb2.vocab[word]
else:
cdb.vocab[word] = cdb2.vocab[word]

return cdb
31 changes: 21 additions & 10 deletions tests/test_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,32 +101,43 @@ def test_merge_cdb(self):
cdb1 = maker1.prepare_csvs(csv_paths=[path])
cdb2 = maker2.prepare_csvs(csv_paths=[path])

# generating vectors and setting up
# generating context vectors here for for testing the weighted average function (based off cui2count_train)
zeroes = np.zeros(shape=(1,300))
ones = np.ones(shape=(1,300))
for i, cui in enumerate(cdb1.cui2names):
cdb1.cui2context_vectors[cui] = {"short" : ones}
cdb2.cui2context_vectors[cui] = {"short" : zeroes}
cdb1.cui2context_vectors[cui] = {"short": ones}
cdb2.cui2context_vectors[cui] = {"short": zeroes}
cdb1.cui2count_train[cui] = 1
cdb2.cui2count_train[cui] = i + 1
test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper" : "P"}}
# adding new names and cuis to each cdb to test after merging
test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
cdb1.add_names("C0006826", test_add)
unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper" : "P"}}
unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
cdb2.add_names("UniqueTest", unique_test)
cdb2.cui2context_vectors["UniqueTest"] = {"short" : ones}
cdb2.cui2context_vectors["UniqueTest"] = {"short": zeroes}
cdb2.addl_info["cui2ontologies"] = {}
cdb2.addl_info["cui2description"] = {}
for cui in cdb2.cui2names:
cdb2.addl_info["cui2ontologies"][cui] = ["test_ontology"]
cdb2.addl_info["cui2description"][cui] = "test_description"

# merging
cdb = CDB.merge_cdb(cdb1=cdb1, cdb2=cdb2)
overwrite_cdb = CDB.merge_cdb(cdb1=cdb1, cdb2=cdb2, overwrite_training=2, full_build=True)

# tests
self.assertIn("test", cdb.cui2names["C0006826"])
self.assertIn("test_name", cdb.cui2snames["C0006826"])
self.assertEqual("Cancer", cdb.cui2preferred_name["C0006826"])
self.assertTrue(np.array_equal(np.ones(shape=(1,300)), cdb.cui2context_vectors["UniqueTest"]["short"]))
base = np.ones(shape=(1,300))
self.assertTrue(np.array_equal(zeroes, cdb.cui2context_vectors["UniqueTest"]["short"]))
for i, cui in enumerate(cdb1.cui2names):
self.assertTrue(np.array_equal(cdb.cui2context_vectors[cui]["short"], np.divide(base, i+2)))

self.assertTrue(np.array_equal(cdb.cui2context_vectors[cui]["short"], np.divide(ones, i+2)))
self.assertEqual(cdb.addl_info["cui2ontologies"], dict())
self.assertEqual(cdb.addl_info["cui2ontologies"], dict())
for cui in cdb2.cui2names:
self.assertTrue(np.array_equal(overwrite_cdb.cui2context_vectors[cui]["short"], zeroes))
self.assertEqual(overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
self.assertEqual(overwrite_cdb.addl_info["cui2description"][cui], "test_description")


if __name__ == '__main__':
Expand Down

0 comments on commit 6f752c8

Please sign in to comment.