Skip to content

Improvements to sentence data structure #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 14, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions xnmt/sent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ def __init__(self, idx: Optional[int] = None, score: Optional[numbers.Real] = No
self.idx = idx
self.score = score

def __getitem__(self, key):
"""
Get an item or a slice of the sentence.

Args:
key: index or slice

Returns:
A single word or a Sentence object, depending on whether an index or a slice was given as key.
"""
raise NotImplementedError("must be implemented by subclasses")

def sent_len(self) -> int:
"""
Return length of input, included padded tokens.
Expand Down Expand Up @@ -57,6 +69,18 @@ def create_truncated_sent(self, trunc_len: numbers.Integral) -> 'Sentence':
"""
raise NotImplementedError("must be implemented by subclasses")

def get_unpadded_sent(self) -> 'Sentence':
"""
Return the unpadded sentence.

If self is unpadded, return self, if not return reference to original unpadded sentence if possible, otherwise
create a new sentence.
"""
if self.sent_len() == self.len_unpadded():
return self
else:
return self[:self.len_unpadded()]

class ReadableSentence(Sentence):
"""
A base class for sentences based on readable strings.
Expand Down Expand Up @@ -121,6 +145,15 @@ def __init__(self, value: numbers.Integral, idx: Optional[numbers.Integral] = No
super().__init__(idx=idx, score=score)
self.value = value
self.vocab = vocab
def __getitem__(self, key):
if isinstance(key, numbers.Integral):
if key!=0: raise IndexError()
return self.value
else:
if not isinstance(key, slice):
raise TypeError()
if key.start!=0 and key.stop!=1: raise IndexError()
return self
def sent_len(self) -> int:
return 1
def len_unpadded(self) -> int:
Expand All @@ -133,6 +166,8 @@ def create_truncated_sent(self, trunc_len: numbers.Integral) -> 'ScalarSentence'
if trunc_len != 0:
raise ValueError("ScalarSentence cannot be truncated")
return self
def get_unpadded_sent(self):
return self # scalar sentences are always unpadded
def str_tokens(self, **kwargs) -> List[str]:
if self.vocab: return [self.vocab[self.value]]
else: return [str(self.value)]
Expand All @@ -151,6 +186,8 @@ def __init__(self, sents: Sequence[Sentence]) -> None:
if s.idx != self.idx:
raise ValueError("CompoundSentence must contain sentences of consistent idx.")
self.sents = sents
def __getitem__(self, item):
raise ValueError("not supported with CompoundSentence, must be called on one of the sub-inputs instead.")
def sent_len(self) -> int:
return sum(sent.sent_len() for sent in self.sents)
def len_unpadded(self) -> int:
Expand All @@ -159,6 +196,8 @@ def create_padded_sent(self, pad_len):
raise ValueError("not supported with CompoundSentence, must be called on one of the sub-inputs instead.")
def create_truncated_sent(self, trunc_len):
raise ValueError("not supported with CompoundSentence, must be called on one of the sub-inputs instead.")
def get_unpadded_sent(self):
raise ValueError("not supported with CompoundSentence, must be called on one of the sub-inputs instead.")


class SimpleSentence(ReadableSentence):
Expand All @@ -172,24 +211,27 @@ class SimpleSentence(ReadableSentence):
score: a score given to this sentence by a model
output_procs: output processors to be applied when calling sent_str()
pad_token: special token used for padding
unpadded_sent: reference to original, unpadded sentence if available
"""
def __init__(self,
words: Sequence[numbers.Integral],
idx: Optional[numbers.Integral] = None,
vocab: Optional[Vocab] = None,
score: Optional[numbers.Real] = None,
output_procs: Union[OutputProcessor, Sequence[OutputProcessor]] = [],
pad_token: numbers.Integral = Vocab.ES) -> None:
pad_token: numbers.Integral = Vocab.ES,
unpadded_sent: 'SimpleSentence' = None) -> None:
super().__init__(idx=idx, score=score, output_procs=output_procs)
self.pad_token = pad_token
self.words = words
self.vocab = vocab
self.unpadded_sent = unpadded_sent

def __getitem__(self, key):
ret = self.words[key]
if isinstance(ret, list): # support for slicing
return SimpleSentence(words=ret, idx=self.idx, vocab=self.vocab, score=self.score, output_procs=self.output_procs,
pad_token=self.pad_token)
pad_token=self.pad_token, unpadded_sent=self.unpadded_sent)
return self.words[key]

def sent_len(self):
Expand All @@ -209,6 +251,10 @@ def create_truncated_sent(self, trunc_len: numbers.Integral) -> 'SimpleSentence'
return self
return self.sent_with_words(self.words[:-trunc_len])

def get_unpadded_sent(self):
if self.unpadded_sent: return self.unpadded_sent
else: return super().get_unpadded_sent()

def str_tokens(self, exclude_ss_es=True, exclude_unk=False, exclude_padded=True, **kwargs) -> List[str]:
exclude_set = set()
if exclude_ss_es:
Expand All @@ -221,12 +267,16 @@ def str_tokens(self, exclude_ss_es=True, exclude_unk=False, exclude_padded=True,
else: return [str(w) for w in ret_toks]

def sent_with_new_words(self, new_words):
unpadded_sent = self.unpadded_sent
if not unpadded_sent:
if self.sent_len()==self.len_unpadded(): unpadded_sent = self
return SimpleSentence(words=new_words,
idx=self.idx,
vocab=self.vocab,
score=self.score,
output_procs=self.output_procs,
pad_token=self.pad_token)
pad_token=self.pad_token,
unpadded_sent=unpadded_sent)

class SegmentedSentence(SimpleSentence):
def __init__(self, segment=[], **kwargs) -> None:
Expand All @@ -240,7 +290,8 @@ def sent_with_new_words(self, new_words):
score=self.score,
output_procs=self.output_procs,
pad_token=self.pad_token,
segment=self.segment)
segment=self.segment,
unpadded_sent=self.unpadded_sent)


class ArraySentence(Sentence):
Expand All @@ -257,14 +308,16 @@ class ArraySentence(Sentence):
def __init__(self,
nparr: np.ndarray,
idx: Optional[numbers.Integral] = None,
padded_len: int = 0,
score: Optional[numbers.Real] = None) -> None:
padded_len: numbers.Integral= 0,
score: Optional[numbers.Real] = None,
unpadded_sent: 'ArraySentence' = None) -> None:
super().__init__(idx=idx, score=score)
self.nparr = nparr
self.padded_len = padded_len
self.unpadded_sent = unpadded_sent

def __getitem__(self, key):
assert isinstance(key, numbers.Integral)
if not isinstance(key, numbers.Integral): raise NotImplementedError()
return self.nparr.__getitem__(key)

def sent_len(self):
Expand All @@ -279,13 +332,20 @@ def create_padded_sent(self, pad_len: numbers.Integral) -> 'ArraySentence':
return self
new_nparr = np.append(self.nparr, np.broadcast_to(np.reshape(self.nparr[:, -1], (self.nparr.shape[0], 1)),
(self.nparr.shape[0], pad_len)), axis=1)
return ArraySentence(new_nparr, idx=self.idx, score=self.score, padded_len=self.padded_len + pad_len)
return ArraySentence(new_nparr, idx=self.idx, score=self.score, padded_len=self.padded_len + pad_len,
unpadded_sent=self if self.padded_len==0 else self.unpadded_sent)

def create_truncated_sent(self, trunc_len: numbers.Integral) -> 'ArraySentence':
if trunc_len == 0:
return self
new_nparr = np.asarray(self.nparr[:-trunc_len])
return ArraySentence(new_nparr, idx=self.idx, score=self.score, padded_len=max(0,self.padded_len - trunc_len))
return ArraySentence(new_nparr, idx=self.idx, score=self.score, padded_len=max(0,self.padded_len - trunc_len),
unpadded_sent=self if self.padded_len == 0 else self.unpadded_sent)

def get_unpadded_sent(self):
if self.padded_len==0: return self
elif self.unpadded_sent: return self.unpadded_sent
else: return super().get_unpadded_sent()

def get_array(self):
return self.nparr
Expand Down