diff --git a/src/utils.py b/src/utils.py index bc36625..815129b 100755 --- a/src/utils.py +++ b/src/utils.py @@ -30,7 +30,7 @@ class ImageCaptioningTools: @staticmethod def process_raw_data(dp, tokenizer: BunTokenizer, image_size: list, patch_size: list): - """ To process a data point from IC dataloader + """ To process a data point which is LIKELY from IC dataloader for the inference phase of ImageTextCasualLM """ text_inputs = tokenizer(' ', return_tensors='pt') @@ -69,7 +69,7 @@ class VisualQuestionAnswerTools: @staticmethod def process_raw_data(dp, tokenizer: BunTokenizer, image_size: list, patch_size: list): - """ To process a data point from VQA dataloader + """ To process a data point which is LIKELY from VQA dataloader for the inference phase of ImageTextCasualLM """ text_inputs = tokenizer(dp['question'] + ' ? ', return_tensors='pt') @@ -103,3 +103,35 @@ def prettify_output(output, tokenizer: BunTokenizer, return_str: bool = True): return tokenizer.decode(answer) else: return answer + + +class ImageTextMatchingTools: + """ Tools for pre-process and post-forward + for inference mode of ImageTextForPretraining + """ + + @staticmethod + def process_raw_data(dp, tokenizer: BunTokenizer, image_size: list, patch_size: list): + text_inputs = tokenizer(dp['caption'], return_tensors='pt') + image_inputs = torch.stack([resize(read_image(str(dp['img_file']), + ImageReadMode.RGB), + image_size)], + 0).float() + num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1]) + # Extend the shape of text_inputs.attention_masks to cover image_inputs + extra_attention_mask = torch.ones(1, num_patches, dtype=text_inputs.attention_mask.dtype) + # [:,:-1] to ignore [SEP] token to enable sentence completion + attention_mask = torch.cat((text_inputs.attention_mask[:, :-1], extra_attention_mask), dim=1) + + return BatchEncoding({'input_ids': text_inputs.input_ids[:, :-1], + 'attention_mask': attention_mask, + 'token_type_ids': text_inputs.token_type_ids[:, :-1], + 'image_input': image_inputs}) + + @staticmethod + def prettify_output(output): + seq_relationship_logits = output['seq_relationship_logits'] + if seq_relationship_logits[0, 0] < seq_relationship_logits[0, 1]: # this is reversed of BERT NSP + return True + else: + return False