Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 70305f4

Browse files
Merge pull request #373 from CogStack/CU2e77a5x-cdb-merge-function
CU2e77a5x - Add a CDB merge function Given two CDBs, a new CDB will be created combining the entries from both CDBs.
2 parents 45cef2b + c74fe1f commit 70305f4

File tree

3 files changed

+195
-0
lines changed

3 files changed

+195
-0
lines changed

medcat/utils/cdb_utils.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import logging
2+
import numpy as np
3+
4+
from copy import deepcopy
5+
from medcat.cdb import CDB
6+
7+
logger = logging.getLogger(__name__) # separate logger from the package-level one
8+
9+
10+
def merge_cdb(cdb1: "CDB",
11+
cdb2: "CDB",
12+
overwrite_training: int = 0,
13+
full_build: bool = False):
14+
"""Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed.
15+
`addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build`
16+
17+
Args:
18+
cdb1 (medcat.cdb.CDB):
19+
The first medcat cdb to merge. In cases where merging isn't suitable isn't ideal (such as
20+
cui2preferred_name), this cdb values will be prioritised over cdb2.
21+
cdb2 (medcat.cdb.CDB):
22+
The second medcat cdb to merge.
23+
overwrite_training (int):
24+
Choose to prioritise a CDB's context vectors values over merging gracefully. 0 - no prio, 1 - CDB1, 2 - CDB2
25+
full_build (bool):
26+
Add additional information from "addl_info" dicts "cui2ontologies" and "cui2description"
27+
"""
28+
config = deepcopy(cdb1.config)
29+
cdb = CDB(config)
30+
31+
# Copy CDB 1 - as all settings from CDB 1 will be carried over
32+
cdb.cui2names = deepcopy(cdb1.cui2names)
33+
cdb.cui2snames = deepcopy(cdb1.cui2snames)
34+
cdb.cui2count_train = deepcopy(cdb1.cui2count_train)
35+
cdb.cui2info = deepcopy(cdb1.cui2info)
36+
cdb.cui2context_vectors = deepcopy(cdb1.cui2context_vectors)
37+
cdb.cui2tags = deepcopy(cdb1.cui2tags)
38+
cdb.cui2type_ids = deepcopy(cdb1.cui2type_ids)
39+
cdb.cui2preferred_name = deepcopy(cdb1.cui2preferred_name)
40+
cdb.name2cuis = deepcopy(cdb1.name2cuis)
41+
cdb.name2cuis2status = deepcopy(cdb1.name2cuis2status)
42+
cdb.name2count_train = deepcopy(cdb1.name2count_train)
43+
cdb.name_isupper = deepcopy(cdb1.name_isupper)
44+
if full_build:
45+
cdb.addl_info = deepcopy(cdb1.addl_info)
46+
47+
# handles cui2names, cui2snames, name_isupper, name2cuis, name2cuis2status, cui2preferred_name
48+
for cui in cdb2.cui2names:
49+
names = dict()
50+
for name in cdb2.cui2names[cui]:
51+
names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)}
52+
name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A'
53+
# For addl_info check cui2original_names as they MUST be added
54+
ontologies = set()
55+
description = ''
56+
to_build = False
57+
if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']):
58+
to_build = True
59+
if 'cui2ontologies' in cdb2.addl_info:
60+
ontologies.update(cdb2.addl_info['cui2ontologies'][cui])
61+
if 'cui2description' in cdb2.addl_info:
62+
description = cdb2.addl_info['cui2description'][cui]
63+
cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
64+
type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build)
65+
if cui in cdb1.cui2names:
66+
if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train):
67+
if overwrite_training == 2 and cui in cdb2.cui2count_train:
68+
cdb.cui2count_train[cui] = cdb2.cui2count_train[cui]
69+
else:
70+
cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0)
71+
if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]):
72+
if overwrite_training == 2 and cui in cdb2.cui2context_vectors:
73+
weights = [0, 1]
74+
else:
75+
norm = cdb.cui2count_train[cui]
76+
weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)]
77+
contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short
78+
for s in contexts:
79+
cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300))))
80+
if cui in cdb1.cui2tags:
81+
cdb.cui2tags[cui].append(cdb2.cui2tags[cui])
82+
if cui in cdb1.cui2type_ids:
83+
cdb.cui2type_ids[cui] = cdb1.cui2type_ids[cui].union(cdb2.cui2type_ids[cui])
84+
else:
85+
if cui in cdb2.cui2count_train:
86+
cdb.cui2count_train[cui] = cdb2.cui2names[cui]
87+
if cui in cdb2.cui2info:
88+
cdb.cui2info[cui] = cdb2.cui2info[cui]
89+
if cui in cdb2.cui2context_vectors:
90+
cdb.cui2context_vectors[cui] = cdb2.cui2context_vectors[cui]
91+
if cui in cdb2.cui2tags:
92+
cdb.cui2tags[cui] = cdb2.cui2tags[cui]
93+
if cui in cdb2.cui2type_ids:
94+
cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui]
95+
96+
if overwrite_training != 1:
97+
for name in cdb2.name2cuis:
98+
if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs
99+
if name in cdb1.name2count_train and name in cdb2.name2count_train:
100+
cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason
101+
else:
102+
if name in cdb2.name2count_train:
103+
cdb.name2count_train[name] = cdb2.name2count_train[name]
104+
105+
# snames
106+
cdb.snames = cdb1.snames.union(cdb2.snames)
107+
108+
# vocab, adding counts if they occur in both
109+
cdb.vocab = deepcopy(cdb1.vocab)
110+
if overwrite_training != 1:
111+
for word in cdb2.vocab:
112+
if word in cdb.vocab and overwrite_training == 0:
113+
cdb.vocab[word] += cdb2.vocab[word]
114+
else:
115+
cdb.vocab[word] = cdb2.vocab[word]
116+
117+
return cdb

tests/helper.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77

88
from medcat.vocab import Vocab
9+
from medcat.cdb_maker import CDBMaker
10+
from medcat.config import Config
911

1012

1113
class AsyncMock(unittest.mock.MagicMock):
@@ -86,3 +88,36 @@ def check_or_download(self):
8688
return
8789
with open(self.vocab_path, 'wb') as f:
8890
f.write(tmp.content)
91+
92+
93+
class ForCDBMerging:
94+
95+
def __init__(self) -> None:
96+
# generating cdbs - two maker are requested as they point to the same created CDB.
97+
config = Config()
98+
config.general["spacy_model"] = "en_core_web_md"
99+
maker1 = CDBMaker(config)
100+
maker2 = CDBMaker(config) # second maker is required as it will otherwise point to same object
101+
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "model_creator", "umls_sample.csv")
102+
self.cdb1 = maker1.prepare_csvs(csv_paths=[path])
103+
self.cdb2 = maker2.prepare_csvs(csv_paths=[path])
104+
105+
# generating context vectors here for for testing the weighted average function (based off cui2count_train)
106+
zeroes = np.zeros(shape=(1,300))
107+
ones = np.ones(shape=(1,300))
108+
for i, cui in enumerate(self.cdb1.cui2names):
109+
self.cdb1.cui2context_vectors[cui] = {"short": ones}
110+
self.cdb2.cui2context_vectors[cui] = {"short": zeroes}
111+
self.cdb1.cui2count_train[cui] = 1
112+
self.cdb2.cui2count_train[cui] = i + 1
113+
# adding new names and cuis to each cdb to test after merging
114+
test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
115+
self.cdb1.add_names("C0006826", test_add)
116+
unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
117+
self.cdb2.add_names("UniqueTest", unique_test)
118+
self.cdb2.cui2context_vectors["UniqueTest"] = {"short": zeroes}
119+
self.cdb2.addl_info["cui2ontologies"] = {}
120+
self.cdb2.addl_info["cui2description"] = {}
121+
for cui in self.cdb2.cui2names:
122+
self.cdb2.addl_info["cui2ontologies"][cui] = {"test_ontology"}
123+
self.cdb2.addl_info["cui2description"][cui] = "test_description"

tests/utils/test_cdb_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
import numpy as np
3+
from tests.helper import ForCDBMerging
4+
from medcat.utils.cdb_utils import merge_cdb
5+
6+
7+
class CDBMergeTests(unittest.TestCase):
8+
9+
@classmethod
10+
def setUpClass(cls):
11+
to_merge = ForCDBMerging()
12+
cls.cdb1 = to_merge.cdb1
13+
cls.cdb2 = to_merge.cdb2
14+
cls.merged_cdb = merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2)
15+
cls.overwrite_cdb = merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2, overwrite_training=2, full_build=True)
16+
cls.zeroes = np.zeros(shape=(1,300))
17+
cls.ones = np.ones(shape=(1,300))
18+
19+
def test_merge_inserts(self):
20+
self.assertIn("test", self.merged_cdb.cui2names["C0006826"])
21+
self.assertIn("test_name", self.merged_cdb.cui2snames["C0006826"])
22+
self.assertEqual("Cancer", self.merged_cdb.cui2preferred_name["C0006826"])
23+
24+
def test_no_full_build(self):
25+
self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict())
26+
self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict())
27+
28+
def test_full_build(self):
29+
for cui in self.cdb2.cui2names:
30+
self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
31+
self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description")
32+
33+
def test_vector_merge(self):
34+
self.assertTrue(np.array_equal(self.zeroes, self.merged_cdb.cui2context_vectors["UniqueTest"]["short"]))
35+
for i, cui in enumerate(self.cdb1.cui2names):
36+
self.assertTrue(np.array_equal(self.merged_cdb.cui2context_vectors[cui]["short"], np.divide(self.ones, i+2)))
37+
38+
39+
def test_overwrite_parameter(self):
40+
for cui in self.cdb2.cui2names:
41+
self.assertTrue(np.array_equal(self.overwrite_cdb.cui2context_vectors[cui]["short"], self.zeroes))
42+
self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
43+
self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description")

0 commit comments

Comments
 (0)