This repository was archived by the owner on Jan 19, 2022. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from matken11235/converter
Add converter.py
- Loading branch information
Showing
1 changed file
with
90 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import numpy as xp # TODO: cupyにも対応する | ||
|
||
# データ変換クラスの定義 | ||
class DataConverter: | ||
def __init__(self, batch_col_size=15): | ||
''' | ||
クラスの初期化 | ||
:param batch_col_size: 学習時のミニバッチ単語数サイズ | ||
''' | ||
self.mecab = MeCab.Tagger('-d /usr/local/Cellar/mecab-ipadic/2.7.0-20070801/lib/mecab/dic/ipadic') # 形態素解析器 | ||
self.vocab = { "<eos>": 0, "<unk>": 1 } # 単語辞書 | ||
self.batch_col_size = batch_col_size | ||
|
||
def load(self, inputs, outputs): | ||
''' | ||
学習時に、教師データを読み込んでミニバッチサイズに対応したNumpy配列に変換する | ||
:param inputs: inputデータ(list) | ||
:param outputs: outputデータ(list) | ||
''' | ||
# 単語辞書の登録 | ||
self.vocab = { "<eos>": 0, "<unk>": 1 } # 単語辞書を初期化 | ||
queries, responses = [], [] | ||
for sentence_in, sentence_out in zip(inputs, outputs): | ||
words_in = self.sentence2words(sentence_in) | ||
words_out = self.sentence2words(sentence_out) | ||
# 教師データのID化と整理 | ||
queries.append(self.sentence2ids(words=words_in, sentence_type="query")) | ||
responses.append(self.sentence2ids(words=words_out, sentence_type="response")) | ||
for word_in, word_out in zip(words_in, words_out): | ||
if word_in not in self.vocab: | ||
self.vocab[word_in] = len(self.vocab) | ||
if word_out not in self.vocab: | ||
self.vocab[word_out] = len(self.vocab) | ||
self.train_queries = xp.vstack(queries) | ||
self.train_responses = xp.vstack(responses) | ||
|
||
def sentence2words(self, sentence): | ||
''' | ||
文章を単語の配列にして返却する | ||
:param sentence: 文章文字列 | ||
:return: mecabでparseして単語ごとに分割したsentence | ||
''' | ||
sentence_words = [] | ||
for m in self.mecab.parse(sentence).split("\n"): # 形態素解析で単語に分解する | ||
w = m.split("\t")[0].lower() # 単語 | ||
if len(w) == 0 or w == "eos": # 不正文字、eosは省略 | ||
continue | ||
sentence_words.append(w) | ||
sentence_words.append("<eos>") # 最後にvocabに登録している<eos>を代入する | ||
return sentence_words | ||
|
||
def sentence2ids(self, words, train=True, sentence_type="query"): | ||
''' | ||
文章を単語IDのNumpy配列に変換して返却する | ||
:param sentence: 文章文字列 | ||
:param train: 学習用かどうか | ||
:sentence_type: 学習用でミニバッチ対応のためのサイズ補填方向をクエリー・レスポンスで変更するため"query"or"response"を指定 | ||
:return: 単語IDのNumpy配列 | ||
''' | ||
ids = [] # 単語IDに変換して格納する配列 | ||
for word in words: | ||
if word in self.vocab: # 単語辞書に存在する単語ならば、IDに変換する | ||
ids.append(self.vocab[word]) | ||
else: # 単語辞書に存在しない単語ならば、<unk>に変換する | ||
ids.append(self.vocab["<unk>"]) | ||
# 学習時は、ミニバッチ対応のため、単語数サイズを調整してNumpy変換する | ||
if train: | ||
if sentence_type == "query": # クエリーの場合は前方にミニバッチ単語数サイズになるまで-1を補填する | ||
while len(ids) > self.batch_col_size: # ミニバッチ単語サイズよりも大きければ、ミニバッチ単語サイズになるまで先頭から削る | ||
ids.pop(0) | ||
ids = xp.array([-1]*(self.batch_col_size-len(ids))+ids, dtype="int32") | ||
elif sentence_type == "response": # レスポンスの場合は後方にミニバッチ単語数サイズになるまで-1を補填する | ||
while len(ids) > self.batch_col_size: # ミニバッチ単語サイズよりも大きければ、ミニバッチ単語サイズになるまで末尾から削る | ||
ids.pop() | ||
ids = xp.array(ids+[-1]*(self.batch_col_size-len(ids)), dtype="int32") | ||
else: # 予測時は、そのままNumpy変換する | ||
ids = xp.array([ids], dtype="int32") | ||
return ids | ||
|
||
# 予測時に使用する関数 | ||
def ids2words(self, ids): | ||
''' | ||
予測時に、単語IDのNumpy配列を単語に変換して返却する | ||
:param ids: 単語IDのNumpy配列 | ||
:return: 単語の配列 | ||
''' | ||
words = [] # 単語を格納する配列 | ||
for i in ids: # 順番に単語IDを単語辞書から参照して単語に変換する | ||
words.append(list(self.vocab.keys())[list(self.vocab.values()).index(i)]) | ||
return words |