Skip to content

Commit e4763f4

Browse files
authored
added function count_parameters
1 parent 16f7c9d commit e4763f4

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

utils_modified.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
def q():
55
sys.exit()
66

7+
# define a function to count the total number of trainable parameters
8+
def count_parameters(model):
9+
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
10+
return num_parameters/1e6 # in terms of millions
11+
712
# TEST
813
def nearest_word(inp, emb, top = 5, debug = False):
914
euclidean_dis = np.linalg.norm(inp - emb, axis = 1)
@@ -18,4 +23,4 @@ def nearest_word(inp, emb, top = 5, debug = False):
1823
print('emb_ranking: ', emb_ranking)
1924
print(f'top {top} embeddings are: {emb_ranking[:top]} with respective distances\n {euclidean_dis_top}')
2025

21-
return emb_ranking_top, euclidean_dis_top
26+
return emb_ranking_top, euclidean_dis_top

0 commit comments

Comments
 (0)