Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit d079702

Browse files
authored
Merge branch 'master' into faq
2 parents e18fab5 + d91f3b6 commit d079702

File tree

10 files changed

+189
-1
lines changed

10 files changed

+189
-1
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,7 @@ s2v_old/
143143
*.ipynb
144144
cc.*.bin
145145
tests/*.bin
146-
node_modules
146+
147147
/package-lock.json
148+
/node_modules/
149+

docs/api/language/gensim_lang.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# `whatlies.language.GensimLanguage`
2+
3+
::: whatlies.language.GensimLanguage

mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ nav:
3030
- fasttext: api/language/fasttext_lang.md
3131
- CountVector: api/language/countvector_lang.md
3232
- BPEmbLang: api/language/bpemb_lang.md
33+
- Gensim: api/language/gensim_lang.md
3334
- Roadmap: roadmap.md
3435
plugins:
3536
- mkdocstrings

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"sense2vec>=1.0.2",
1414
"fasttext>=0.9.1",
1515
"bpemb>=0.3.0",
16+
"gensim>=3.8.3",
1617
]
1718

1819
docs_packages = [

tests/cache/custom_gensim_vectors.kv

1.7 KB
Binary file not shown.

tests/prepare_gensim_kv.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from gensim.test.utils import common_texts
2+
from gensim.models import Word2Vec
3+
4+
model = Word2Vec(common_texts, size=10, window=5, min_count=1, workers=4)
5+
model.wv.save("tests/cache/custom_gensim_vectors.kv")
File renamed without changes.

tests/test_lang/test_gensim.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from whatlies.language import GensimLanguage
4+
5+
6+
@pytest.fixture()
7+
def lang():
8+
return GensimLanguage("tests/cache/custom_gensim_vectors.kv")
9+
10+
11+
def test_missing_retreival(lang):
12+
assert lang["doesnotexist"].vector.shape == (10,)
13+
14+
15+
def test_spaces_retreival(lang):
16+
assert lang["graph trees"].vector.shape == (10,)
17+
assert lang["graph trees dog"].vector.shape == (10,)
18+
19+
20+
def test_single_token_words(lang):
21+
assert lang["computer"].vector.shape == (10,)
22+
assert len(lang[["red", "blue"]]) == 2
23+
24+
25+
def test_similar_retreival(lang):
26+
assert len(lang.score_similar("hi", 10)) == 10
27+
assert len(lang.embset_similar("hi", 10)) == 10

whatlies/language/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from .fasttext_lang import FasttextLanguage
44
from .countvector_lang import CountVectorLanguage
55
from .bpemblang import BytePairLang
6+
from .gensim_lang import GensimLanguage
67

78
__all__ = [
89
"SpacyLanguage",
910
"Sense2VecLanguage",
1011
"FasttextLanguage",
1112
"CountVectorLanguage",
1213
"BytePairLang",
14+
"GensimLanguage",
1315
]

whatlies/language/gensim_lang.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import warnings
2+
3+
import numpy as np
4+
from typing import Union, List
5+
from sklearn.metrics import pairwise_distances
6+
from gensim.models import KeyedVectors
7+
8+
9+
from whatlies.embedding import Embedding
10+
from whatlies.embeddingset import EmbeddingSet
11+
from whatlies.language.common import SklearnTransformerMixin
12+
13+
14+
class GensimLanguage(SklearnTransformerMixin):
15+
"""
16+
This object is used to lazily fetch [Embedding][whatlies.embedding.Embedding]s or
17+
[EmbeddingSet][whatlies.embeddingset.EmbeddingSet]s from a keyed vector file.
18+
These files are generated by [gensim](https://radimrehurek.com/gensim/models/word2vec.html).
19+
This object is meant for retreival, not plotting.
20+
21+
Important:
22+
The vectors are not given by this library they must be download/created upfront.
23+
A potential benefit of this is that you can train your own embeddings using
24+
gensim and visualise them using this library.
25+
26+
Here's a snippet that you can use to train your own (very limited) word2vec embeddings.
27+
28+
```
29+
from gensim.test.utils import common_texts
30+
from gensim.models import Word2Vec
31+
model = Word2Vec(common_texts, size=10, window=5, min_count=1, workers=4)
32+
model.wv.save("wordvectors.kv")
33+
```
34+
35+
Note that if a word is not available in the keyed vectors file then we'll assume
36+
a zero vector. If you pass a sentence then we'll add together the embeddings vectors
37+
of the seperate words.
38+
39+
Arguments:
40+
keyedfile: name of the model to load, be sure that it's downloaded or trained beforehand
41+
42+
**Usage**:
43+
44+
```python
45+
> from whatlies.language import GensimLanguage
46+
> lang = GensimLanguage("wordvectors.kv")
47+
> lang['computer']
48+
> lang = GensimLanguage("wordvectors.kv", size=10)
49+
> lang[['computer', 'human', 'dog']]
50+
```
51+
"""
52+
53+
def __init__(self, keyedfile):
54+
self.kv = KeyedVectors.load(keyedfile)
55+
56+
def __getitem__(self, query: Union[str, List[str]]):
57+
"""
58+
Retreive a single embedding or a set of embeddings.
59+
60+
Arguments:
61+
query: single string or list of strings
62+
63+
**Usage**
64+
```python
65+
> from whatlies.language import GensimLanguage
66+
> lang = GensimLanguage("wordvectors.kv")
67+
> lang['computer']
68+
> lang = GensimLanguage("wordvectors.kv", size=10)
69+
> lang[['computer', 'human', 'dog']]
70+
```
71+
"""
72+
if isinstance(query, str):
73+
if " " in query:
74+
return Embedding(
75+
query, np.sum([self[q].vector for q in query.split(" ")], axis=0)
76+
)
77+
try:
78+
vec = np.sum([self.kv[q] for q in query.split(" ")], axis=0)
79+
except KeyError:
80+
vec = np.zeros(self.kv.vector_size)
81+
return Embedding(query, vec)
82+
return EmbeddingSet(*[self[tok] for tok in query])
83+
84+
def _prepare_queries(self, lower):
85+
queries = [w for w in self.kv.vocab.keys()]
86+
if lower:
87+
queries = [w for w in queries if w.lower() == w]
88+
return queries
89+
90+
def _calculate_distances(self, emb, queries, metric):
91+
vec = emb.vector
92+
vector_matrix = np.array([self[w].vector for w in queries])
93+
# there are NaNs returned, good to investigate later why that might be
94+
vector_matrix = np.array(
95+
[np.zeros(v.shape) if np.any(np.isnan(v)) else v for v in vector_matrix]
96+
)
97+
return pairwise_distances(vector_matrix, vec.reshape(1, -1), metric=metric)
98+
99+
def score_similar(
100+
self, emb: Union[str, Embedding], n: int = 10, metric="cosine", lower=False,
101+
) -> List:
102+
"""
103+
Retreive a list of (Embedding, score) tuples that are the most similar to the passed query.
104+
105+
Arguments:
106+
emb: query to use
107+
n: the number of items you'd like to see returned
108+
metric: metric to use to calculate distance, must be scipy or sklearn compatible
109+
lower: only fetch lower case tokens
110+
111+
Returns:
112+
An list of ([Embedding][whatlies.embedding.Embedding], score) tuples.
113+
"""
114+
if isinstance(emb, str):
115+
emb = self[emb]
116+
117+
queries = self._prepare_queries(lower=lower)
118+
distances = self._calculate_distances(emb=emb, queries=queries, metric=metric)
119+
by_similarity = sorted(zip(queries, distances), key=lambda z: z[1])
120+
121+
if len(queries) < n:
122+
warnings.warn(
123+
f"We could only find {len(queries)} feasible words. Consider changing `top_n` or `lower`",
124+
UserWarning,
125+
)
126+
127+
return [(self[q], float(d)) for q, d in by_similarity[:n]]
128+
129+
def embset_similar(
130+
self, emb: Union[str, Embedding], n: int = 10, lower=False, metric="cosine",
131+
) -> EmbeddingSet:
132+
"""
133+
Retreive an [EmbeddingSet][whatlies.embeddingset.EmbeddingSet] that are the most similar to the passed query.
134+
135+
Arguments:
136+
emb: query to use
137+
n: the number of items you'd like to see returned
138+
metric: metric to use to calculate distance, must be scipy or sklearn compatible
139+
lower: only fetch lower case tokens
140+
141+
Returns:
142+
An [EmbeddingSet][whatlies.embeddingset.EmbeddingSet] containing the similar embeddings.
143+
"""
144+
embs = [
145+
w[0] for w in self.score_similar(emb=emb, n=n, lower=lower, metric=metric)
146+
]
147+
return EmbeddingSet({w.name: w for w in embs})

0 commit comments

Comments
 (0)