1
+ from __future__ import print_function
2
+ from tqdm import tqdm
3
+ # from tqdm import tqdm_gui
4
+ import matplotlib
5
+ # matplotlib.use('Agg')
6
+ import matplotlib .pyplot as plt
7
+ import numpy as np
8
+ import sys , pdb , os , shutil , pickle
9
+ from pprint import pprint
10
+
11
+ import torch
12
+ import torch .optim as optim
13
+ import torch .nn as nn
14
+
15
+ # it is a little tricky on run SummaryWriter by installing a suitable version of pytorch. so if you are able to import SummaryWriter from torch.utils.tensorboard, this script will record summaries. Otherwise it would not.
16
+ try :
17
+ from torch .utils .tensorboard import SummaryWriter
18
+ write_summary = True
19
+ except :
20
+ write_summary = False
21
+
22
+ from model import Word2Vec_neg_sampling
23
+ from utils_modified import count_parameters
24
+ from datasets import word2vec_dataset
25
+ from config import *
26
+ from test import print_nearest_words
27
+ from utils_modified import q
28
+
29
+ # for tensorboard to work properly on embeddings projections
30
+ import tensorflow as tf
31
+ import tensorboard as tb
32
+ tf .io .gfile = tb .compat .tensorflow_stub .io .gfile
33
+
34
+ # remove MODEL_DIR if it exists
35
+ if os .path .exists (MODEL_DIR ):
36
+ shutil .rmtree (MODEL_DIR )
37
+ # create MODEL_DIR
38
+ os .makedirs (MODEL_DIR )
39
+
40
+ # SUMMARY_DIR is the path of the directory where the tensorboard SummaryWriter files are written
41
+ if write_summary :
42
+ if os .path .exists (SUMMARY_DIR ):
43
+ # the directory is removed, if it already exists
44
+ shutil .rmtree (SUMMARY_DIR )
45
+
46
+ writer = SummaryWriter (SUMMARY_DIR ) # this command automatically creates the directory at SUMMARY_DIR
47
+ summary_counter = 0
48
+
49
+ # make training data
50
+ if not os .path .exists (PREPROCESSED_DATA_PATH ):
51
+ train_dataset = word2vec_dataset (DATA_SOURCE , CONTEXT_SIZE , FRACTION_DATA , SUBSAMPLING , SAMPLING_RATE )
52
+
53
+ if not os .path .exists (PREPROCESSED_DATA_DIR ):
54
+ os .makedirs (PREPROCESSED_DATA_DIR )
55
+
56
+ # pickle dump
57
+ print ('\n dumping pickle...' )
58
+ outfile = open (PREPROCESSED_DATA_PATH ,'wb' )
59
+ pickle .dump (train_dataset , outfile )
60
+ outfile .close ()
61
+ print ('pickle dumped\n ' )
62
+
63
+ else :
64
+ # pickle load
65
+ print ('\n loading pickle...' )
66
+ infile = open (PREPROCESSED_DATA_PATH ,'rb' )
67
+ train_dataset = pickle .load (infile )
68
+ infile .close ()
69
+ print ('pickle loaded\n ' )
70
+
71
+ vocab = train_dataset .vocab
72
+ word_to_ix = train_dataset .word_to_ix
73
+ ix_to_word = train_dataset .ix_to_word
74
+
75
+ train_loader = torch .utils .data .DataLoader (train_dataset , batch_size = BATCH_SIZE , shuffle = not True )
76
+ print ('len(train_dataset): ' , len (train_dataset ))
77
+ print ('len(train_loader): ' , len (train_loader ))
78
+ print ('len(vocab): ' , len (vocab ), '\n ' )
79
+
80
+ # make noise distribution to sample negative examples from
81
+ word_freqs = np .array (list (vocab .values ()))
82
+ unigram_dist = word_freqs / sum (word_freqs )
83
+ noise_dist = torch .from_numpy (unigram_dist ** (0.75 )/ np .sum (unigram_dist ** (0.75 )))
84
+
85
+ losses = []
86
+
87
+ model = Word2Vec_neg_sampling (EMBEDDING_DIM , len (vocab ), DEVICE , noise_dist , NEGATIVE_SAMPLES ).to (DEVICE )
88
+ print ('\n We have {} Million trainable parameters here in the model' .format (count_parameters (model )))
89
+
90
+ # optimizer = optim.SGD(model.parameters(), lr = 0.008, momentum=0.9)
91
+ optimizer = optim .Adam (model .parameters (), lr = LR )
92
+ # print(model, '\n')
93
+
94
+ for epoch in tqdm (range (NUM_EPOCHS )):
95
+ print ('\n ===== EPOCH {}/{} =====' .format (epoch + 1 , NUM_EPOCHS ))
96
+ # print('\nTRAINING...')
97
+
98
+ # model.train()
99
+ for batch_idx , (x_batch , y_batch ) in enumerate (train_loader ):
100
+ print ('batch# ' + str (batch_idx + 1 ).zfill (len (str (len (train_loader )))) + '/' + str (len (train_loader )), end = '\r ' )
101
+
102
+ model .train ()
103
+
104
+ x_batch = x_batch .to (DEVICE )
105
+ y_batch = y_batch .to (DEVICE )
106
+
107
+ optimizer .zero_grad ()
108
+ loss = model (x_batch , y_batch )
109
+
110
+ loss .backward ()
111
+ optimizer .step ()
112
+
113
+ losses .append (loss .item ())
114
+ if write_summary :
115
+ # write tensorboard summaries
116
+ writer .add_scalar (f'batch_loss' , loss .item (), summary_counter )
117
+ summary_counter += 1
118
+
119
+ if batch_idx % DISPLAY_EVERY_N_BATCH == 0 and DISPLAY_BATCH_LOSS :
120
+ print (f'Batch: { batch_idx + 1 } /{ len (train_loader )} , Loss: { loss .item ()} ' )
121
+ # show 5 closest words to some test words
122
+ print_nearest_words (model , TEST_WORDS , word_to_ix , ix_to_word , top = 5 )
123
+
124
+ # write embeddings every SAVE_EVERY_N_EPOCH epoch
125
+ if epoch % SAVE_EVERY_N_EPOCH == 0 :
126
+ writer .add_embedding (model .embeddings_input .weight .data , metadata = [ix_to_word [k ] for k in range (len (ix_to_word ))], global_step = epoch )
127
+
128
+ torch .save ({'model_state_dict' : model .state_dict (),
129
+ 'losses' : losses ,
130
+ 'word_to_ix' : word_to_ix ,
131
+ 'ix_to_word' : ix_to_word
132
+ },
133
+ '{}/model{}.pth' .format (MODEL_DIR , epoch ))
134
+
135
+ plt .figure (figsize = (50 , 50 ))
136
+ plt .xlabel ("batches" )
137
+ plt .ylabel ("batch_loss" )
138
+ plt .title ("loss vs #batch" )
139
+
140
+ plt .plot (losses )
141
+ plt .savefig ('losses.png' )
142
+ plt .show ()
143
+
144
+ # '''
145
+ EMBEDDINGS = model .embeddings_input .weight .data
146
+ print ('EMBEDDINGS.shape: ' , EMBEDDINGS .shape )
147
+
148
+ from sklearn .manifold import TSNE
149
+
150
+ print ('\n ' , 'running TSNE...' )
151
+ tsne = TSNE (n_components = 2 ).fit_transform (EMBEDDINGS .cpu ())
152
+ print ('tsne.shape: ' , tsne .shape ) #(15, 2)
153
+
154
+ ############ VISUALIZING ############
155
+ x , y = [], []
156
+ annotations = []
157
+ for idx , coord in enumerate (tsne ):
158
+ # print(coord)
159
+ annotations .append (ix_to_word [idx ])
160
+ x .append (coord [0 ])
161
+ y .append (coord [1 ])
162
+
163
+ # test_words = ['king', 'queen', 'berlin', 'capital', 'germany', 'palace', 'stays']
164
+ # test_words = ['sun', 'moon', 'earth', 'while', 'open', 'run', 'distance', 'energy', 'coal', 'exploit']
165
+ # test_words = ['amazing', 'beautiful', 'work', 'breakfast', 'husband', 'hotel', 'quick', 'cockroach']
166
+
167
+ test_words = TEST_WORDS_VIZ
168
+ print ('test_words: ' , test_words )
169
+
170
+ plt .figure (figsize = (50 , 50 ))
171
+ for i in range (len (test_words )):
172
+ word = test_words [i ]
173
+ #print('word: ', word)
174
+ vocab_idx = word_to_ix [word ]
175
+ # print('vocab_idx: ', vocab_idx)
176
+ plt .scatter (x [vocab_idx ], y [vocab_idx ])
177
+ plt .annotate (word , xy = (x [vocab_idx ], y [vocab_idx ]), \
178
+ ha = 'right' ,va = 'bottom' )
179
+
180
+ plt .savefig ("w2v.png" )
181
+ plt .show ()
182
+ # '''
0 commit comments