|
| 1 | +import warnings |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from typing import Union, List |
| 5 | +from sklearn.metrics import pairwise_distances |
| 6 | +from gensim.models import KeyedVectors |
| 7 | + |
| 8 | + |
| 9 | +from whatlies.embedding import Embedding |
| 10 | +from whatlies.embeddingset import EmbeddingSet |
| 11 | +from whatlies.language.common import SklearnTransformerMixin |
| 12 | + |
| 13 | + |
| 14 | +class GensimLanguage(SklearnTransformerMixin): |
| 15 | + """ |
| 16 | + This object is used to lazily fetch [Embedding][whatlies.embedding.Embedding]s or |
| 17 | + [EmbeddingSet][whatlies.embeddingset.EmbeddingSet]s from a keyed vector file. |
| 18 | + These files are generated by [gensim](https://radimrehurek.com/gensim/models/word2vec.html). |
| 19 | + This object is meant for retreival, not plotting. |
| 20 | +
|
| 21 | + Important: |
| 22 | + The vectors are not given by this library they must be download/created upfront. |
| 23 | + A potential benefit of this is that you can train your own embeddings using |
| 24 | + gensim and visualise them using this library. |
| 25 | +
|
| 26 | + Here's a snippet that you can use to train your own (very limited) word2vec embeddings. |
| 27 | +
|
| 28 | + ``` |
| 29 | + from gensim.test.utils import common_texts |
| 30 | + from gensim.models import Word2Vec |
| 31 | + model = Word2Vec(common_texts, size=10, window=5, min_count=1, workers=4) |
| 32 | + model.wv.save("wordvectors.kv") |
| 33 | + ``` |
| 34 | +
|
| 35 | + Note that if a word is not available in the keyed vectors file then we'll assume |
| 36 | + a zero vector. If you pass a sentence then we'll add together the embeddings vectors |
| 37 | + of the seperate words. |
| 38 | +
|
| 39 | + Arguments: |
| 40 | + keyedfile: name of the model to load, be sure that it's downloaded or trained beforehand |
| 41 | +
|
| 42 | + **Usage**: |
| 43 | +
|
| 44 | + ```python |
| 45 | + > from whatlies.language import GensimLanguage |
| 46 | + > lang = GensimLanguage("wordvectors.kv") |
| 47 | + > lang['computer'] |
| 48 | + > lang = GensimLanguage("wordvectors.kv", size=10) |
| 49 | + > lang[['computer', 'human', 'dog']] |
| 50 | + ``` |
| 51 | + """ |
| 52 | + |
| 53 | + def __init__(self, keyedfile): |
| 54 | + self.kv = KeyedVectors.load(keyedfile) |
| 55 | + |
| 56 | + def __getitem__(self, query: Union[str, List[str]]): |
| 57 | + """ |
| 58 | + Retreive a single embedding or a set of embeddings. |
| 59 | +
|
| 60 | + Arguments: |
| 61 | + query: single string or list of strings |
| 62 | +
|
| 63 | + **Usage** |
| 64 | + ```python |
| 65 | + > from whatlies.language import GensimLanguage |
| 66 | + > lang = GensimLanguage("wordvectors.kv") |
| 67 | + > lang['computer'] |
| 68 | + > lang = GensimLanguage("wordvectors.kv", size=10) |
| 69 | + > lang[['computer', 'human', 'dog']] |
| 70 | + ``` |
| 71 | + """ |
| 72 | + if isinstance(query, str): |
| 73 | + if " " in query: |
| 74 | + return Embedding( |
| 75 | + query, np.sum([self[q].vector for q in query.split(" ")], axis=0) |
| 76 | + ) |
| 77 | + try: |
| 78 | + vec = np.sum([self.kv[q] for q in query.split(" ")], axis=0) |
| 79 | + except KeyError: |
| 80 | + vec = np.zeros(self.kv.vector_size) |
| 81 | + return Embedding(query, vec) |
| 82 | + return EmbeddingSet(*[self[tok] for tok in query]) |
| 83 | + |
| 84 | + def _prepare_queries(self, lower): |
| 85 | + queries = [w for w in self.kv.vocab.keys()] |
| 86 | + if lower: |
| 87 | + queries = [w for w in queries if w.lower() == w] |
| 88 | + return queries |
| 89 | + |
| 90 | + def _calculate_distances(self, emb, queries, metric): |
| 91 | + vec = emb.vector |
| 92 | + vector_matrix = np.array([self[w].vector for w in queries]) |
| 93 | + # there are NaNs returned, good to investigate later why that might be |
| 94 | + vector_matrix = np.array( |
| 95 | + [np.zeros(v.shape) if np.any(np.isnan(v)) else v for v in vector_matrix] |
| 96 | + ) |
| 97 | + return pairwise_distances(vector_matrix, vec.reshape(1, -1), metric=metric) |
| 98 | + |
| 99 | + def score_similar( |
| 100 | + self, emb: Union[str, Embedding], n: int = 10, metric="cosine", lower=False, |
| 101 | + ) -> List: |
| 102 | + """ |
| 103 | + Retreive a list of (Embedding, score) tuples that are the most similar to the passed query. |
| 104 | +
|
| 105 | + Arguments: |
| 106 | + emb: query to use |
| 107 | + n: the number of items you'd like to see returned |
| 108 | + metric: metric to use to calculate distance, must be scipy or sklearn compatible |
| 109 | + lower: only fetch lower case tokens |
| 110 | +
|
| 111 | + Returns: |
| 112 | + An list of ([Embedding][whatlies.embedding.Embedding], score) tuples. |
| 113 | + """ |
| 114 | + if isinstance(emb, str): |
| 115 | + emb = self[emb] |
| 116 | + |
| 117 | + queries = self._prepare_queries(lower=lower) |
| 118 | + distances = self._calculate_distances(emb=emb, queries=queries, metric=metric) |
| 119 | + by_similarity = sorted(zip(queries, distances), key=lambda z: z[1]) |
| 120 | + |
| 121 | + if len(queries) < n: |
| 122 | + warnings.warn( |
| 123 | + f"We could only find {len(queries)} feasible words. Consider changing `top_n` or `lower`", |
| 124 | + UserWarning, |
| 125 | + ) |
| 126 | + |
| 127 | + return [(self[q], float(d)) for q, d in by_similarity[:n]] |
| 128 | + |
| 129 | + def embset_similar( |
| 130 | + self, emb: Union[str, Embedding], n: int = 10, lower=False, metric="cosine", |
| 131 | + ) -> EmbeddingSet: |
| 132 | + """ |
| 133 | + Retreive an [EmbeddingSet][whatlies.embeddingset.EmbeddingSet] that are the most similar to the passed query. |
| 134 | +
|
| 135 | + Arguments: |
| 136 | + emb: query to use |
| 137 | + n: the number of items you'd like to see returned |
| 138 | + metric: metric to use to calculate distance, must be scipy or sklearn compatible |
| 139 | + lower: only fetch lower case tokens |
| 140 | +
|
| 141 | + Returns: |
| 142 | + An [EmbeddingSet][whatlies.embeddingset.EmbeddingSet] containing the similar embeddings. |
| 143 | + """ |
| 144 | + embs = [ |
| 145 | + w[0] for w in self.score_similar(emb=emb, n=n, lower=lower, metric=metric) |
| 146 | + ] |
| 147 | + return EmbeddingSet({w.name: w for w in embs}) |
0 commit comments