Skip to content

Commit

Permalink
Add Image Text Matching tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dinhanhx committed Oct 2, 2022
1 parent c05d770 commit 4c2d66b
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ImageCaptioningTools:

@staticmethod
def process_raw_data(dp, tokenizer: BunTokenizer, image_size: list, patch_size: list):
""" To process a data point from IC dataloader
""" To process a data point which is LIKELY from IC dataloader
for the inference phase of ImageTextCasualLM
"""
text_inputs = tokenizer(' ', return_tensors='pt')
Expand Down Expand Up @@ -69,7 +69,7 @@ class VisualQuestionAnswerTools:

@staticmethod
def process_raw_data(dp, tokenizer: BunTokenizer, image_size: list, patch_size: list):
""" To process a data point from VQA dataloader
""" To process a data point which is LIKELY from VQA dataloader
for the inference phase of ImageTextCasualLM
"""
text_inputs = tokenizer(dp['question'] + ' ? ', return_tensors='pt')
Expand Down Expand Up @@ -103,3 +103,35 @@ def prettify_output(output, tokenizer: BunTokenizer, return_str: bool = True):
return tokenizer.decode(answer)
else:
return answer


class ImageTextMatchingTools:
""" Tools for pre-process and post-forward
for inference mode of ImageTextForPretraining
"""

@staticmethod
def process_raw_data(dp, tokenizer: BunTokenizer, image_size: list, patch_size: list):
text_inputs = tokenizer(dp['caption'], return_tensors='pt')
image_inputs = torch.stack([resize(read_image(str(dp['img_file']),
ImageReadMode.RGB),
image_size)],
0).float()
num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1])
# Extend the shape of text_inputs.attention_masks to cover image_inputs
extra_attention_mask = torch.ones(1, num_patches, dtype=text_inputs.attention_mask.dtype)
# [:,:-1] to ignore [SEP] token to enable sentence completion
attention_mask = torch.cat((text_inputs.attention_mask[:, :-1], extra_attention_mask), dim=1)

return BatchEncoding({'input_ids': text_inputs.input_ids[:, :-1],
'attention_mask': attention_mask,
'token_type_ids': text_inputs.token_type_ids[:, :-1],
'image_input': image_inputs})

@staticmethod
def prettify_output(output):
seq_relationship_logits = output['seq_relationship_logits']
if seq_relationship_logits[0, 0] < seq_relationship_logits[0, 1]: # this is reversed of BERT NSP
return True
else:
return False

0 comments on commit 4c2d66b

Please sign in to comment.