Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raml Text Reader and SwitchOut config file #535

Merged
merged 14 commits into from
Oct 11, 2018
58 changes: 58 additions & 0 deletions examples/23_switchout.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Implements SwithcOut, a data augmentation strategy for NMT
# RAML corrupts target side only, while SwitchOut corrupts both source and target
# https://arxiv.org/pdf/1808.07512.pdf
switchout: !Experiment
# global parameters shared throughout the experiment
exp_global: !ExpGlobal
# {EXP_DIR} is a placeholder for the directory in which the config file lies.
# {EXP} is a placeholder for the experiment name (here: 'standard')
model_file: '{EXP_DIR}/models/{EXP}.mod'
log_file: '{EXP_DIR}/logs/{EXP}.log'
default_layer_dim: 512
dropout: 0.3
# model architecture
model: !DefaultTranslator
src_reader: !RamlTextReader
vocab: !Vocab {vocab_file: examples/data/head.ja.vocab}
tau: 0.8
trg_reader: !RamlTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
tau: 0.8
src_embedder: !SimpleWordEmbedder
emb_dim: 512
encoder: !BiLSTMSeqTransducer
layers: 1
attender: !MlpAttender
hidden_dim: 512
state_dim: 512
input_dim: 512
trg_embedder: !SimpleWordEmbedder
emb_dim: 512
decoder: !AutoRegressiveDecoder
rnn: !UniLSTMSeqTransducer
layers: 1
transform: !AuxNonLinear
output_dim: 512
activation: 'tanh'
bridge: !CopyBridge {}
scorer: !Softmax {}
# training parameters
train: !SimpleTrainingRegimen
batcher: !SrcBatcher
batch_size: 32
trainer: !AdamTrainer
alpha: 0.001
run_for_epochs: 2
src_file: examples/data/head.ja
trg_file: examples/data/head.en
dev_tasks:
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
# final evaluation
evaluate:
- !AccuracyEvalTask
eval_metrics: bleu
src_file: examples/data/head.ja
ref_file: examples/data/head.en
hyp_file: examples/output/{EXP}.test_hyp
52 changes: 51 additions & 1 deletion xnmt/input_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,61 @@ def count_words(self, trg_words):
def vocab_size(self):
return len(self.vocab)

class RamlTextReader(BaseTextReader, Serializable):
"""
Handles the RAML sampling, can be used on the target side, or on both the source and target side.
Randomly replaces words according to Hamming Distance.
https://arxiv.org/pdf/1808.07512.pdf
https://arxiv.org/pdf/1609.00150.pdf
"""
yaml_tag = '!RamlTextReader'

@register_xnmt_handler
@serializable_init
def __init__(self, tau: Optional[float] = 1., vocab: Optional[Vocab] = None, output_proc=[]):
"""
Args:
tau: The temperature that controls peakiness of the sampling distribution
vocab: The vocabulary
"""
self.tau = tau
self.vocab = vocab
self.output_procs = output.OutputProcessor.get_output_processor(output_proc)

@handle_xnmt_event
def on_set_train(self, val):
self.train = val

def read_sent(self, line, idx):
words = line.strip().split()
if not self.train:
return SimpleSentence(idx=idx,
words=[self.vocab.convert(word) for word in words] + [Vocab.ES],
vocab=self.vocab,
output_procs=self.output_procs)
word_ids = np.array([self.vocab.convert(word) for word in words])
length = len(word_ids)
logits = np.arange(length) * (-1) * self.tau
logits = np.exp(logits - np.max(logits))
probs = logits / np.sum(logits)
num_words = np.random.choice(length, p=probs)
corrupt_pos = np.random.binomial(1, p=num_words/length, size=(length,))
num_words_to_sample = np.sum(corrupt_pos)
sampled_words = np.random.choice(np.arange(2, len(self.vocab)), size=(num_words_to_sample,))
word_ids[np.where(corrupt_pos==1)[0].tolist()] = sampled_words
return SimpleSentence(idx=idx,
words=word_ids.tolist() + [Vocab.ES],
vocab=self.vocab,
output_procs=self.output_procs)

def needs_reload(self) -> bool:
return True

class CharFromWordTextReader(PlainTextReader, Serializable):
"""
Read in word based corpus and turned that into SegmentedSentence.
SegmentedSentece's words are characters, but it contains the information of the segmentation.

x = SegmentedSentence("i code today")
(TRUE) x.words == ["i", "c", "o", "d", "e", "t", "o", "d", "a", "y"]
(TRUE) x.segment == [0, 4, 9]
Expand Down