-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CU-8696nbm9j: Add module to convert vocab vectors and a few simple tests
- Loading branch information
Showing
2 changed files
with
230 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import numpy as np | ||
import logging | ||
from typing import Type | ||
|
||
from medcat.cdb import CDB | ||
from medcat.vocab import Vocab | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def calc_matrix(vocab: Vocab, target_size: int) -> np.ndarray: | ||
"""Calculate the transformation matrix based on the word vectors in the Vocab. | ||
Performs Principal Component Analysis (PCA). | ||
This first means all the word vectors in the Vocab. | ||
It then finds the covariance matrix. | ||
After that, the eigenvalues and and eigenvectors are calculated. | ||
And the `target_size` eigenvectors corresponding to the largest | ||
eigenvalues are selected to create the transformation matrix. | ||
Args: | ||
vocab (Vocab): The Vocab. | ||
target_size (int): The target vector size. | ||
Returns: | ||
np.ndarray: The transformation matrix. | ||
""" | ||
all_vecs = np.vstack( | ||
[value['vec'] for value in vocab.vocab.values() if value['vec'] is not None] | ||
) | ||
logger.debug("Vocab vectors have a total shape of %s", np.shape(all_vecs)) | ||
all_vecs_meaned = all_vecs - np.mean(all_vecs, axis=0) | ||
cov_matrix = np.cov(all_vecs_meaned, rowvar=False) | ||
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) | ||
sorted_idx = np.argsort(eigenvalues)[::-1] | ||
logger.debug("The sorted eigenvalues are as follows:", | ||
[f"{v:5.2f}" for v in eigenvalues[sorted_idx]]) | ||
sorted_eigenvectors = eigenvectors[:, sorted_idx] | ||
transformation_matrix = sorted_eigenvectors[:, :target_size] | ||
return transformation_matrix.T | ||
|
||
|
||
def convert_vec(cur: np.ndarray, matrix: np.ndarray, | ||
target_dtype: Type = np.float32) -> np.ndarray: | ||
"""Helper function to convert the vector. | ||
This also guarantees uniform typing (of np.float32) since in our | ||
experience some vectors may be of a different type before (i.e np.float64). | ||
Args: | ||
cur (np.ndarray): The current vector. | ||
matrix (np.ndarray): The transformation matrix. | ||
target_dtype (Type): The target element data ype. Defaults to np.float32. | ||
Returns: | ||
np.ndarray: The transformed vector. | ||
""" | ||
return (matrix @ cur).astype(target_dtype) | ||
|
||
|
||
def convert_vocab(vocab: Vocab, matrix: np.ndarray, | ||
unigram_table_size: int = 10_000_000) -> None: | ||
"""Use the transformation matrix to convert the word vectors. | ||
Args: | ||
vocab (Vocab): The Vocab. | ||
matrix (np.ndarray): The transformation matrix. | ||
unigram_table_size (int): The unigram table size. Defualts to 10 000 000. | ||
""" | ||
for d in vocab.vocab.values(): | ||
cvec = d['vec'] | ||
if cvec is None: | ||
continue | ||
d['vec'] = convert_vec(cvec, matrix) | ||
logger.info("Recalc unigram table") | ||
vocab.make_unigram_table(unigram_table_size) | ||
|
||
|
||
def convert_context_vectors(cdb: CDB, matrix: np.ndarray) -> None: | ||
"""Use the transformation matrix to convert the context vectors within the CDB. | ||
Args: | ||
cdb (CDB): The Context Database. | ||
matrix (np.ndarray): The transformation matrix. | ||
""" | ||
for per_cui_dict in cdb.cui2context_vectors.values(): | ||
for type_name, cur_vec in list(per_cui_dict.items()): | ||
per_cui_dict[type_name] = convert_vec(cur_vec, matrix) | ||
cdb.is_dirty = True | ||
|
||
|
||
def convert_vocab_vector_size(cdb: CDB, vocab: Vocab, vec_size: int): | ||
"""Convert the vocab vector size to a smaller one. | ||
This uses Principal Component Analysis (PCA). The idea is that we | ||
first center all the word vectors (in Vocab), then compute the | ||
covariance matrix, then find the eigenvalues and eigenvectors, | ||
and then we select the top `vec_size` eigenvectors. | ||
This produces a transformation matrix of shape (vec_size, N), | ||
where N is the current vector length in the vocab. | ||
After that, we perform the tranformation. First we transform all | ||
the vectors in the Vocab. And then we transform all the context | ||
vectors defined within the CDB. | ||
NOTE: This requires the CDB as well since the per concept context | ||
vectors stored within it are based on the vectors in the vocab and | ||
thus they also need to be transformed. | ||
Args: | ||
cdb (CDB): The Concept Database. | ||
vocab (Vocab): The Vocab. | ||
vec_size (int): The target vector size. | ||
""" | ||
logger.info("Converting Vocab and CDB to size %s. Calculating " | ||
"transformation matrix", vec_size) | ||
matrix = calc_matrix(vocab, vec_size) | ||
logger.info("Found transformation matrix with shape %s. " | ||
"Now converting vocab.", matrix.shape) | ||
convert_vocab(vocab, matrix) | ||
logger.info("Done converting vocab, now converting the per concept " | ||
"context vectors defined in the CDB.") | ||
convert_context_vectors(cdb, matrix) | ||
logger.info("Done with the conversion to vocab vector size %s.", | ||
vec_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from medcat.vocab import Vocab | ||
from medcat.utils import vocab_utils | ||
from medcat.cdb import CDB | ||
|
||
import unittest | ||
import numpy as np | ||
import os | ||
import random | ||
|
||
|
||
WORDS = [ | ||
("word1", 12, np.array([0, 1, 2, 1, 1, 0])), | ||
("word2", 21, np.array([2, -1, 0, 1, -1, -1])), | ||
("word3", 32, np.array([2, -1, 0, 0, 0, 1])), | ||
("word4", 42, np.array([-1, 0, -1, -1, 0, 2])), | ||
("word5", 24, np.array([0, 3, -2, 5, -1, 3])), | ||
("word6", 46, np.array([3, -5, 10, 1, 10, -2])), | ||
("word7", 31, np.array([-2, 4, -1, -2, 1, 2])), | ||
("word8", 28, np.array([-3, 3, -2, 4, 9, 2])), | ||
("word9", 19, np.array([-4, 2, -3, -6, 3, 2])), | ||
("word10", 1, np.array([4, 1, -4, 0, 5, 2])), | ||
] | ||
|
||
|
||
class TestWithTransformationMatrixBase(unittest.TestCase): | ||
ORIG_SIZE = len(WORDS[0][-1]) | ||
TARGET_SIZE = 3 | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.vocab = Vocab() | ||
for word, cnt, vec in WORDS: | ||
cls.vocab.add_word(word, cnt, vec) | ||
cls.TM = vocab_utils.calc_matrix(cls.vocab, cls.TARGET_SIZE) | ||
|
||
|
||
class TransformationMatrixTests(TestWithTransformationMatrixBase): | ||
|
||
def test_transformation_matrix_correct_size(self): | ||
self.assertEqual(self.TM.shape, (self.TARGET_SIZE, self.ORIG_SIZE)) | ||
|
||
def test_transformation_matrix_reasonable(self): | ||
self.assertFalse(np.any(self.TM != self.TM), "Shouldn't have NaNs") | ||
self.assertFalse(np.any(self.TM - 100 == self.TM), "Shouldn't have infinity") | ||
|
||
|
||
class TestWithTMAndCDBBase(TestWithTransformationMatrixBase): | ||
CDB_PATH = os.path.join(os.path.dirname(__file__), "..", "..", | ||
"examples", "cdb.dat") | ||
UNIGRAM_TABLE_SIZE = 100 | ||
|
||
@classmethod | ||
def add_fake_context_vectors(cls, words: int = 4): | ||
# NOTE: in original size! | ||
cui2cv = cls.cdb.cui2context_vectors | ||
for cui in cls.cdb.cui2names: | ||
cui_cv = {} | ||
for cv_type in cls.cdb.config.linking.context_vector_sizes: | ||
cv = 0 | ||
for _ in range(words): | ||
# get the original vector | ||
cv += random.choice(WORDS)[2] | ||
cui_cv[cv_type] = cv | ||
cui2cv[cui] = cui_cv | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.cdb = CDB.load(cls.CDB_PATH) | ||
cls.add_fake_context_vectors() | ||
|
||
|
||
class VocabTransformationTests(TestWithTMAndCDBBase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.do_conversion() | ||
|
||
@classmethod | ||
def do_conversion(cls): | ||
vocab_utils.convert_vocab(cls.vocab, cls.TM, | ||
unigram_table_size=cls.UNIGRAM_TABLE_SIZE) | ||
vocab_utils.convert_context_vectors(cls.cdb, cls.TM) | ||
|
||
def test_can_transform_vocab(self): | ||
for w in self.vocab.vocab: | ||
with self.subTest(w): | ||
vec = self.vocab.vec(w) | ||
self.assertEqual(len(vec), self.TARGET_SIZE) | ||
|
||
def test_can_transform_cdb(self): | ||
for cui, cv in self.cdb.cui2context_vectors.items(): | ||
for cvt, vec in cv.items(): | ||
with self.subTest(f"{cui}-{cvt}"): | ||
self.assertEqual(len(vec), self.TARGET_SIZE) | ||
|
||
|
||
class OverallTransformationTests(VocabTransformationTests): | ||
|
||
@classmethod | ||
def do_conversion(cls): | ||
vocab_utils.convert_vocab_vector_size(cls.cdb, cls.vocab, | ||
cls.TARGET_SIZE) |