-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBadWord.py
77 lines (55 loc) · 2.51 KB
/
BadWord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import pickle
from string import ascii_lowercase, ascii_uppercase
from collections.abc import Iterable
import numpy as np
from hgtk.text import compose, decompose
from tensorflow import reduce_sum
from tensorflow.keras import callbacks, layers, metrics
from tensorflow.keras.utils import plot_model, to_categorical
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.preprocessing.sequence import pad_sequences
directory = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(directory, 'model')
def get_path(filename):
return os.path.join(path, filename)
if os.path.isfile(get_path('chardict.pkl')):
with open(get_path('chardict.pkl'), 'rb') as f:
char_dict = pickle.load(f)
else:
jaem = ['ㄱ', 'ㄴ', 'ㄷ', 'ㄹ', 'ㅁ', 'ㅂ', 'ㅅ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ', 'ㄲ', 'ㄸ', 'ㅃ', 'ㅆ', 'ㅉ', 'ㄳ', 'ㄵ', 'ㄶ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄽ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅄ']
moem = ['ㅏ', 'ㅑ', 'ㅓ', 'ㅕ', 'ㅗ', 'ㅛ', 'ㅜ', 'ㅠ', 'ㅡ', 'ㅣ', 'ㅐ', 'ㅒ', 'ㅔ', 'ㅖ', 'ㅘ', 'ㅙ', 'ㅚ', 'ㅝ', 'ㅞ', 'ㅟ', 'ㅢ']
english = list(ascii_lowercase) + list(ascii_uppercase)
sign = [s for s in ''' `~!@#$%^&*()+-/=_,.?;:'"[]{}<>\|''']
link_list = sign + jaem + moem + english + [str(i) for i in range(10)]
char_dict = {k: code + 2 for code, k in enumerate(link_list)} # 0은 padding, 1은 oov
with open(get_path('chardict.pkl'), 'wb') as f:
pickle.dump(char_dict, f, pickle.HIGHEST_PROTOCOL)
vocab_size = len(char_dict) + 2 # padding, OOV 포함!
maxlen = 60
def encode(text: str) -> list:
assert isinstance(text, str), "text argument must be str."
text = decompose(str(text)).replace('ᴥ', '')
code = [char_dict.get(t, 1) for t in text]
return code
def preprocessing(data: Iterable) -> np.ndarray:
if isinstance(data, str):
data = [encode(data)]
elif isinstance(data, Iterable):
data = [encode(t) for t in data]
else:
assert True, "data argument must be str or Iterable object."
data = pad_sequences(data, maxlen)
return to_categorical(data, vocab_size)
def load_badword_model() -> Model:
model = load_model(get_path('model.h5'))
model.compile(
loss="binary_crossentropy",
optimizer="adam",
metrics=[
metrics.BinaryAccuracy(name="acc"),
metrics.Recall(name="recall"),
metrics.Precision(name="prec"),
]
)
return model