Skip to content
This repository was archived by the owner on Jan 19, 2022. It is now read-only.

Add converter.py #5

Merged
merged 1 commit into from
Nov 11, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as xp # TODO: cupyにも対応する

# データ変換クラスの定義
class DataConverter:
def __init__(self, batch_col_size=15):
'''
クラスの初期化
:param batch_col_size: 学習時のミニバッチ単語数サイズ
'''
self.mecab = MeCab.Tagger('-d /usr/local/Cellar/mecab-ipadic/2.7.0-20070801/lib/mecab/dic/ipadic') # 形態素解析器
self.vocab = { "<eos>": 0, "<unk>": 1 } # 単語辞書
self.batch_col_size = batch_col_size

def load(self, inputs, outputs):
'''
学習時に、教師データを読み込んでミニバッチサイズに対応したNumpy配列に変換する
:param inputs: inputデータ(list)
:param outputs: outputデータ(list)
'''
# 単語辞書の登録
self.vocab = { "<eos>": 0, "<unk>": 1 } # 単語辞書を初期化
queries, responses = [], []
for sentence_in, sentence_out in zip(inputs, outputs):
words_in = self.sentence2words(sentence_in)
words_out = self.sentence2words(sentence_out)
# 教師データのID化と整理
queries.append(self.sentence2ids(words=words_in, sentence_type="query"))
responses.append(self.sentence2ids(words=words_out, sentence_type="response"))
for word_in, word_out in zip(words_in, words_out):
if word_in not in self.vocab:
self.vocab[word_in] = len(self.vocab)
if word_out not in self.vocab:
self.vocab[word_out] = len(self.vocab)
self.train_queries = xp.vstack(queries)
self.train_responses = xp.vstack(responses)

def sentence2words(self, sentence):
'''
文章を単語の配列にして返却する
:param sentence: 文章文字列
:return: mecabでparseして単語ごとに分割したsentence
'''
sentence_words = []
for m in self.mecab.parse(sentence).split("\n"): # 形態素解析で単語に分解する
w = m.split("\t")[0].lower() # 単語
if len(w) == 0 or w == "eos": # 不正文字、eosは省略
continue
sentence_words.append(w)
sentence_words.append("<eos>") # 最後にvocabに登録している<eos>を代入する
return sentence_words

def sentence2ids(self, words, train=True, sentence_type="query"):
'''
文章を単語IDのNumpy配列に変換して返却する
:param sentence: 文章文字列
:param train: 学習用かどうか
:sentence_type: 学習用でミニバッチ対応のためのサイズ補填方向をクエリー・レスポンスで変更するため"query"or"response"を指定 
:return: 単語IDのNumpy配列
'''
ids = [] # 単語IDに変換して格納する配列
for word in words:
if word in self.vocab: # 単語辞書に存在する単語ならば、IDに変換する
ids.append(self.vocab[word])
else: # 単語辞書に存在しない単語ならば、<unk>に変換する
ids.append(self.vocab["<unk>"])
# 学習時は、ミニバッチ対応のため、単語数サイズを調整してNumpy変換する
if train:
if sentence_type == "query": # クエリーの場合は前方にミニバッチ単語数サイズになるまで-1を補填する
while len(ids) > self.batch_col_size: # ミニバッチ単語サイズよりも大きければ、ミニバッチ単語サイズになるまで先頭から削る
ids.pop(0)
ids = xp.array([-1]*(self.batch_col_size-len(ids))+ids, dtype="int32")
elif sentence_type == "response": # レスポンスの場合は後方にミニバッチ単語数サイズになるまで-1を補填する
while len(ids) > self.batch_col_size: # ミニバッチ単語サイズよりも大きければ、ミニバッチ単語サイズになるまで末尾から削る
ids.pop()
ids = xp.array(ids+[-1]*(self.batch_col_size-len(ids)), dtype="int32")
else: # 予測時は、そのままNumpy変換する
ids = xp.array([ids], dtype="int32")
return ids

# 予測時に使用する関数
def ids2words(self, ids):
'''
予測時に、単語IDのNumpy配列を単語に変換して返却する
:param ids: 単語IDのNumpy配列
:return: 単語の配列
'''
words = [] # 単語を格納する配列
for i in ids: # 順番に単語IDを単語辞書から参照して単語に変換する
words.append(list(self.vocab.keys())[list(self.vocab.values()).index(i)])
return words