Skip to content

Commit fde76e8

Browse files
committed
refactored code
1 parent 7b8fe80 commit fde76e8

10 files changed

+529
-652
lines changed

changes_to_be_made.txt

-2
This file was deleted.

config.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os, torch
2+
3+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
4+
print('DEVICE: ', DEVICE)
5+
6+
DATA_SOURCE = 'gensim' # or 'toy'
7+
DATA_SOURCE = 'toy'
8+
MODEL_ID = DATA_SOURCE #'toy'# 'gensim'
9+
DISPLAY_BATCH_LOSS = True
10+
11+
if DATA_SOURCE=='toy':
12+
DISPLAY_EVERY_N_BATCH = 5000
13+
SAVE_EVERY_N_EPOCH = 100
14+
BATCH_SIZE = 32
15+
NUM_EPOCHS = int(1e+3)
16+
17+
CONTEXT_SIZE = 3
18+
FRACTION_DATA = 1
19+
SUBSAMPLING = False
20+
SAMPLING_RATE = 0.001
21+
NEGATIVE_SAMPLES = 0 # set it to 0 if you don't want to use negative samplings
22+
23+
EMBEDDING_DIM = 3
24+
LR = 0.001
25+
26+
TEST_WORDS = ['word1', 'word3', 'word6', 'word13', 'word14']
27+
TEST_WORDS_VIZ = ['word1', 'word2', 'word3', 'word4', 'word5', 'word6', 'word7', 'word8', 'word9', 'word10', 'word11', 'word12', 'word13', 'word14', 'word15']
28+
29+
elif DATA_SOURCE=='gensim':
30+
DISPLAY_EVERY_N_BATCH = 1000
31+
SAVE_EVERY_N_EPOCH = 1
32+
BATCH_SIZE = 1024*16
33+
NUM_EPOCHS = 10
34+
35+
CONTEXT_SIZE = 5
36+
FRACTION_DATA = 0.1
37+
SUBSAMPLING = True
38+
SAMPLING_RATE = 0.001
39+
NEGATIVE_SAMPLES = 10 # set it to 0 if you don't want to use negative samplings
40+
41+
EMBEDDING_DIM = 64
42+
LR = 0.0011
43+
44+
if FRACTION_DATA == 1:
45+
TEST_WORDS = ['india', 'computer', 'gold', 'football', 'cars', 'war', 'apple', 'music', 'helicopter']
46+
TEST_WORDS_VIZ = ['india', 'asia', 'guitar', 'piano', 'album', 'music', 'war', 'soldiers', 'helicopter']
47+
else:
48+
TEST_WORDS = ['human', 'boy', 'office', 'woman']
49+
TEST_WORDS_VIZ = TEST_WORDS
50+
51+
PREPROCESSED_DATA_DIR = os.path.join(MODEL_ID, 'preprocessed_data')
52+
PREPROCESSED_DATA_PATH = os.path.join(PREPROCESSED_DATA_DIR, 'preprocessed_data_' + MODEL_ID + '_' + str(FRACTION_DATA) + '.pickle')
53+
SUMMARY_DIR = os.path.join(MODEL_ID, 'summary')
54+
MODEL_DIR = os.path.join(MODEL_ID, 'models')

datasets.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from __future__ import print_function
2+
3+
import nltk
4+
nltk.download('punkt')
5+
nltk.download('stopwords')
6+
nltk.download('wordnet')
7+
8+
from nltk.stem.wordnet import WordNetLemmatizer
9+
from nltk.tokenize import word_tokenize
10+
from nltk.tokenize import sent_tokenize
11+
from nltk.corpus import stopwords
12+
13+
import numpy as np
14+
import os, glob, cv2, sys, torch, pdb, random
15+
from torch.utils.data import Dataset
16+
17+
import pdb, sys, os, time
18+
import pandas as pd
19+
from tqdm import tqdm
20+
21+
lem = WordNetLemmatizer()
22+
23+
from utils_modified import q
24+
25+
class word2vec_dataset(Dataset):
26+
def __init__(self, DATA_SOURCE, CONTEXT_SIZE, FRACTION_DATA, SUBSAMPLING, SAMPLING_RATE):
27+
28+
print("Parsing text and loading training data...")
29+
vocab, word_to_ix, ix_to_word, training_data = self.load_data(DATA_SOURCE, CONTEXT_SIZE, FRACTION_DATA, SUBSAMPLING, SAMPLING_RATE)
30+
31+
self.vocab = vocab
32+
self.word_to_ix = word_to_ix
33+
self.ix_to_word = ix_to_word
34+
35+
# training_data is a list of list of 2 indices
36+
self.data = torch.tensor(training_data, dtype = torch.long)
37+
38+
def __getitem__(self, index):
39+
x = self.data[index, 0]
40+
y = self.data[index, 1]
41+
return x, y
42+
43+
def __len__(self):
44+
return len(self.data)
45+
46+
def gather_training_data(self, split_text, word_to_ix, context_size):
47+
training_data = []
48+
all_vocab_indices = list(range(len(word_to_ix)))
49+
50+
#for each sentence
51+
print('preparing training data (x, y)...')
52+
for sentence in tqdm(split_text):
53+
indices = [word_to_ix[word] for word in sentence]
54+
55+
#for each word treated as center word
56+
for center_word_pos in range(len(indices)):
57+
58+
#for each window position
59+
for w in range(-context_size, context_size+1):
60+
context_word_pos = center_word_pos + w
61+
62+
#make sure we dont jump out of the sentence
63+
if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
64+
continue
65+
66+
context_word_idx = indices[context_word_pos]
67+
center_word_idx = indices[center_word_pos]
68+
69+
if center_word_idx == context_word_idx: # same words might be present in the close vicinity of each other. we want to avoid such cases
70+
continue
71+
72+
training_data.append([center_word_idx, context_word_idx])
73+
74+
return training_data
75+
76+
def load_data(self, data_source, context_size, fraction_data, subsampling, sampling_rate):
77+
78+
stop_words = set(stopwords.words('english'))
79+
80+
if data_source == 'toy':
81+
sents = [
82+
'word1 word2 word3 word4 word5',
83+
'word6 word7 word8 word9 word10',
84+
'word11 word12 word13 word14 word15'
85+
]
86+
# sents = ['word6 word7 word8 word9 word10', 'word1 word1 word1 word2 word2 word3 word4 word5', 'word11 word12 word13 word14 word15']
87+
88+
elif data_source == 'gensim':
89+
import gensim.downloader as api
90+
dataset = api.load("text8")
91+
data = [d for d in dataset][:int(fraction_data*len([d_ for d_ in dataset]))]
92+
print(f'fraction of data taken: {fraction_data}/1')
93+
94+
sents = []
95+
print('forming sentences by joining tokenized words...')
96+
for d in tqdm(data):
97+
sents.append(' '.join(d))
98+
99+
sent_list_tokenized = [word_tokenize(s) for s in sents]
100+
print('len(sent_list_tokenized): ', len(sent_list_tokenized))
101+
102+
# remove the stopwords
103+
sent_list_tokenized_filtered = []
104+
print('lemmatizing and removing stopwords...')
105+
for s in tqdm(sent_list_tokenized):
106+
sent_list_tokenized_filtered.append([lem.lemmatize(w, 'v') for w in s if w not in stop_words])
107+
108+
sent_list_tokenized_filtered, vocab, word_to_ix, ix_to_word = self.gather_word_freqs(sent_list_tokenized_filtered, subsampling, sampling_rate)
109+
110+
training_data = self.gather_training_data(sent_list_tokenized_filtered, word_to_ix, context_size)
111+
112+
return vocab, word_to_ix, ix_to_word, training_data
113+
114+
def gather_word_freqs(self, split_text, subsampling, sampling_rate): #here split_text is sent_list
115+
116+
vocab = {}
117+
ix_to_word = {}
118+
word_to_ix = {}
119+
total = 0.0
120+
121+
print('building vocab...')
122+
for word_tokens in tqdm(split_text):
123+
for word in word_tokens: #for every word in the word list(split_text), which might occur multiple times
124+
if word not in vocab: #only new words allowed
125+
vocab[word] = 0
126+
ix_to_word[len(word_to_ix)] = word
127+
word_to_ix[word] = len(word_to_ix)
128+
vocab[word] += 1.0 #count of the word stored in a dict
129+
total += 1.0 #total number of words in the word_list(split_text)
130+
131+
print('\nsubsampling: ', subsampling)
132+
if subsampling:
133+
134+
print('performing subsampling...')
135+
for sent in tqdm(split_text):
136+
word_tokens = sent
137+
# print('word_tokens: ', word_tokens)
138+
# print('len(word_tokens): ', len(word_tokens), '\n')
139+
for i , word in enumerate(word_tokens):
140+
# print(i, word_tokens[i])
141+
142+
frac = vocab[word]/total
143+
prob = 1 - np.sqrt(sampling_rate/frac)
144+
145+
sampling = np.random.sample()
146+
#print(sampling, prob)
147+
if (sampling < prob):
148+
# print('freq: ', vocab[word_tokens[i]])
149+
del word_tokens[i]
150+
i -= 1
151+
152+
return split_text, vocab, word_to_ix, ix_to_word

main.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from __future__ import print_function
2+
from tqdm import tqdm
3+
# from tqdm import tqdm_gui
4+
import matplotlib
5+
# matplotlib.use('Agg')
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import sys, pdb, os, shutil, pickle
9+
from pprint import pprint
10+
11+
import torch
12+
import torch.optim as optim
13+
import torch.nn as nn
14+
15+
# it is a little tricky on run SummaryWriter by installing a suitable version of pytorch. so if you are able to import SummaryWriter from torch.utils.tensorboard, this script will record summaries. Otherwise it would not.
16+
try:
17+
from torch.utils.tensorboard import SummaryWriter
18+
write_summary = True
19+
except:
20+
write_summary = False
21+
22+
from model import Word2Vec_neg_sampling
23+
from utils_modified import count_parameters
24+
from datasets import word2vec_dataset
25+
from config import *
26+
from test import print_nearest_words
27+
from utils_modified import q
28+
29+
# for tensorboard to work properly on embeddings projections
30+
import tensorflow as tf
31+
import tensorboard as tb
32+
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
33+
34+
# remove MODEL_DIR if it exists
35+
if os.path.exists(MODEL_DIR):
36+
shutil.rmtree(MODEL_DIR)
37+
# create MODEL_DIR
38+
os.makedirs(MODEL_DIR)
39+
40+
# SUMMARY_DIR is the path of the directory where the tensorboard SummaryWriter files are written
41+
if write_summary:
42+
if os.path.exists(SUMMARY_DIR):
43+
# the directory is removed, if it already exists
44+
shutil.rmtree(SUMMARY_DIR)
45+
46+
writer = SummaryWriter(SUMMARY_DIR) # this command automatically creates the directory at SUMMARY_DIR
47+
summary_counter = 0
48+
49+
# make training data
50+
if not os.path.exists(PREPROCESSED_DATA_PATH):
51+
train_dataset = word2vec_dataset(DATA_SOURCE, CONTEXT_SIZE, FRACTION_DATA, SUBSAMPLING, SAMPLING_RATE)
52+
53+
if not os.path.exists(PREPROCESSED_DATA_DIR):
54+
os.makedirs(PREPROCESSED_DATA_DIR)
55+
56+
# pickle dump
57+
print('\ndumping pickle...')
58+
outfile = open(PREPROCESSED_DATA_PATH,'wb')
59+
pickle.dump(train_dataset, outfile)
60+
outfile.close()
61+
print('pickle dumped\n')
62+
63+
else:
64+
# pickle load
65+
print('\nloading pickle...')
66+
infile = open(PREPROCESSED_DATA_PATH,'rb')
67+
train_dataset = pickle.load(infile)
68+
infile.close()
69+
print('pickle loaded\n')
70+
71+
vocab = train_dataset.vocab
72+
word_to_ix = train_dataset.word_to_ix
73+
ix_to_word = train_dataset.ix_to_word
74+
75+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = not True)
76+
print('len(train_dataset): ', len(train_dataset))
77+
print('len(train_loader): ', len(train_loader))
78+
print('len(vocab): ', len(vocab), '\n')
79+
80+
# make noise distribution to sample negative examples from
81+
word_freqs = np.array(list(vocab.values()))
82+
unigram_dist = word_freqs/sum(word_freqs)
83+
noise_dist = torch.from_numpy(unigram_dist**(0.75)/np.sum(unigram_dist**(0.75)))
84+
85+
losses = []
86+
87+
model = Word2Vec_neg_sampling(EMBEDDING_DIM, len(vocab), DEVICE, noise_dist, NEGATIVE_SAMPLES).to(DEVICE)
88+
print('\nWe have {} Million trainable parameters here in the model'.format(count_parameters(model)))
89+
90+
# optimizer = optim.SGD(model.parameters(), lr = 0.008, momentum=0.9)
91+
optimizer = optim.Adam(model.parameters(), lr = LR)
92+
# print(model, '\n')
93+
94+
for epoch in tqdm(range(NUM_EPOCHS)):
95+
print('\n===== EPOCH {}/{} ====='.format(epoch + 1, NUM_EPOCHS))
96+
# print('\nTRAINING...')
97+
98+
# model.train()
99+
for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
100+
print('batch# ' + str(batch_idx+1).zfill(len(str(len(train_loader)))) + '/' + str(len(train_loader)), end = '\r')
101+
102+
model.train()
103+
104+
x_batch = x_batch.to(DEVICE)
105+
y_batch = y_batch.to(DEVICE)
106+
107+
optimizer.zero_grad()
108+
loss = model(x_batch, y_batch)
109+
110+
loss.backward()
111+
optimizer.step()
112+
113+
losses.append(loss.item())
114+
if write_summary:
115+
# write tensorboard summaries
116+
writer.add_scalar(f'batch_loss', loss.item(), summary_counter)
117+
summary_counter += 1
118+
119+
if batch_idx%DISPLAY_EVERY_N_BATCH == 0 and DISPLAY_BATCH_LOSS:
120+
print(f'Batch: {batch_idx+1}/{len(train_loader)}, Loss: {loss.item()}')
121+
# show 5 closest words to some test words
122+
print_nearest_words(model, TEST_WORDS, word_to_ix, ix_to_word, top = 5)
123+
124+
# write embeddings every SAVE_EVERY_N_EPOCH epoch
125+
if epoch%SAVE_EVERY_N_EPOCH == 0:
126+
writer.add_embedding(model.embeddings_input.weight.data, metadata=[ix_to_word[k] for k in range(len(ix_to_word))], global_step=epoch)
127+
128+
torch.save({'model_state_dict': model.state_dict(),
129+
'losses': losses,
130+
'word_to_ix': word_to_ix,
131+
'ix_to_word': ix_to_word
132+
},
133+
'{}/model{}.pth'.format(MODEL_DIR, epoch))
134+
135+
plt.figure(figsize = (50, 50))
136+
plt.xlabel("batches")
137+
plt.ylabel("batch_loss")
138+
plt.title("loss vs #batch")
139+
140+
plt.plot(losses)
141+
plt.savefig('losses.png')
142+
plt.show()
143+
144+
# '''
145+
EMBEDDINGS = model.embeddings_input.weight.data
146+
print('EMBEDDINGS.shape: ', EMBEDDINGS.shape)
147+
148+
from sklearn.manifold import TSNE
149+
150+
print('\n', 'running TSNE...')
151+
tsne = TSNE(n_components = 2).fit_transform(EMBEDDINGS.cpu())
152+
print('tsne.shape: ', tsne.shape) #(15, 2)
153+
154+
############ VISUALIZING ############
155+
x, y = [], []
156+
annotations = []
157+
for idx, coord in enumerate(tsne):
158+
# print(coord)
159+
annotations.append(ix_to_word[idx])
160+
x.append(coord[0])
161+
y.append(coord[1])
162+
163+
# test_words = ['king', 'queen', 'berlin', 'capital', 'germany', 'palace', 'stays']
164+
# test_words = ['sun', 'moon', 'earth', 'while', 'open', 'run', 'distance', 'energy', 'coal', 'exploit']
165+
# test_words = ['amazing', 'beautiful', 'work', 'breakfast', 'husband', 'hotel', 'quick', 'cockroach']
166+
167+
test_words = TEST_WORDS_VIZ
168+
print('test_words: ', test_words)
169+
170+
plt.figure(figsize = (50, 50))
171+
for i in range(len(test_words)):
172+
word = test_words[i]
173+
#print('word: ', word)
174+
vocab_idx = word_to_ix[word]
175+
# print('vocab_idx: ', vocab_idx)
176+
plt.scatter(x[vocab_idx], y[vocab_idx])
177+
plt.annotate(word, xy = (x[vocab_idx], y[vocab_idx]), \
178+
ha='right',va='bottom')
179+
180+
plt.savefig("w2v.png")
181+
plt.show()
182+
# '''

0 commit comments

Comments
 (0)