|
| 1 | +import logging |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +from copy import deepcopy |
| 5 | +from medcat.cdb import CDB |
| 6 | + |
| 7 | +logger = logging.getLogger(__name__) # separate logger from the package-level one |
| 8 | + |
| 9 | + |
| 10 | +def merge_cdb(cdb1: "CDB", |
| 11 | + cdb2: "CDB", |
| 12 | + overwrite_training: int = 0, |
| 13 | + full_build: bool = False): |
| 14 | + """Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed. |
| 15 | + `addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build` |
| 16 | +
|
| 17 | + Args: |
| 18 | + cdb1 (medcat.cdb.CDB): |
| 19 | + The first medcat cdb to merge. In cases where merging isn't suitable isn't ideal (such as |
| 20 | + cui2preferred_name), this cdb values will be prioritised over cdb2. |
| 21 | + cdb2 (medcat.cdb.CDB): |
| 22 | + The second medcat cdb to merge. |
| 23 | + overwrite_training (int): |
| 24 | + Choose to prioritise a CDB's context vectors values over merging gracefully. 0 - no prio, 1 - CDB1, 2 - CDB2 |
| 25 | + full_build (bool): |
| 26 | + Add additional information from "addl_info" dicts "cui2ontologies" and "cui2description" |
| 27 | + """ |
| 28 | + config = deepcopy(cdb1.config) |
| 29 | + cdb = CDB(config) |
| 30 | + |
| 31 | + # Copy CDB 1 - as all settings from CDB 1 will be carried over |
| 32 | + cdb.cui2names = deepcopy(cdb1.cui2names) |
| 33 | + cdb.cui2snames = deepcopy(cdb1.cui2snames) |
| 34 | + cdb.cui2count_train = deepcopy(cdb1.cui2count_train) |
| 35 | + cdb.cui2info = deepcopy(cdb1.cui2info) |
| 36 | + cdb.cui2context_vectors = deepcopy(cdb1.cui2context_vectors) |
| 37 | + cdb.cui2tags = deepcopy(cdb1.cui2tags) |
| 38 | + cdb.cui2type_ids = deepcopy(cdb1.cui2type_ids) |
| 39 | + cdb.cui2preferred_name = deepcopy(cdb1.cui2preferred_name) |
| 40 | + cdb.name2cuis = deepcopy(cdb1.name2cuis) |
| 41 | + cdb.name2cuis2status = deepcopy(cdb1.name2cuis2status) |
| 42 | + cdb.name2count_train = deepcopy(cdb1.name2count_train) |
| 43 | + cdb.name_isupper = deepcopy(cdb1.name_isupper) |
| 44 | + if full_build: |
| 45 | + cdb.addl_info = deepcopy(cdb1.addl_info) |
| 46 | + |
| 47 | + # handles cui2names, cui2snames, name_isupper, name2cuis, name2cuis2status, cui2preferred_name |
| 48 | + for cui in cdb2.cui2names: |
| 49 | + names = dict() |
| 50 | + for name in cdb2.cui2names[cui]: |
| 51 | + names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)} |
| 52 | + name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A' |
| 53 | + # For addl_info check cui2original_names as they MUST be added |
| 54 | + ontologies = set() |
| 55 | + description = '' |
| 56 | + to_build = False |
| 57 | + if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']): |
| 58 | + to_build = True |
| 59 | + if 'cui2ontologies' in cdb2.addl_info: |
| 60 | + ontologies.update(cdb2.addl_info['cui2ontologies'][cui]) |
| 61 | + if 'cui2description' in cdb2.addl_info: |
| 62 | + description = cdb2.addl_info['cui2description'][cui] |
| 63 | + cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, |
| 64 | + type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build) |
| 65 | + if cui in cdb1.cui2names: |
| 66 | + if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train): |
| 67 | + if overwrite_training == 2 and cui in cdb2.cui2count_train: |
| 68 | + cdb.cui2count_train[cui] = cdb2.cui2count_train[cui] |
| 69 | + else: |
| 70 | + cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0) |
| 71 | + if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]): |
| 72 | + if overwrite_training == 2 and cui in cdb2.cui2context_vectors: |
| 73 | + weights = [0, 1] |
| 74 | + else: |
| 75 | + norm = cdb.cui2count_train[cui] |
| 76 | + weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)] |
| 77 | + contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short |
| 78 | + for s in contexts: |
| 79 | + cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) |
| 80 | + if cui in cdb1.cui2tags: |
| 81 | + cdb.cui2tags[cui].append(cdb2.cui2tags[cui]) |
| 82 | + if cui in cdb1.cui2type_ids: |
| 83 | + cdb.cui2type_ids[cui] = cdb1.cui2type_ids[cui].union(cdb2.cui2type_ids[cui]) |
| 84 | + else: |
| 85 | + if cui in cdb2.cui2count_train: |
| 86 | + cdb.cui2count_train[cui] = cdb2.cui2names[cui] |
| 87 | + if cui in cdb2.cui2info: |
| 88 | + cdb.cui2info[cui] = cdb2.cui2info[cui] |
| 89 | + if cui in cdb2.cui2context_vectors: |
| 90 | + cdb.cui2context_vectors[cui] = cdb2.cui2context_vectors[cui] |
| 91 | + if cui in cdb2.cui2tags: |
| 92 | + cdb.cui2tags[cui] = cdb2.cui2tags[cui] |
| 93 | + if cui in cdb2.cui2type_ids: |
| 94 | + cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui] |
| 95 | + |
| 96 | + if overwrite_training != 1: |
| 97 | + for name in cdb2.name2cuis: |
| 98 | + if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs |
| 99 | + if name in cdb1.name2count_train and name in cdb2.name2count_train: |
| 100 | + cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason |
| 101 | + else: |
| 102 | + if name in cdb2.name2count_train: |
| 103 | + cdb.name2count_train[name] = cdb2.name2count_train[name] |
| 104 | + |
| 105 | + # snames |
| 106 | + cdb.snames = cdb1.snames.union(cdb2.snames) |
| 107 | + |
| 108 | + # vocab, adding counts if they occur in both |
| 109 | + cdb.vocab = deepcopy(cdb1.vocab) |
| 110 | + if overwrite_training != 1: |
| 111 | + for word in cdb2.vocab: |
| 112 | + if word in cdb.vocab and overwrite_training == 0: |
| 113 | + cdb.vocab[word] += cdb2.vocab[word] |
| 114 | + else: |
| 115 | + cdb.vocab[word] = cdb2.vocab[word] |
| 116 | + |
| 117 | + return cdb |
0 commit comments