1
1
import kenlm
2
- from request import ModelRequest
2
+ from request import ModelRequest , ModelUpdateRequest
3
3
import Levenshtein
4
4
5
+ from symspellpy import SymSpell , Verbosity
6
+
7
+ from collections import Counter
8
+
5
9
model_paths = {
6
10
'ory' : '5gram_model.bin' ,
7
11
'eng' : '5gram_model_eng.bin'
12
16
'eng' : 'lexicon_eng.txt'
13
17
}
14
18
19
+ freq_dict_paths = {
20
+ 'ory' : 'freq_dict.txt' ,
21
+ 'eng' : 'freq_dict_eng.txt'
22
+ }
23
+
15
24
16
25
class TextCorrector :
17
- def __init__ (self , model_paths , vocab_paths ):
26
+ def __init__ (self , model_paths , vocab_paths , freq_dict_paths ):
18
27
# Initialize both models and vocabularies
19
28
self .models = {
20
29
'ory' : kenlm .Model (model_paths ['ory' ]),
@@ -24,13 +33,19 @@ def __init__(self, model_paths, vocab_paths):
24
33
'ory' : self .create_vocab_lexicon (vocab_paths ['ory' ]),
25
34
'eng' : self .create_vocab_lexicon (vocab_paths ['eng' ])
26
35
}
36
+
37
+ self .symspell_models = {
38
+ 'ory' : self .create_symspell_model (freq_dict_paths ['ory' ]),
39
+ 'eng' : self .create_symspell_model (freq_dict_paths ['eng' ])
40
+ }
27
41
# Set the default language
28
42
self .set_language ('ory' )
29
43
30
44
def set_language (self , lang ):
31
45
# Switch the model and vocabulary based on language
32
46
self .model = self .models [lang ]
33
47
self .vocab = self .vocabs [lang ]
48
+ self .symspell_model = self .symspell_models [lang ]
34
49
35
50
def create_vocab_lexicon (self , lexicon_path ):
36
51
vocabulary = []
@@ -40,14 +55,23 @@ def create_vocab_lexicon(self, lexicon_path):
40
55
vocabulary .append (word )
41
56
return vocabulary
42
57
58
+ def create_symspell_model (self , freq_dict_path ):
59
+ sym_spell = SymSpell (max_dictionary_edit_distance = 2 , prefix_length = 7 )
60
+ sym_spell .load_dictionary (freq_dict_path , term_index = 0 , count_index = 1 , separator = ' ' )
61
+ return sym_spell
62
+
63
+ # def generate_candidates(self, word, max_distance=1):
64
+ # len_range = range(len(word) - max_distance, len(word) + max_distance + 1)
65
+ # filtered_vocab = [vocab_word for vocab_word in self.vocab if len(vocab_word) in len_range]
66
+ # return [vocab_word for vocab_word in filtered_vocab if 0 <= Levenshtein.distance(word, vocab_word) <= max_distance]
67
+
43
68
def generate_candidates (self , word , max_distance = 1 ):
44
- len_range = range (len (word ) - max_distance , len (word ) + max_distance + 1 )
45
- filtered_vocab = [vocab_word for vocab_word in self .vocab if len (vocab_word ) in len_range ]
46
- return [vocab_word for vocab_word in filtered_vocab if 0 <= Levenshtein .distance (word , vocab_word ) <= max_distance ]
69
+ suggestions = self .symspell_model .lookup (word , Verbosity .CLOSEST , max_distance )
70
+ return [suggestion .term for suggestion in suggestions ]
47
71
48
72
def beam_search (self , chunk , BEAM_WIDTH = 5 , SCORE_THRESHOLD = 1.5 , max_distance = 1 ):
49
73
original_score = self .model .score (' ' .join (chunk ))
50
-
74
+
51
75
initial_candidates = self .generate_candidates (chunk [0 ], max_distance = 1 )
52
76
if not initial_candidates :
53
77
initial_candidates = [chunk [0 ]]
@@ -88,11 +112,55 @@ def correct_text_with_beam_search(self, text, BEAM_WIDTH=5, SCORE_THRESHOLD=1.5,
88
112
corrected_sentences .append (best_sentence )
89
113
90
114
return ' ' .join (corrected_sentences )
115
+
116
+ def load_freq_dict (self , freq_dict_path ):
117
+ freq_dict = {}
118
+ with open (freq_dict_path , 'r' ) as f :
119
+ for line in f :
120
+ word , freq = line .split ()
121
+ freq_dict [word ] = int (freq )
122
+ return freq_dict
123
+
124
+ def make_updation_counter (self , text ):
125
+
126
+ if type (text ) == list :
127
+ text = ' ' .join (text )
128
+
129
+ # remove punctuations from the text
130
+ text = '' .join (e for e in text if e .isalnum () or e .isspace ())
131
+ words = text .split ()
132
+
133
+ # create a dictionary of words and their frequencies
134
+ dict = Counter (words )
135
+
136
+ return dict
137
+
138
+ def update_symspell_model (self , lang , text ):
139
+ # update the frequency dictionary
140
+ current_freq_dict_counter = Counter (self .load_freq_dict (freq_dict_paths [lang ]))
141
+ new_freq_dict_counter = self .make_updation_counter (text )
142
+
143
+ # merge the two frequency dictionaries
144
+ freq_dict_counter = current_freq_dict_counter + new_freq_dict_counter
145
+
146
+ freq_dict = {}
147
+ for word , freq in freq_dict_counter .items ():
148
+ freq_dict [word ] = int (freq )
149
+
150
+ with open (freq_dict_paths [lang ], 'w' ) as f :
151
+ for word , freq in freq_dict .items ():
152
+ f .write (word + ' ' + str (freq ) + '\n ' )
153
+
154
+ # retrain the model with the updated frequency dictionary
155
+ self .symspell_models [lang ] = self .create_symspell_model (freq_dict_paths [lang ])
156
+
157
+ return 'Model updated successfully'
158
+
91
159
92
160
class Model ():
93
- def __init__ (self , context , model_paths , vocab_paths ):
161
+ def __init__ (self , context , model_paths , vocab_paths , freq_dict_paths ):
94
162
self .context = context
95
- self .text_corrector = TextCorrector (model_paths , vocab_paths )
163
+ self .text_corrector = TextCorrector (model_paths , vocab_paths , freq_dict_paths )
96
164
97
165
async def inference (self , request : ModelRequest ):
98
166
# Set the correct language model based on the request
@@ -105,3 +173,12 @@ async def inference(self, request: ModelRequest):
105
173
max_distance = request .max_distance
106
174
)
107
175
return corrected_text
176
+
177
+ async def update_symspell (self , request : ModelUpdateRequest ):
178
+ # Set the correct language model based on the request
179
+ self .text_corrector .set_language (request .lang )
180
+
181
+ # Update the model with the new data
182
+ self .text_corrector .update_symspell_model (request .lang , request .text )
183
+
184
+ return 'Model updated successfully'
0 commit comments