Skip to content

Commit

Permalink
- Adding support for BT information in training examples
Browse files Browse the repository at this point in the history
- Checking wether noised examples become too long (needs rework on List[examples] to make sense)
- Adding extra padding for max_seq len when deciding if examples are too long
- Changing behaviour max_seq_len validation in BT dataset
  • Loading branch information
HaukurPall committed Sep 4, 2023
1 parent aab7a0d commit 21c2e4e
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 33 deletions.
5 changes: 5 additions & 0 deletions src/greynirseq/nicenlp/data/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
global_skip_noise_prob: float,
word_noise_config: WordNoiserConfig,
char_noise_config: CharacterNoiserConfig,
max_sequence_length: int,
):
self.bpe = bpe
self.dictionary = dictionary
Expand All @@ -46,6 +47,7 @@ def __init__(
self.char_noiser = CharacterNoiser(
char_noiser_config=char_noise_config,
)
self.max_sequence_length = max_sequence_length

def encode(self, sequence: str) -> torch.Tensor:
"""Encode a sequence of tokens into a sequence of integers using BPE and then the fairseq dictionary."""
Expand All @@ -60,4 +62,7 @@ def encode_noisy(self, sequence: str) -> torch.Tensor:
res = self.word_noiser.apply(sequence)
res = self.noisy_subword_enc.apply(res)
seq_tensor = self.fragment_noiser.apply(res.sequence, res.noise_allowed_mask)
# If the noisy sequence is too long, we encode it again without noise
if len(seq_tensor) > self.max_sequence_length:
return self.encode(sequence)
return seq_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,15 @@ def __getitem__(self, index):
tgt_segments: List[str] = [self.flat_tgt[int(i)]["segment"] for i in item[KEYS.TARGET_INDICES]]
src_langs: List[str] = [self.flat_src[int(i)]["lang"] for i in item[KEYS.SOURCE_INDICES]]
tgt_langs: List[str] = [self.flat_tgt[int(i)]["lang"] for i in item[KEYS.TARGET_INDICES]]
assert len(set(src_langs)) == 1, "source segments must be from the same language"
assert len(set(tgt_langs)) == 1, "target segments must be from the same language"

if len(set(src_langs)) != 1:
self.log_example(index=index)
raise ValueError("source segments must be from the same language")
if len(set(tgt_langs)) != 1:
self.log_example(index=index)
raise ValueError("target segments must be from the same language")

# Experimental: add BT information
bt_info = self.encoder.encode("BT") if is_bt else torch.tensor([], dtype=torch.long)
with data_utils.numpy_seed(self.seed, self.epoch, index):
insert_sep = np.random.randint(2, dtype=bool)

Expand All @@ -159,16 +165,23 @@ def __getitem__(self, index):

# This language code handling is like the mBart-50 model and nllb-200
src_out = torch.cat(
[torch.tensor([self.dictionary.index(src_langs[0])])] + src_out + [torch.tensor([self.dictionary.eos()])]
[torch.tensor([self.dictionary.index(src_langs[0])])]
+ src_out
+ [torch.tensor([self.dictionary.eos()]), bt_info]
)
tgt_out = torch.cat(
[torch.tensor([self.dictionary.index(tgt_langs[0])])] + tgt_out + [torch.tensor([self.dictionary.eos()])]
)

if len(src_out) > 1020 or len(tgt_out) > 1020:
print(f"Source: {self.encoder.bpe.decode(self.src_dict.string(src_out))}")
print(f"Target: {self.encoder.bpe.decode(self.src_dict.string(tgt_out))}")
assert False
if len(src_out) > self.max_seq_len or len(tgt_out) > self.max_seq_len:
logger.warning(
f"Truncating example at index={index} because it is too long: src={len(src_out)}, tgt={len(tgt_out)}"
)
self.log_example(index=index)
# We take the first 510 tokens and the last 510 tokens
half_seq_len = self.max_seq_len // 2
src_out = torch.cat([src_out[:half_seq_len], src_out[-half_seq_len:]])
tgt_out = torch.cat([tgt_out[:half_seq_len], tgt_out[-half_seq_len:]])

example = {
"id": index,
Expand All @@ -185,6 +198,18 @@ def decode(self, example):
print(f"{self.encoder.bpe.decode(tgt_string)}")
print()

def log_example(self, index: int):
"""For debugging"""
item = self.index_dataset[int(index)]
logger.error(f"index={index}")
print(f"item={item}")
is_bt = any(item[KEYS.SOURCE_INDICES] >= self.bt_src_start)
logger.error(f"is_bt={is_bt}")
logger.error([self.flat_src[int(i)]["segment"] for i in item[KEYS.SOURCE_INDICES]])
logger.error([self.flat_tgt[int(i)]["segment"] for i in item[KEYS.TARGET_INDICES]])
logger.error([self.flat_src[int(i)]["lang"] for i in item[KEYS.SOURCE_INDICES]])
logger.error([self.flat_tgt[int(i)]["lang"] for i in item[KEYS.TARGET_INDICES]])

def set_epoch(self, epoch: int):
self.epoch = epoch
logger.info(f"Preparing epoch {epoch}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def _inner(
]
doc_idxs, pg_idxs, weights, all_src_idxs, all_tgt_idxs, skip = [ex_batched[k] for k in keys]

# these keys are present only if we are dealing with a IndexedParallelBTDocumentsDataset
# SOURCE_OFFSETS and TARGET_OFFSETS keys are present only if we are dealing with a IndexedParallelBTDocumentsDataset
# where many flat_src and flat_tgt are concatenated together, this would skew the source/target_indices
# so we need to offset them by the document offsets
if KEYS.SOURCE_OFFSETS in ex_batched:
Expand All @@ -885,12 +885,14 @@ def _inner(
[idx + document_offset for idx in idxs]
for (idxs, document_offset) in zip(all_src_idxs, src_doc_offsets)
]
doc_idxs = [idx + document_offset for (idx, document_offset) in zip(doc_idxs, src_doc_offsets)]
if KEYS.TARGET_OFFSETS in ex_batched:
tgt_doc_offsets = ex_batched[KEYS.TARGET_OFFSETS]
all_tgt_idxs = [
[idx + document_offset for idx in idxs]
for (idxs, document_offset) in zip(all_tgt_idxs, tgt_doc_offsets)
]

# cast for readability
all_src_idxs = cast(List[List[int]], all_src_idxs)
all_tgt_idxs = cast(List[List[int]], all_tgt_idxs)
Expand All @@ -909,7 +911,8 @@ def _inner(
# set up reproducible rng state that depends implicitly on batch_size but is invariant to num_proc
rng = np.random.default_rng((seed, abs_align_indices))
# these are the sequence lengths of the bins we will use to merge sentences
maximum_lengths = np.array([50, 100, 150, 350, max_seq_len], dtype=np.int64)
extra_padding = max_merges + 1 + 1 + 1 # SENT_SEP*max_merges + Start + End + Maybe BT info
maximum_lengths = np.array([50, 100, 150, 350, max_seq_len - extra_padding], dtype=np.int64)

# fetch maximum number of rolls to minimize fn calls
bin_idxs = np.clip(rng.poisson(_POISSON_MEAN, size=len(doc_idxs)), 0, len(maximum_lengths) - 1)
Expand All @@ -934,7 +937,7 @@ def _inner(
) in enumerate(zip(doc_idxs, pg_idxs, weights, all_src_idxs, all_tgt_idxs, skip)):
# This loop is a bit tricky to understand, here is a high level overview:
# We iterate over the examples in the batch, and for each example we roll a dice to decide whether to
# merge it with the next example or not. If we decide to merge, we check whether the accumulated weight
# merge it with the next example or not. If we decide to merge, we check whether the accumulated weight/length
# is within the maximum sequence length, if it is, we merge the current example with the accumulator
# and continue. If it is not, we store the accumulator and reset it to the current example.
# An implicit assumption here is that the examples have not been shuffled,
Expand Down
2 changes: 1 addition & 1 deletion src/greynirseq/nicenlp/data/word_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, config: WordNoiserConfig):
max_shuffle_distance=config.max_shift_distance, shift_prob=config.shift_prob
)

def apply(self, sequence: str):
def apply(self, sequence: str) -> str:
return word_noise(
sequence,
self.config.drop_word_prob,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,21 @@
@dataclass
class DocumentTranslationFromPretrainedBARTConfig(TranslationFromPretrainedBARTConfig):
max_sequence_length: int = field(
default=int(1024 * 0.75),
metadata={"help": "max sequence length"},
default=1000,
metadata={
"help": "Maximum sequence length to train model on. Some sequences might be longer due to sentence separator tokens."
},
)
num_preprocess_workers: int = field(
default=2,
metadata={"help": "number of workers to preprocess the data"},
)
bt_subset: str = field(
default="",
metadata={"help": "comma separated list of subsets to use for backtranslation"},
metadata={
"help": "Prefix of backtranslation included in the training data.\
Should include the 'train_subset' prefix, i.e. if train_subset='train' then bt_subset='train_bt' will work."
},
)
parallel_prob: float = field(
default=0.33,
Expand Down Expand Up @@ -104,11 +109,6 @@ def __init__(
super().__init__(cfg, src_dict=the_dict, tgt_dict=copy.deepcopy(the_dict))
# this is for typing only
self.the_dict = the_dict
# TODO: This is a temp hack for NLLB-200
# the_dict.add_symbol("<mask1>")
# the_dict.add_symbol("<mask2>")
# self.tgt_dict = the_dict
# Hack done
self.language_mappings = language_mappings

@classmethod
Expand All @@ -124,7 +124,7 @@ def setup_task(cls, cfg: DocumentTranslationFromPretrainedBARTConfig, **kwargs):
raise ValueError("Must specify languages to train on")
the_dict = cls.load_dictionary(cfg.dict_path)
logger.info("dictionary: {} types".format(len(the_dict)))
# langcode and translation direction
# langcode and translation direction, e.g. "en:en_XX,is:is_IS"
language_mappings_pairs = cfg.data_language_mappings.split(",")
language_mappings = {}
for pair in language_mappings_pairs:
Expand Down Expand Up @@ -177,8 +177,6 @@ def metadata_from_filename(path: pathlib.Path):
assert lang1 in self.language_mappings
assert lang2 in self.language_mappings
assert file_type in [lang1, lang2, "align"]
print(self.the_dict.index(self.language_mappings[lang1]), lang1)
print(self.the_dict.index(self.language_mappings[lang2]), lang2)
if file_type == "align":
file_type = "align"
elif file_type == lang1:
Expand All @@ -202,9 +200,8 @@ def metadata_from_filename(path: pathlib.Path):
datasets_by_name_and_direction[name + direction] = []
datasets_by_name_and_direction[name + direction].append(dataset_metadata)

bt_dataset_names = self.cfg.bt_subset.split(",")
self.cfg.bt_subset.split(",")

logger.info(datasets_by_name_and_direction)
# We combine the lists into a single entity with the same name and direction
datasets_for_loading = {}
for name_direction, datasets in datasets_by_name_and_direction.items():
Expand All @@ -223,18 +220,13 @@ def metadata_from_filename(path: pathlib.Path):
"src_path": src_dataset["path"],
"tgt_path": tgt_dataset["path"],
"align_path": align_dataset["path"] if align_dataset is not None else None,
"is_bt": datasets[0]["name"] in bt_dataset_names,
"is_bt": datasets[0]["name"].startswith(self.cfg.bt_subset),
}

# sanity checks
assert (
self.cfg.max_sequence_length <= self.cfg.max_source_positions
), "The maximum training sequence length should be lesser than the positional encoding."
max_seq_len = self.cfg.max_sequence_length

logger.info(f"Max sequence length={max_seq_len}")
logger.info(f"Max merges={self.cfg.max_merges}")
print(self.cfg)

bpe = SentencepieceBPE(SentencepieceConfig(sentencepiece_model=self.cfg.spm_model))
noisy_bpe = SentencepieceBPE(
Expand All @@ -252,6 +244,7 @@ def metadata_from_filename(path: pathlib.Path):
global_skip_noise_prob=self.cfg.global_skip_noise_prob,
word_noise_config=self.cfg.word_noise_config,
char_noise_config=self.cfg.char_noise_config,
max_sequence_length=self.cfg.max_sequence_length,
)

def decode(example):
Expand All @@ -275,7 +268,7 @@ def decode(example):
dictionary=self.the_dict,
encoder=my_enc,
data_language_mapper=self.language_mappings,
max_seq_len=max_seq_len,
max_seq_len=self.cfg.max_sequence_length,
max_merges=self.cfg.max_merges,
align_path=dataset_values["align_path"],
num_proc=self.cfg.num_preprocess_workers,
Expand All @@ -294,7 +287,7 @@ def decode(example):
encoder=my_enc,
parallel_prob=self.cfg.parallel_prob,
seed=self.cfg.seed,
max_seq_len=max_seq_len,
max_seq_len=self.cfg.max_sequence_length,
max_merges=self.cfg.max_merges,
num_proc=self.cfg.num_preprocess_workers,
no_merge_prob=self.cfg.no_merge_prob,
Expand Down

0 comments on commit 21c2e4e

Please sign in to comment.