From 46bf713d07db87c98287e2f3d14d1fda8195e8d2 Mon Sep 17 00:00:00 2001 From: matken11235 Date: Sun, 12 Nov 2017 00:58:13 +0900 Subject: [PATCH] Add converter.py --- converter.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 converter.py diff --git a/converter.py b/converter.py new file mode 100644 index 0000000..c51843c --- /dev/null +++ b/converter.py @@ -0,0 +1,90 @@ +import numpy as xp # TODO: cupyにも対応する + +# データ変換クラスの定義 +class DataConverter: + def __init__(self, batch_col_size=15): + ''' + クラスの初期化 + :param batch_col_size: 学習時のミニバッチ単語数サイズ + ''' + self.mecab = MeCab.Tagger('-d /usr/local/Cellar/mecab-ipadic/2.7.0-20070801/lib/mecab/dic/ipadic') # 形態素解析器 + self.vocab = { "": 0, "": 1 } # 単語辞書 + self.batch_col_size = batch_col_size + + def load(self, inputs, outputs): + ''' + 学習時に、教師データを読み込んでミニバッチサイズに対応したNumpy配列に変換する + :param inputs: inputデータ(list) + :param outputs: outputデータ(list) + ''' + # 単語辞書の登録 + self.vocab = { "": 0, "": 1 } # 単語辞書を初期化 + queries, responses = [], [] + for sentence_in, sentence_out in zip(inputs, outputs): + words_in = self.sentence2words(sentence_in) + words_out = self.sentence2words(sentence_out) + # 教師データのID化と整理 + queries.append(self.sentence2ids(words=words_in, sentence_type="query")) + responses.append(self.sentence2ids(words=words_out, sentence_type="response")) + for word_in, word_out in zip(words_in, words_out): + if word_in not in self.vocab: + self.vocab[word_in] = len(self.vocab) + if word_out not in self.vocab: + self.vocab[word_out] = len(self.vocab) + self.train_queries = xp.vstack(queries) + self.train_responses = xp.vstack(responses) + + def sentence2words(self, sentence): + ''' + 文章を単語の配列にして返却する + :param sentence: 文章文字列 + :return: mecabでparseして単語ごとに分割したsentence + ''' + sentence_words = [] + for m in self.mecab.parse(sentence).split("\n"): # 形態素解析で単語に分解する + w = m.split("\t")[0].lower() # 単語 + if len(w) == 0 or w == "eos": # 不正文字、eosは省略 + continue + sentence_words.append(w) + sentence_words.append("") # 最後にvocabに登録しているを代入する + return sentence_words + + def sentence2ids(self, words, train=True, sentence_type="query"): + ''' + 文章を単語IDのNumpy配列に変換して返却する + :param sentence: 文章文字列 + :param train: 学習用かどうか + :sentence_type: 学習用でミニバッチ対応のためのサイズ補填方向をクエリー・レスポンスで変更するため"query"or"response"を指定  + :return: 単語IDのNumpy配列 + ''' + ids = [] # 単語IDに変換して格納する配列 + for word in words: + if word in self.vocab: # 単語辞書に存在する単語ならば、IDに変換する + ids.append(self.vocab[word]) + else: # 単語辞書に存在しない単語ならば、に変換する + ids.append(self.vocab[""]) + # 学習時は、ミニバッチ対応のため、単語数サイズを調整してNumpy変換する + if train: + if sentence_type == "query": # クエリーの場合は前方にミニバッチ単語数サイズになるまで-1を補填する + while len(ids) > self.batch_col_size: # ミニバッチ単語サイズよりも大きければ、ミニバッチ単語サイズになるまで先頭から削る + ids.pop(0) + ids = xp.array([-1]*(self.batch_col_size-len(ids))+ids, dtype="int32") + elif sentence_type == "response": # レスポンスの場合は後方にミニバッチ単語数サイズになるまで-1を補填する + while len(ids) > self.batch_col_size: # ミニバッチ単語サイズよりも大きければ、ミニバッチ単語サイズになるまで末尾から削る + ids.pop() + ids = xp.array(ids+[-1]*(self.batch_col_size-len(ids)), dtype="int32") + else: # 予測時は、そのままNumpy変換する + ids = xp.array([ids], dtype="int32") + return ids + + # 予測時に使用する関数 + def ids2words(self, ids): + ''' + 予測時に、単語IDのNumpy配列を単語に変換して返却する + :param ids: 単語IDのNumpy配列 + :return: 単語の配列 + ''' + words = [] # 単語を格納する配列 + for i in ids: # 順番に単語IDを単語辞書から参照して単語に変換する + words.append(list(self.vocab.keys())[list(self.vocab.values()).index(i)]) + return words \ No newline at end of file