diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..337c774 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +/X101-features +/X152-features +/m2_annotations +evaluation/spice/* +*.pyc +*.jar +/saved_transformer_models +/tensorboard_logs +/visualization +/.vscode diff --git a/README.md b/README.md new file mode 100644 index 0000000..1f0dcad --- /dev/null +++ b/README.md @@ -0,0 +1,112 @@ +# $\mathcal{S}^2$ Transformer for Image Captioning + +[![](https://img.shields.io/badge/python-3.7.11-orange.svg)](https://www.python.org/) +[![](https://img.shields.io/apm/l/vim-mode.svg)](https://github.com/zchoi/S2-Transformer/blob/main/LICENSE) +[![](https://img.shields.io/badge/Pytorch-1.7.1-red.svg)](https://pytorch.org/) + +This repository contains the official code implementation for the paper [_S2 Transformer for Image Captioning_ (IJCAI 2022)](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhang_RSTNet_Captioning_With_Adaptive_Attention_on_Visual_and_Non-Visual_Words_CVPR_2021_paper.pdf). + +

+ Relationship-Sensitive Transformer +

+ +## Table of Contents +- [Environment setup](#environment-setup) +- [Data Preparation](#data-preparation) +- [Training](#training) +- [Evaluation](#evaluation) +- [Reference and Citation](#reference-and-citation) +- [Acknowledgements](#acknowledgements) + +## Environment setup + +Clone this repository and create the `m2release` conda environment using the `environment.yml` file: +``` +conda env create -f environment.yaml +conda activate m2release +``` + +Then download spacy data by executing the following command: +``` +python -m spacy download en_core_web_md +``` + +**Note:** Python 3 is required to run our code. If you suffer network problems, please download ```en_core_web_md``` library from [here](https://drive.google.com/file/d/1jf6ecYDzIomaGt3HgOqO_7rEL6oiTjgN/view?usp=sharing), unzip and place it to ```/your/anaconda/path/envs/m2release/lib/python*/site-packages/``` + + +## Data Preparation + +* **Annotation**. Download the annotation file [annotation.zip](https://drive.google.com/file/d/1Zc2P3-MIBg3JcHT1qKeYuQt9CnQcY5XJ/view?usp=sharing) [1]. Extract and put it in the project root directory. +* **Feature**. Download processed image features [ResNeXt-101](https://stduestceducn-my.sharepoint.com/:f:/g/personal/zhn_std_uestc_edu_cn/EssZY4Xdb0JErCk0A1Yx3vUBaRbXau88scRvYw4r1ZuwPg?e=f2QFGp) and [ResNeXt-152](https://stduestceducn-my.sharepoint.com/:f:/g/personal/zhn_std_uestc_edu_cn/EssZY4Xdb0JErCk0A1Yx3vUBaRbXau88scRvYw4r1ZuwPg?e=f2QFGp) features [2], put it in the project root directory. + + + +## Training +Run `python train_transformer.py` using the following arguments: + +| Argument | Possible values | +|------|------| +| `--exp_name` | Experiment name| +| `--batch_size` | Batch size (default: 50) | +| `--workers` | Number of workers, accelerate model training in the xe stage.| +| `--head` | Number of heads (default: 8) | +| `--resume_last` | If used, the training will be resumed from the last checkpoint. | +| `--resume_best` | If used, the training will be resumed from the best checkpoint. | +| `--features_path` | Path to visual features file (h5py)| +| `--annotation_folder` | Path to annotations | +| `--num_clusters` | Number of pseudo regions | + +For example, to train the model, run the following command: +``` +python train_transformer.py --exp_name S2 --batch_size 50 --m 40 --head 8 --features_path /path/to/features --num_clusters 5 +``` +or just run: +``` +bash train.sh +``` +**Note:** We apply `torch.distributed` to train our model, you can set the `worldSize` in [train_transformer.py]() to determine the number of GPUs for your training. + +## Evaluation +### Offline Evaluation. +Run `python test_transformer.py` to evaluate the model using the following arguments: +``` +python test_transformer.py --batch_size 10 --features_path /path/to/features --model_path /path/to/saved_transformer_models/ckpt --num_clusters 5 +``` + +**Note:** We have removed the ```SPICE``` evaluation metric during training because it is time-cost. You can add it when evaluate the model: download this [file](https://drive.google.com/file/d/1vEVsbEFjDstmSvoWhu4UdKaJjX1jJXpR/view?usp=sharing) and put it in ```/path/to/evaluation/```, then uncomment codes in [init代码](). + +We provide pretrained model [here](https://drive.google.com/file/d/1Y133r4Wd9ediS1Jqlwc1qtL15vCK_Mik/view?usp=sharing), you will get following results (second row) by evaluating the pretrained model: + +| Model | B@1 | B@4 | M | R | C | S | +|:---------: |:-------: |:-: |:---------------: |:--------------------------: |:-------: | :-------:| +| Our Paper (ResNext101) | 81.1 | 39.6 | 29.6 | 59.1 | 133.5 | 23.2| +| Reproduced Model (ResNext101) | 81.2 | 39.9 | 29.6 | 59.1 | 133.7 | 23.3| + + + +### Online Evaluation +We also report the performance of our model on the online COCO test server with an ensemble of four S2 models. The detailed online test code can be obtained in this [repo](https://github.com/zhangxuying1004/RSTNet). + +## Reference and Citation +### Reference +[1] Cornia, M., Stefanini, M., Baraldi, L., & Cucchiara, R. (2020). Meshed-memory transformer for image captioning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. +[2] Xuying Zhang, Xiaoshuai Sun, Yunpeng Luo, Jiayi Ji, Yiyi Zhou, Yongjian Wu, Feiyue +Huang, and Rongrong Ji. Rstnet: Captioning with adaptive attention on visual and non-visual words. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15465–15474, 2021. +### Citation +``` +@inproceedings{S2, + author = {Pengpeng Zeng and + Haonan Zhang and + Jingkuan Song and + Lianli Gao}, + title = {S2 Transformer for Image Captioning}, + booktitle = {IJCAI}, + % pages = {????--????} + year = {2022} +} +``` +## Acknowledgements +Thanks Zhang _et.al_ for releasing the visual features (ResNeXt-101 and ResNeXt-152). Our code implementation is also based on their [repo](https://github.com/zhangxuying1004/RSTNet). +Thanks for the original annotations prepared by [M2 Transformer](https://github.com/aimagelab/meshed-memory-transformer), and effective visual representation from [grid-feats-vqa](https://github.com/facebookresearch/grid-feats-vqa). + + diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..202e709 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,7 @@ +from .field import RawField, Merge, ImageDetectionsField, TextField +from .dataset import COCO +from torch.utils.data import DataLoader as TorchDataLoader + +class DataLoader(TorchDataLoader): + def __init__(self, dataset, *args, **kwargs): + super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs) diff --git a/data/dataset.py b/data/dataset.py new file mode 100644 index 0000000..264684f --- /dev/null +++ b/data/dataset.py @@ -0,0 +1,284 @@ +import os +import numpy as np +import itertools +import collections +import torch +from .example import Example +from .utils import nostdout +from pycocotools.coco import COCO as pyCOCO + + +class Dataset(object): + def __init__(self, examples, fields): + self.examples = examples + self.fields = dict(fields) + + def collate(self, batch): + if len(self.fields) == 1: + batch = [batch, ] + else: + batch = list(zip(*batch)) + + tensors = [] + for field, data in zip(self.fields.values(), batch): + tensor = field.process(data) + if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor): + tensors.extend(tensor) + else: + tensors.append(tensor) + + if len(tensors) > 1: + return tensors + else: + return tensors[0] + + def collate_fn(self): + return self.collate + + def text(self): + return [x.text for x in self.examples] + + def __getitem__(self, i): + example = self.examples[i] + data = [] + for field_name, field in self.fields.items(): + data.append(field.preprocess(getattr(example, field_name))) + + if len(data) == 1: + data = data[0] + return data + + def __len__(self): + return len(self.examples) + + +class ValueDataset(Dataset): + def __init__(self, examples, fields, dictionary): + super(ValueDataset, self).__init__(examples, fields) + self.dictionary = dictionary + + def valuecollate(self, batch): + value_batch_flattened = list(itertools.chain(*batch)) + value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened) + + lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch])) + if isinstance(value_tensors_flattened, collections.Sequence) \ + and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened): + value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened] + else: + value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] + + return value_tensors + + def collate_fn(self): + return self.valuecollate + + def __getitem__(self, i): + if i not in self.dictionary: + raise IndexError + + values_data = [] + for idx in self.dictionary[i]: + value_data = super(ValueDataset, self).__getitem__(idx) + values_data.append(value_data) + return values_data + + def __len__(self): + return len(self.dictionary) + + +class DictionaryDataset(Dataset): + def __init__(self, examples, fields, key_fields): + if not isinstance(key_fields, (tuple, list)): + key_fields = (key_fields,) + for field in key_fields: + assert (field in fields) + caption_field = fields.pop('add_text') + dictionary = collections.defaultdict(list) + # {'image': image_field} + key_fields = {k: fields[k] for k in key_fields} + # {’text':RawField()} + value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields} + key_examples = [] + key_dict = dict() + value_examples = [] + + for i, e in enumerate(examples): + key_example = Example.fromdict({k: getattr(e, k) for k in key_fields}) + value_example = Example.fromdict({v: getattr(e, v) for v in value_fields}) + if key_example not in key_dict: + key_dict[key_example] = len(key_examples) + key_examples.append(key_example) + + value_examples.append(value_example) + dictionary[key_dict[key_example]].append(i) + + self.key_dataset = Dataset(key_examples, key_fields) + self.value_dataset = ValueDataset(value_examples, value_fields, dictionary) + self.not_key_dataset = Dataset(value_examples, {'text': caption_field}) + + super(DictionaryDataset, self).__init__(examples, fields) + + def dictcollate(self, batch): + key_batch, value_batch, not_key_batch = list(zip(*batch)) + key_tensors = self.key_dataset.collate_fn()(key_batch) + value_tensors = self.value_dataset.collate_fn()(value_batch) + not_key_tensors = self.not_key_dataset.collate_fn()(not_key_batch) + return key_tensors, value_tensors, not_key_tensors + def collate_fn(self): + return self.dictcollate + + + def __getitem__(self, i): + return self.key_dataset[i], self.value_dataset[i], self.not_key_dataset[i] + + def __len__(self): + return len(self.key_dataset) + + +def unique(sequence): + seen = set() + if isinstance(sequence[0], list): + return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))] + else: + return [x for x in sequence if not (x in seen or seen.add(x))] + + +class PairedDataset(Dataset): + def __init__(self, examples, fields): + assert ('image' in fields) + assert ('text' in fields) + super(PairedDataset, self).__init__(examples, fields) + self.image_field = self.fields['image'] + self.text_field = self.fields['text'] + + def image_set(self): + img_list = [e.image for e in self.examples] + image_set = unique(img_list) + examples = [Example.fromdict({'image': i}) for i in image_set] + dataset = Dataset(examples, {'image': self.image_field}) + return dataset + + def text_set(self): + text_list = [e.text for e in self.examples] + text_list = unique(text_list) + examples = [Example.fromdict({'text': t}) for t in text_list] + dataset = Dataset(examples, {'text': self.text_field}) + return dataset + + def image_dictionary(self, fields=None): + if not fields: + fields = self.fields + dataset = DictionaryDataset(self.examples, fields, key_fields='image') + return dataset + + def text_dictionary(self, fields=None): + if not fields: + fields = self.fields + dataset = DictionaryDataset(self.examples, fields, key_fields='text') + return dataset + + @property + def splits(self): + raise NotImplementedError + + +class COCO(PairedDataset): + def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, + cut_validation=False): + roots = {} + roots['train'] = { + 'img': os.path.join(img_root, 'train2014'), + 'cap': os.path.join(ann_root, 'captions_train2014.json') + } + roots['val'] = { + 'img': os.path.join(img_root, 'val2014'), + 'cap': os.path.join(ann_root, 'captions_val2014.json') + } + roots['test'] = { + 'img': os.path.join(img_root, 'val2014'), + 'cap': os.path.join(ann_root, 'captions_val2014.json') + } + roots['trainrestval'] = { + 'img': (roots['train']['img'], roots['val']['img']), + 'cap': (roots['train']['cap'], roots['val']['cap']) + } + + if id_root is not None: + ids = {} + ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy')) + ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy')) + if cut_validation: + ids['val'] = ids['val'][:5000] + ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy')) + ids['trainrestval'] = ( + ids['train'], + np.load(os.path.join(id_root, 'coco_restval_ids.npy'))) + + if use_restval: + roots['train'] = roots['trainrestval'] + ids['train'] = ids['trainrestval'] + else: + ids = None + + with nostdout(): + self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids) + examples = self.train_examples + self.val_examples + self.test_examples + super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field}) + + @property + def splits(self): + train_split = PairedDataset(self.train_examples, self.fields) + val_split = PairedDataset(self.val_examples, self.fields) + test_split = PairedDataset(self.test_examples, self.fields) + return train_split, val_split, test_split + + @classmethod + def get_samples(cls, roots, ids_dataset=None): + train_samples = [] + val_samples = [] + test_samples = [] + + for split in ['train', 'val', 'test']: + if isinstance(roots[split]['cap'], tuple): + coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1])) + root = roots[split]['img'] + else: + coco_dataset = (pyCOCO(roots[split]['cap']),) + root = (roots[split]['img'],) + + if ids_dataset is None: + ids = list(coco_dataset.anns.keys()) + else: + ids = ids_dataset[split] + + if isinstance(ids, tuple): + bp = len(ids[0]) + ids = list(ids[0]) + list(ids[1]) + else: + bp = len(ids) + + for index in range(len(ids)): + if index < bp: + coco = coco_dataset[0] + img_root = root[0] + else: + coco = coco_dataset[1] + img_root = root[1] + + ann_id = ids[index] + caption = coco.anns[ann_id]['caption'] + img_id = coco.anns[ann_id]['image_id'] + filename = coco.loadImgs(img_id)[0]['file_name'] + + example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption}) + + if split == 'train': + train_samples.append(example) + elif split == 'val': + val_samples.append(example) + elif split == 'test': + test_samples.append(example) + + return train_samples, val_samples, test_samples + diff --git a/data/example.py b/data/example.py new file mode 100644 index 0000000..d46c07f --- /dev/null +++ b/data/example.py @@ -0,0 +1,27 @@ + +class Example(object): + """Defines a single training or test example. + Stores each column of the example as an attribute. + """ + @classmethod + def fromdict(cls, data): + ex = cls(data) + return ex + + def __init__(self, data): + for key, val in data.items(): + super(Example, self).__setattr__(key, val) + + def __setattr__(self, key, value): + raise AttributeError + + def __hash__(self): + return hash(tuple(x for x in self.__dict__.values())) + + def __eq__(self, other): + this = tuple(x for x in self.__dict__.values()) + other = tuple(x for x in other.__dict__.values()) + return this == other + + def __ne__(self, other): + return not self.__eq__(other) diff --git a/data/field.py b/data/field.py new file mode 100644 index 0000000..5173096 --- /dev/null +++ b/data/field.py @@ -0,0 +1,329 @@ +# coding: utf8 +from collections import Counter, OrderedDict +from torch.utils.data.dataloader import default_collate +from itertools import chain +import six +import torch +import numpy as np +import h5py +import os +import warnings +import shutil + +from .dataset import Dataset +from .vocab import Vocab +from .utils import get_tokenizer + + +class RawField(object): + """ Defines a general datatype. + + Every dataset consists of one or more types of data. For instance, + a machine translation dataset contains paired examples of text, while + an image captioning dataset contains images and texts. + Each of these types of data is represented by a RawField object. + An RawField object does not assume any property of the data type and + it holds parameters relating to how a datatype should be processed. + + Attributes: + preprocessing: The Pipeline that will be applied to examples + using this field before creating an example. + Default: None. + postprocessing: A Pipeline that will be applied to a list of examples + using this field before assigning to a batch. + Function signature: (batch(list)) -> object + Default: None. + """ + + def __init__(self, preprocessing=None, postprocessing=None): + self.preprocessing = preprocessing + self.postprocessing = postprocessing + + def preprocess(self, x): + """ Preprocess an example if the `preprocessing` Pipeline is provided. """ + if self.preprocessing is not None: + return self.preprocessing(x) + else: + return x + + def process(self, batch, *args, **kwargs): + """ Process a list of examples to create a batch. + + Postprocess the batch with user-provided Pipeline. + + Args: + batch (list(object)): A list of object from a batch of examples. + Returns: + object: Processed object given the input and custom + postprocessing Pipeline. + """ + if self.postprocessing is not None: + batch = self.postprocessing(batch) + return default_collate(batch) + + +class Merge(RawField): + def __init__(self, *fields): + super(Merge, self).__init__() + self.fields = fields + + def preprocess(self, x): + return tuple(f.preprocess(x) for f in self.fields) + + def process(self, batch, *args, **kwargs): + if len(self.fields) == 1: + batch = [batch, ] + else: + batch = list(zip(*batch)) + + out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch)) + return out + +class ImageDetectionsField(RawField): + def __init__(self, preprocessing=None, postprocessing=None, detections_path=None, max_detections=100, + sort_by_prob=False, load_in_tmp=True): + self.max_detections = max_detections + self.detections_path = detections_path + self.sort_by_prob = sort_by_prob + + tmp_detections_path = os.path.join('/tmp', os.path.basename(detections_path)) + + if load_in_tmp: + if not os.path.isfile(tmp_detections_path): + if shutil.disk_usage("/tmp")[-1] < os.path.getsize(detections_path): + warnings.warn('Loading from %s, because /tmp has no enough space.' % detections_path) + else: + warnings.warn("Copying detection file to /tmp") + shutil.copyfile(detections_path, tmp_detections_path) + warnings.warn("Done.") + self.detections_path = tmp_detections_path + else: + self.detections_path = tmp_detections_path + + super(ImageDetectionsField, self).__init__(preprocessing, postprocessing) + + def preprocess(self, x, avoid_precomp=False): + image_id = int(x.split('_')[-1].split('.')[0]) + try: + f = h5py.File(self.detections_path, 'r') +# precomp_data = f['%d_features' % image_id][()] + precomp_data = f['%d_grids' % image_id][()] + if self.sort_by_prob: + precomp_data = precomp_data[np.argsort(np.max(f['%d_cls_prob' % image_id][()], -1))[::-1]] + except KeyError: + warnings.warn('Could not find detections for %d' % image_id) + precomp_data = np.random.rand(10,2048) + + delta = self.max_detections - precomp_data.shape[0] + if delta > 0: + precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0) + elif delta < 0: + precomp_data = precomp_data[:self.max_detections] + + return precomp_data.astype(np.float32) + + +class TextField(RawField): + vocab_cls = Vocab + # Dictionary mapping PyTorch tensor dtypes to the appropriate Python + # numeric type. + dtypes = { + torch.float32: float, + torch.float: float, + torch.float64: float, + torch.double: float, + torch.float16: float, + torch.half: float, + + torch.uint8: int, + torch.int8: int, + torch.int16: int, + torch.short: int, + torch.int32: int, + torch.int: int, + torch.int64: int, + torch.long: int, + } + punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ + ".", "?", "!", ",", ":", "-", "--", "...", ";"] + + def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long, + preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()), + remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="", + unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True): + self.use_vocab = use_vocab + self.init_token = init_token + self.eos_token = eos_token + self.fix_length = fix_length + self.dtype = dtype + self.lower = lower + self.tokenize = get_tokenizer(tokenize) + self.remove_punctuation = remove_punctuation + self.include_lengths = include_lengths + self.batch_first = batch_first + self.pad_token = pad_token + self.unk_token = unk_token + self.pad_first = pad_first + self.truncate_first = truncate_first + self.vocab = None + self.vectors = vectors + if nopoints: + self.punctuations.append("..") + + super(TextField, self).__init__(preprocessing, postprocessing) + + def preprocess(self, x): + if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type): + x = six.text_type(x, encoding='utf-8') + if self.lower: + x = six.text_type.lower(x) + x = self.tokenize(x.rstrip('\n')) + if self.remove_punctuation: + x = [w for w in x if w not in self.punctuations] + if self.preprocessing is not None: + return self.preprocessing(x) + else: + return x + + def process(self, batch, device=None): + padded = self.pad(batch) + tensor = self.numericalize(padded, device=device) + return tensor + + def build_vocab(self, *args, **kwargs): + counter = Counter() + sources = [] + for arg in args: + if isinstance(arg, Dataset): + sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] + else: + sources.append(arg) + + for data in sources: + for x in data: + x = self.preprocess(x) + try: + counter.update(x) + except TypeError: + counter.update(chain.from_iterable(x)) + + specials = list(OrderedDict.fromkeys([ + tok for tok in [self.unk_token, self.pad_token, self.init_token, + self.eos_token] + if tok is not None])) + self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) + + def pad(self, minibatch): + """Pad a batch of examples using this field. + Pads to self.fix_length if provided, otherwise pads to the length of + the longest example in the batch. Prepends self.init_token and appends + self.eos_token if those attributes are not None. Returns a tuple of the + padded list and a list containing lengths of each example if + `self.include_lengths` is `True`, else just + returns the padded list. + """ + minibatch = list(minibatch) + if self.fix_length is None: + max_len = max(len(x) for x in minibatch) + else: + max_len = self.fix_length + ( + self.init_token, self.eos_token).count(None) - 2 + padded, lengths = [], [] + for x in minibatch: + if self.pad_first: + padded.append( + [self.pad_token] * max(0, max_len - len(x)) + + ([] if self.init_token is None else [self.init_token]) + + list(x[-max_len:] if self.truncate_first else x[:max_len]) + + ([] if self.eos_token is None else [self.eos_token])) + else: + padded.append( + ([] if self.init_token is None else [self.init_token]) + + list(x[-max_len:] if self.truncate_first else x[:max_len]) + + ([] if self.eos_token is None else [self.eos_token]) + + [self.pad_token] * max(0, max_len - len(x))) + lengths.append(len(padded[-1]) - max(0, max_len - len(x))) + if self.include_lengths: + return padded, lengths + return padded + + def numericalize(self, arr, device=None): + """Turn a batch of examples that use this field into a list of Variables. + If the field has include_lengths=True, a tensor of lengths will be + included in the return value. + Arguments: + arr (List[List[str]], or tuple of (List[List[str]], List[int])): + List of tokenized and padded examples, or tuple of List of + tokenized and padded examples and List of lengths of each + example if self.include_lengths is True. + device (str or torch.device): A string or instance of `torch.device` + specifying which device the Variables are going to be created on. + If left as default, the tensors will be created on cpu. Default: None. + """ + if self.include_lengths and not isinstance(arr, tuple): + raise ValueError("Field has include_lengths set to True, but " + "input data is not a tuple of " + "(data batch, batch lengths).") + if isinstance(arr, tuple): + arr, lengths = arr + lengths = torch.tensor(lengths, dtype=self.dtype, device=device) + + if self.use_vocab: + arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] + + if self.postprocessing is not None: + arr = self.postprocessing(arr, self.vocab) + + var = torch.tensor(arr, dtype=self.dtype, device=device) + else: + if self.vectors: + arr = [[self.vectors[x] for x in ex] for ex in arr] + if self.dtype not in self.dtypes: + raise ValueError( + "Specified Field dtype {} can not be used with " + "use_vocab=False because we do not know how to numericalize it. " + "Please raise an issue at " + "https://github.com/pytorch/text/issues".format(self.dtype)) + numericalization_func = self.dtypes[self.dtype] + # It doesn't make sense to explictly coerce to a numeric type if + # the data is sequential, since it's unclear how to coerce padding tokens + # to a numeric type. + arr = [numericalization_func(x) if isinstance(x, six.string_types) + else x for x in arr] + + if self.postprocessing is not None: + arr = self.postprocessing(arr, None) + + var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr]) + + # var = torch.tensor(arr, dtype=self.dtype, device=device) + if not self.batch_first: + var.t_() + var = var.contiguous() + + if self.include_lengths: + return var, lengths + return var + + def decode(self, word_idxs, join_words=True): + if isinstance(word_idxs, list) and len(word_idxs) == 0: + return self.decode([word_idxs, ], join_words)[0] + if isinstance(word_idxs, list) and isinstance(word_idxs[0], int): + return self.decode([word_idxs, ], join_words)[0] + elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1: + return self.decode(word_idxs.reshape((1, -1)), join_words)[0] + elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1: + return self.decode(word_idxs.unsqueeze(0), join_words)[0] + + captions = [] + for wis in word_idxs: + caption = [] + for wi in wis: + word = self.vocab.itos[int(wi)] + if word == self.eos_token: + break + caption.append(word) + if join_words: + caption = ' '.join(caption) + captions.append(caption) + return captions diff --git a/data/utils.py b/data/utils.py new file mode 100644 index 0000000..e1dc21d --- /dev/null +++ b/data/utils.py @@ -0,0 +1,91 @@ +import contextlib, sys + +class DummyFile(object): + def write(self, x): pass + +@contextlib.contextmanager +def nostdout(): + save_stdout = sys.stdout + sys.stdout = DummyFile() + yield + sys.stdout = save_stdout + +def reporthook(t): + """https://github.com/tqdm/tqdm""" + last_b = [0] + + def inner(b=1, bsize=1, tsize=None): + """ + b: int, optionala + Number of blocks just transferred [default: 1]. + bsize: int, optional + Size of each block (in tqdm units) [default: 1]. + tsize: int, optional + Total size (in tqdm units). If [default: None] remains unchanged. + """ + if tsize is not None: + t.total = tsize + t.update((b - last_b[0]) * bsize) + last_b[0] = b + return inner + +import revtok +import spacy +spacy_en = spacy.load('en_core_web_md') + + +def _getTokenizerrevtok(x): + return revtok.tokenize(x, decap=True) + +def _getTokenizerspacy(s): + return [tok.text for tok in spacy_en.tokenizer(s)] + + +def get_tokenizer(tokenizer): + if callable(tokenizer): + return tokenizer + if tokenizer == "spacy": + try: + return _getTokenizerspacy + except ImportError: + print("Please install SpaCy and the SpaCy English tokenizer. " + "See the docs at https://spacy.io for more information.") + raise + except AttributeError: + print("Please install SpaCy and the SpaCy English tokenizer. " + "See the docs at https://spacy.io for more information.") + raise + elif tokenizer == "moses": + try: + from nltk.tokenize.moses import MosesTokenizer + moses_tokenizer = MosesTokenizer() + return moses_tokenizer.tokenize + except ImportError: + print("Please install NLTK. " + "See the docs at http://nltk.org for more information.") + raise + except LookupError: + print("Please install the necessary NLTK corpora. " + "See the docs at http://nltk.org for more information.") + raise + elif tokenizer == 'revtok': + try: + import revtok + return revtok.tokenize + except ImportError: + print("Please install revtok.") + raise + elif tokenizer == 'subword': + try: + return _getTokenizerrevtok + except ImportError: + print("Please install revtok.") + raise + raise ValueError("Requested tokenizer {}, valid choices are a " + "callable that takes a single string as input, " + "\"revtok\" for the revtok reversible tokenizer, " + "\"subword\" for the revtok caps-aware tokenizer, " + "\"spacy\" for the SpaCy English tokenizer, or " + "\"moses\" for the NLTK port of the Moses tokenization " + "script.".format(tokenizer)) + diff --git a/data/vocab.py b/data/vocab.py new file mode 100644 index 0000000..7614248 --- /dev/null +++ b/data/vocab.py @@ -0,0 +1,372 @@ +from __future__ import unicode_literals +import array +from collections import defaultdict +from functools import partial +import io +import logging +import os +import zipfile + +import six +from six.moves.urllib.request import urlretrieve +import torch +from tqdm import tqdm +import tarfile + +from .utils import reporthook + +logger = logging.getLogger(__name__) + + +class Vocab(object): + """Defines a vocabulary object that will be used to numericalize a field. + + Attributes: + freqs: A collections.Counter object holding the frequencies of tokens + in the data used to build the Vocab. + stoi: A collections.defaultdict instance mapping token strings to + numerical identifiers. + itos: A list of token strings indexed by their numerical identifiers. + """ + def __init__(self, counter, max_size=None, min_freq=1, specials=[''], + vectors=None, unk_init=None, vectors_cache=None): + """Create a Vocab object from a collections.Counter. + + Arguments: + counter: collections.Counter object holding the frequencies of + each value found in the data. + max_size: The maximum size of the vocabulary, or None for no + maximum. Default: None. + min_freq: The minimum frequency needed to include a token in the + vocabulary. Values less than 1 will be set to 1. Default: 1. + specials: The list of special tokens (e.g., padding or eos) that + will be prepended to the vocabulary in addition to an + token. Default: [''] + vectors: One of either the available pretrained vectors + or custom pretrained vectors (see Vocab.load_vectors); + or a list of aforementioned vectors + unk_init (callback): by default, initialize out-of-vocabulary word vectors + to zero vectors; can be any function that takes in a Tensor and + returns a Tensor of the same size. Default: torch.Tensor.zero_ + vectors_cache: directory for cached vectors. Default: '.vector_cache' + """ + self.freqs = counter + counter = counter.copy() + min_freq = max(min_freq, 1) + + self.itos = list(specials) + # frequencies of special tokens are not counted when building vocabulary + # in frequency order + for tok in specials: + del counter[tok] + + max_size = None if max_size is None else max_size + len(self.itos) + + # sort by frequency, then alphabetically + words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) + words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) + + for word, freq in words_and_frequencies: + if freq < min_freq or len(self.itos) == max_size: + break + self.itos.append(word) + + self.stoi = defaultdict(_default_unk_index) + # stoi is simply a reverse dict for itos + self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) + + self.vectors = None + if vectors is not None: + self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) + else: + assert unk_init is None and vectors_cache is None + + def __eq__(self, other): + if self.freqs != other.freqs: + return False + if self.stoi != other.stoi: + return False + if self.itos != other.itos: + return False + if self.vectors != other.vectors: + return False + return True + + def __len__(self): + return len(self.itos) + + def extend(self, v, sort=False): + words = sorted(v.itos) if sort else v.itos + for w in words: + if w not in self.stoi: + self.itos.append(w) + self.stoi[w] = len(self.itos) - 1 + + def load_vectors(self, vectors, **kwargs): + """ + Arguments: + vectors: one of or a list containing instantiations of the + GloVe, CharNGram, or Vectors classes. Alternatively, one + of or a list of available pretrained vectors: + charngram.100d + fasttext.en.300d + fasttext.simple.300d + glove.42B.300d + glove.840B.300d + glove.twitter.27B.25d + glove.twitter.27B.50d + glove.twitter.27B.100d + glove.twitter.27B.200d + glove.6B.50d + glove.6B.100d + glove.6B.200d + glove.6B.300d + Remaining keyword arguments: Passed to the constructor of Vectors classes. + """ + if not isinstance(vectors, list): + vectors = [vectors] + for idx, vector in enumerate(vectors): + if six.PY2 and isinstance(vector, str): + vector = six.text_type(vector) + if isinstance(vector, six.string_types): + # Convert the string pretrained vector identifier + # to a Vectors object + if vector not in pretrained_aliases: + raise ValueError( + "Got string input vector {}, but allowed pretrained " + "vectors are {}".format( + vector, list(pretrained_aliases.keys()))) + vectors[idx] = pretrained_aliases[vector](**kwargs) + elif not isinstance(vector, Vectors): + raise ValueError( + "Got input vectors of type {}, expected str or " + "Vectors object".format(type(vector))) + + tot_dim = sum(v.dim for v in vectors) + self.vectors = torch.Tensor(len(self), tot_dim) + for i, token in enumerate(self.itos): + start_dim = 0 + for v in vectors: + end_dim = start_dim + v.dim + self.vectors[i][start_dim:end_dim] = v[token.strip()] + start_dim = end_dim + assert(start_dim == tot_dim) + + def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): + """ + Set the vectors for the Vocab instance from a collection of Tensors. + + Arguments: + stoi: A dictionary of string to the index of the associated vector + in the `vectors` input argument. + vectors: An indexed iterable (or other structure supporting __getitem__) that + given an input index, returns a FloatTensor representing the vector + for the token associated with the index. For example, + vector[stoi["string"]] should return the vector for "string". + dim: The dimensionality of the vectors. + unk_init (callback): by default, initialize out-of-vocabulary word vectors + to zero vectors; can be any function that takes in a Tensor and + returns a Tensor of the same size. Default: torch.Tensor.zero_ + """ + self.vectors = torch.Tensor(len(self), dim) + for i, token in enumerate(self.itos): + wv_index = stoi.get(token, None) + if wv_index is not None: + self.vectors[i] = vectors[wv_index] + else: + self.vectors[i] = unk_init(self.vectors[i]) + + +class Vectors(object): + + def __init__(self, name, cache=None, + url=None, unk_init=None): + """ + Arguments: + name: name of the file that contains the vectors + cache: directory for cached vectors + url: url for download if vectors not found in cache + unk_init (callback): by default, initalize out-of-vocabulary word vectors + to zero vectors; can be any function that takes in a Tensor and + returns a Tensor of the same size + """ + cache = '.vector_cache' if cache is None else cache + self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init + self.cache(name, cache, url=url) + + def __getitem__(self, token): + if token in self.stoi: + return self.vectors[self.stoi[token]] + else: + return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim)) + + def cache(self, name, cache, url=None): + if os.path.isfile(name): + path = name + path_pt = os.path.join(cache, os.path.basename(name)) + '.pt' + else: + path = os.path.join(cache, name) + path_pt = path + '.pt' + + if not os.path.isfile(path_pt): + if not os.path.isfile(path) and url: + logger.info('Downloading vectors from {}'.format(url)) + if not os.path.exists(cache): + os.makedirs(cache) + dest = os.path.join(cache, os.path.basename(url)) + if not os.path.isfile(dest): + with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: + try: + urlretrieve(url, dest, reporthook=reporthook(t)) + except KeyboardInterrupt as e: # remove the partial zip file + os.remove(dest) + raise e + logger.info('Extracting vectors into {}'.format(cache)) + ext = os.path.splitext(dest)[1][1:] + if ext == 'zip': + with zipfile.ZipFile(dest, "r") as zf: + zf.extractall(cache) + elif ext == 'gz': + with tarfile.open(dest, 'r:gz') as tar: + tar.extractall(path=cache) + if not os.path.isfile(path): + raise RuntimeError('no vectors found at {}'.format(path)) + + # str call is necessary for Python 2/3 compatibility, since + # argument must be Python 2 str (Python 3 bytes) or + # Python 3 str (Python 2 unicode) + itos, vectors, dim = [], array.array(str('d')), None + + # Try to read the whole file with utf-8 encoding. + binary_lines = False + try: + with io.open(path, encoding="utf8") as f: + lines = [line for line in f] + # If there are malformed lines, read in binary mode + # and manually decode each word from utf-8 + except: + logger.warning("Could not read {} as UTF8 file, " + "reading file as bytes and skipping " + "words with malformed UTF8.".format(path)) + with open(path, 'rb') as f: + lines = [line for line in f] + binary_lines = True + + logger.info("Loading vectors from {}".format(path)) + for line in tqdm(lines, total=len(lines)): + # Explicitly splitting on " " is important, so we don't + # get rid of Unicode non-breaking spaces in the vectors. + entries = line.rstrip().split(b" " if binary_lines else " ") + + word, entries = entries[0], entries[1:] + if dim is None and len(entries) > 1: + dim = len(entries) + elif len(entries) == 1: + logger.warning("Skipping token {} with 1-dimensional " + "vector {}; likely a header".format(word, entries)) + continue + elif dim != len(entries): + raise RuntimeError( + "Vector for token {} has {} dimensions, but previously " + "read vectors have {} dimensions. All vectors must have " + "the same number of dimensions.".format(word, len(entries), dim)) + + if binary_lines: + try: + if isinstance(word, six.binary_type): + word = word.decode('utf-8') + except: + logger.info("Skipping non-UTF8 token {}".format(repr(word))) + continue + vectors.extend(float(x) for x in entries) + itos.append(word) + + self.itos = itos + self.stoi = {word: i for i, word in enumerate(itos)} + self.vectors = torch.Tensor(vectors).view(-1, dim) + self.dim = dim + logger.info('Saving vectors to {}'.format(path_pt)) + if not os.path.exists(cache): + os.makedirs(cache) + torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) + else: + logger.info('Loading vectors from {}'.format(path_pt)) + self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) + + +class GloVe(Vectors): + url = { + '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', + '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', + 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', + '6B': 'http://nlp.stanford.edu/data/glove.6B.zip', + } + + def __init__(self, name='840B', dim=300, **kwargs): + url = self.url[name] + name = 'glove.{}.{}d.txt'.format(name, str(dim)) + super(GloVe, self).__init__(name, url=url, **kwargs) + + +class FastText(Vectors): + + url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec' + + def __init__(self, language="en", **kwargs): + url = self.url_base.format(language) + name = os.path.basename(url) + super(FastText, self).__init__(name, url=url, **kwargs) + + +class CharNGram(Vectors): + + name = 'charNgram.txt' + url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' + 'jmt_pre-trained_embeddings.tar.gz') + + def __init__(self, **kwargs): + super(CharNGram, self).__init__(self.name, url=self.url, **kwargs) + + def __getitem__(self, token): + vector = torch.Tensor(1, self.dim).zero_() + if token == "": + return self.unk_init(vector) + # These literals need to be coerced to unicode for Python 2 compatibility + # when we try to join them with read ngrams from the files. + chars = ['#BEGIN#'] + list(token) + ['#END#'] + num_vectors = 0 + for n in [2, 3, 4]: + end = len(chars) - n + 1 + grams = [chars[i:(i + n)] for i in range(end)] + for gram in grams: + gram_key = '{}gram-{}'.format(n, ''.join(gram)) + if gram_key in self.stoi: + vector += self.vectors[self.stoi[gram_key]] + num_vectors += 1 + if num_vectors > 0: + vector /= num_vectors + else: + vector = self.unk_init(vector) + return vector + + +def _default_unk_index(): + return 0 + + +pretrained_aliases = { + "charngram.100d": partial(CharNGram), + "fasttext.en.300d": partial(FastText, language="en"), + "fasttext.simple.300d": partial(FastText, language="simple"), + "glove.42B.300d": partial(GloVe, name="42B", dim="300"), + "glove.840B.300d": partial(GloVe, name="840B", dim="300"), + "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"), + "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"), + "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"), + "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"), + "glove.6B.50d": partial(GloVe, name="6B", dim="50"), + "glove.6B.100d": partial(GloVe, name="6B", dim="100"), + "glove.6B.200d": partial(GloVe, name="6B", dim="200"), + "glove.6B.300d": partial(GloVe, name="6B", dim="300") +} +"""Mapping from string name to factory function""" diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..3990d3a --- /dev/null +++ b/environment.yaml @@ -0,0 +1,194 @@ +name: ic +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - argcomplete=1.12.3=pyhd3eb1b0_0 + - backcall=0.2.0=pyhd3eb1b0_0 + - ca-certificates=2021.7.5=h06a4308_1 + - certifi=2021.5.30=py37h06a4308_0 + - debugpy=1.4.1=py37h295c915_0 + - decorator=5.0.9=pyhd3eb1b0_0 + - entrypoints=0.3=py37_0 + - importlib-metadata=3.10.0=py37h06a4308_0 + - importlib_metadata=3.10.0=hd3eb1b0_0 + - ipykernel=6.2.0=py37h06a4308_1 + - ipython=7.26.0=py37hb070fc8_0 + - ipython_genutils=0.2.0=pyhd3eb1b0_1 + - jedi=0.18.0=py37h06a4308_1 + - jupyter_client=7.0.1=pyhd3eb1b0_0 + - jupyter_core=4.7.1=py37h06a4308_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgomp=9.3.0=h5101ec6_17 + - libsodium=1.0.18=h7b6447c_0 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - matplotlib-inline=0.1.2=pyhd3eb1b0_2 + - ncurses=6.2=he6710b0_1 + - nest-asyncio=1.5.1=pyhd3eb1b0_0 + - openssl=1.1.1l=h7f8727e_0 + - parso=0.8.2=pyhd3eb1b0_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pip=21.0.1=py37h06a4308_0 + - prompt-toolkit=3.0.17=pyhca03da5_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pygments=2.10.0=pyhd3eb1b0_0 + - python=3.7.11=h12debd9_0 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - pyzmq=22.2.1=py37h295c915_1 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py37h06a4308_0 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.36.0=hc218d9a_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py37h27cfd23_0 + - traitlets=5.0.5=pyhd3eb1b0_0 + - typing_extensions=3.10.0.0=pyhca03da5_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - wheel=0.37.0=pyhd3eb1b0_1 + - xz=5.2.5=h7b6447c_0 + - zeromq=4.3.4=h2531618_0 + - zipp=3.5.0=pyhd3eb1b0_0 + - zlib=1.2.11=h7b6447c_3 + - pip: + - absl-py==0.13.0 + - addict==2.4.0 + - attrs==21.2.0 + - av==8.1.0 + - bert-serving-client==1.10.0 + - bert-serving-server==1.10.0 + - blessings==1.7 + - blis==0.7.4 + - boto3==1.18.48 + - botocore==1.21.48 + - cached-property==1.5.2 + - cachetools==4.2.2 + - catalogue==1.0.0 + - charset-normalizer==2.0.4 + - click==7.1.2 + - cycler==0.10.0 + - cymem==2.0.5 + - cython==0.29.24 + - distlib==0.3.4 + - easydict==1.9 + - einops==0.4.0 + - en-core-web-md==2.3.1 + - en-core-web-sm==3.1.0 + - en-vectors-web-lg==2.1.0 + - filelock==3.6.0 + - future==0.18.2 + - gensim==4.1.2 + - google-auth==1.35.0 + - google-auth-oauthlib==0.4.6 + - gpustat==0.6.0 + - gputil==1.4.0 + - grad-cam==1.3.7 + - grpcio==1.40.0 + - h5py==3.4.0 + - huggingface-hub==0.0.17 + - idna==3.2 + - imageio==2.13.3 + - iniconfig==1.1.1 + - jinja2==3.0.1 + - jmespath==0.10.0 + - joblib==1.0.1 + - jsonschema==4.1.0 + - kiwisolver==1.3.2 + - lmdb==1.2.1 + - lxml==4.6.3 + - lz4==3.1.3 + - markdown==3.3.4 + - markupsafe==2.0.1 + - matplotlib==3.4.3 + - middle==0.2.4 + - mime==0.1.0 + - mmcv==1.4.0 + - msgpack==1.0.3 + - murmurhash==1.0.5 + - nbformat==5.1.3 + - networkx==2.6.3 + - nltk==3.6.7 + - nose==1.3.7 + - numpy==1.21.2 + - nvidia-ml-py3==7.352.0 + - oauthlib==3.1.1 + - opencv-python==4.5.5.62 + - packaging==21.0 + - pandas==1.3.3 + - pathy==0.6.0 + - pillow==8.3.2 + - plac==1.1.3 + - platformdirs==2.5.1 + - plotly==5.3.1 + - pluggy==1.0.0 + - preshed==3.0.5 + - progressbar==2.5 + - protobuf==3.17.3 + - psutil==5.8.0 + - py==1.11.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pycocotools==2.0.2 + - pydantic==1.8.2 + - pyemd==0.5.1 + - pyparsing==2.4.7 + - pyrsistent==0.18.0 + - pytest==7.1.0 + - python-docx==0.8.11 + - pytorch-transformers==1.2.0 + - pytz==2021.1 + - pywavelets==1.2.0 + - pyyaml==5.4.1 + - regex==2021.8.28 + - requests==2.26.0 + - requests-oauthlib==1.3.0 + - revtok==0.0.3 + - rsa==4.7.2 + - s3transfer==0.5.0 + - sacremoses==0.0.45 + - scikit-image==0.19.0 + - scikit-learn==1.0 + - scipy==1.7.1 + - seaborn==0.11.2 + - sentencepiece==0.1.96 + - sklearn==0.0 + - smart-open==5.2.1 + - spacy==2.3.7 + - spacy-legacy==3.0.8 + - spherecluster==0.1.7 + - srsly==1.0.5 + - tenacity==8.0.1 + - tensorboard==2.6.0 + - tensorboard-data-server==0.6.1 + - tensorboard-logger==0.1.0 + - tensorboard-plugin-wit==1.8.0 + - tensorboardx==2.5 + - termcolor==1.1.0 + - thinc==7.4.5 + - thop==0.0.31-2005241907 + - threadpoolctl==3.0.0 + - tifffile==2021.11.2 + - timm==0.4.12 + - tokenizers==0.10.3 + - tomli==2.0.1 + - torch==1.7.1+cu110 + - torch-tb-profiler==0.2.1 + - torchsummary==1.5.1 + - torchtext==0.8.0 + - torchvision==0.8.2 + - tqdm==4.62.2 + - transformers==4.10.2 + - ttach==0.0.3 + - typer==0.3.2 + - typing-extensions==3.10.0.2 + - ujson==5.1.0 + - urllib3==1.26.6 + - virtualenv==20.14.0 + - visualize==0.5.1 + - wasabi==0.8.2 + - werkzeug==2.0.1 + - yapf==0.32.0 +prefix: /home/zhanghaonan/anaconda3/envs/ic diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 0000000..150e32d --- /dev/null +++ b/evaluation/__init__.py @@ -0,0 +1,17 @@ +from .bleu import Bleu +from .meteor import Meteor +from .rouge import Rouge +from .cider import Cider +# from .spice import Spice +from .tokenizer import PTBTokenizer + +def compute_scores(gts, gen): + metrics = (Bleu(), Meteor(), Rouge(), Cider()) + all_score = {} + all_scores = {} + for metric in metrics: + score, scores = metric.compute_score(gts, gen) + all_score[str(metric)] = score + all_scores[str(metric)] = scores + + return all_score, all_scores diff --git a/evaluation/bleu/__init__.py b/evaluation/bleu/__init__.py new file mode 100644 index 0000000..8000b20 --- /dev/null +++ b/evaluation/bleu/__init__.py @@ -0,0 +1 @@ +from .bleu import Bleu \ No newline at end of file diff --git a/evaluation/bleu/bleu.py b/evaluation/bleu/bleu.py new file mode 100644 index 0000000..6d3139d --- /dev/null +++ b/evaluation/bleu/bleu.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# +# File Name : bleu.py +# +# Description : Wrapper for BLEU scorer. +# +# Creation Date : 06-01-2015 +# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT +# Authors : Hao Fang and Tsung-Yi Lin + +from .bleu_scorer import BleuScorer + + +class Bleu: + def __init__(self, n=4): + # default compute Blue score up to 4 + self._n = n + self._hypo_for_image = {} + self.ref_for_image = {} + + def compute_score(self, gts, res): + + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + bleu_scorer = BleuScorer(n=self._n) + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + assert(len(ref) >= 1) + + bleu_scorer += (hypo[0], ref) + + # score, scores = bleu_scorer.compute_score(option='shortest') + score, scores = bleu_scorer.compute_score(option='closest', verbose=0) + # score, scores = bleu_scorer.compute_score(option='average', verbose=1) + + return score, scores + + def __str__(self): + return 'BLEU' diff --git a/evaluation/bleu/bleu_scorer.py b/evaluation/bleu/bleu_scorer.py new file mode 100644 index 0000000..8047f46 --- /dev/null +++ b/evaluation/bleu/bleu_scorer.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python + +# bleu_scorer.py +# David Chiang + +# Copyright (c) 2004-2006 University of Maryland. All rights +# reserved. Do not redistribute without permission from the +# author. Not for commercial use. + +# Modified by: +# Hao Fang +# Tsung-Yi Lin + +''' Provides: +cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). +cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). +''' + +import copy +import sys, math, re +from collections import defaultdict + + +def precook(s, n=4, out=False): + """Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well.""" + words = s.split() + counts = defaultdict(int) + for k in range(1, n + 1): + for i in range(len(words) - k + 1): + ngram = tuple(words[i:i + k]) + counts[ngram] += 1 + return (len(words), counts) + + +def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them.''' + + reflen = [] + maxcounts = {} + for ref in refs: + rl, counts = precook(ref, n) + reflen.append(rl) + for (ngram, count) in counts.items(): + maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) + + # Calculate effective reference sentence length. + if eff == "shortest": + reflen = min(reflen) + elif eff == "average": + reflen = float(sum(reflen)) / len(reflen) + + ## lhuang: N.B.: leave reflen computaiton to the very end!! + + ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) + + return (reflen, maxcounts) + + +def cook_test(test, ref_tuple, eff=None, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it.''' + + testlen, counts = precook(test, n, True) + reflen, refmaxcounts = ref_tuple + + result = {} + + # Calculate effective reference sentence length. + + if eff == "closest": + result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] + else: ## i.e., "average" or "shortest" or None + result["reflen"] = reflen + + result["testlen"] = testlen + + result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] + + result['correct'] = [0] * n + for (ngram, count) in counts.items(): + result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) + + return result + + +class BleuScorer(object): + """Bleu scorer. + """ + + __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" + + # special_reflen is used in oracle (proportional effective ref len for a node). + + def copy(self): + ''' copy the refs.''' + new = BleuScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + new._score = None + return new + + def __init__(self, test=None, refs=None, n=4, special_reflen=None): + ''' singular instance ''' + + self.n = n + self.crefs = [] + self.ctest = [] + self.cook_append(test, refs) + self.special_reflen = special_reflen + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + cooked_test = cook_test(test, self.crefs[-1]) + self.ctest.append(cooked_test) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + self._score = None ## need to recompute + + def ratio(self, option=None): + self.compute_score(option=option) + return self._ratio + + def score_ratio(self, option=None): + ''' + return (bleu, len_ratio) pair + ''' + + return self.fscore(option=option), self.ratio(option=option) + + def score_ratio_str(self, option=None): + return "%.4f (%.2f)" % self.score_ratio(option) + + def reflen(self, option=None): + self.compute_score(option=option) + return self._reflen + + def testlen(self, option=None): + self.compute_score(option=option) + return self._testlen + + def retest(self, new_test): + if type(new_test) is str: + new_test = [new_test] + assert len(new_test) == len(self.crefs), new_test + self.ctest = [] + for t, rs in zip(new_test, self.crefs): + self.ctest.append(cook_test(t, rs)) + self._score = None + + return self + + def rescore(self, new_test): + ''' replace test(s) with new test(s), and returns the new score.''' + + return self.retest(new_test).compute_score() + + def size(self): + assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + ## avoid creating new BleuScorer instances + self.cook_append(other[0], other[1]) + else: + assert self.compatible(other), "incompatible BLEUs." + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + self._score = None ## need to recompute + + return self + + def compatible(self, other): + return isinstance(other, BleuScorer) and self.n == other.n + + def single_reflen(self, option="average"): + return self._single_reflen(self.crefs[0][0], option) + + def _single_reflen(self, reflens, option=None, testlen=None): + + if option == "shortest": + reflen = min(reflens) + elif option == "average": + reflen = float(sum(reflens)) / len(reflens) + elif option == "closest": + reflen = min((abs(l - testlen), l) for l in reflens)[1] + else: + assert False, "unsupported reflen option %s" % option + + return reflen + + def recompute_score(self, option=None, verbose=0): + self._score = None + return self.compute_score(option, verbose) + + def compute_score(self, option=None, verbose=0): + n = self.n + small = 1e-9 + tiny = 1e-15 ## so that if guess is 0 still return 0 + bleu_list = [[] for _ in range(n)] + + if self._score is not None: + return self._score + + if option is None: + option = "average" if len(self.crefs) == 1 else "closest" + + self._testlen = 0 + self._reflen = 0 + totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} + + # for each sentence + for comps in self.ctest: + testlen = comps['testlen'] + self._testlen += testlen + + if self.special_reflen is None: ## need computation + reflen = self._single_reflen(comps['reflen'], option, testlen) + else: + reflen = self.special_reflen + + self._reflen += reflen + + for key in ['guess', 'correct']: + for k in range(n): + totalcomps[key][k] += comps[key][k] + + # append per image bleu score + bleu = 1. + for k in range(n): + bleu *= (float(comps['correct'][k]) + tiny) \ + / (float(comps['guess'][k]) + small) + bleu_list[k].append(bleu ** (1. / (k + 1))) + ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleu_list[k][-1] *= math.exp(1 - 1 / ratio) + + if verbose > 1: + print(comps, reflen) + + totalcomps['reflen'] = self._reflen + totalcomps['testlen'] = self._testlen + + bleus = [] + bleu = 1. + for k in range(n): + bleu *= float(totalcomps['correct'][k] + tiny) \ + / (totalcomps['guess'][k] + small) + bleus.append(bleu ** (1. / (k + 1))) + ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleus[k] *= math.exp(1 - 1 / ratio) + + if verbose > 0: + print(totalcomps) + print("ratio:", ratio) + + self._score = bleus + return self._score, bleu_list diff --git a/evaluation/cider/__init__.py b/evaluation/cider/__init__.py new file mode 100644 index 0000000..aaa32ec --- /dev/null +++ b/evaluation/cider/__init__.py @@ -0,0 +1 @@ +from .cider import Cider \ No newline at end of file diff --git a/evaluation/cider/cider.py b/evaluation/cider/cider.py new file mode 100644 index 0000000..57ae309 --- /dev/null +++ b/evaluation/cider/cider.py @@ -0,0 +1,42 @@ +# Filename: cider.py +# +# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric +# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) +# +# Creation Date: Sun Feb 8 14:16:54 2015 +# +# Authors: Ramakrishna Vedantam and Tsung-Yi Lin + +from .cider_scorer import CiderScorer + +class Cider: + """ + Main Class to compute the CIDEr metric + + """ + def __init__(self, gts=None, n=4, sigma=6.0): + # set cider to sum over 1 to 4-grams + self._n = n + # set the standard deviation parameter for gaussian penalty + self._sigma = sigma + self.doc_frequency = None + self.ref_len = None + if gts is not None: + tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma) + self.doc_frequency = tmp_cider.doc_frequency + self.ref_len = tmp_cider.ref_len + + def compute_score(self, gts, res): + """ + Main function to compute CIDEr score + :param gts (dict) : dictionary with key and value + res (dict) : dictionary with key and value + :return: cider (float) : computed CIDEr score for the corpus + """ + assert(gts.keys() == res.keys()) + cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency, + ref_len=self.ref_len) + return cider_scorer.compute_score() + + def __str__(self): + return 'CIDEr' diff --git a/evaluation/cider/cider_scorer.py b/evaluation/cider/cider_scorer.py new file mode 100644 index 0000000..37243ef --- /dev/null +++ b/evaluation/cider/cider_scorer.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +# Tsung-Yi Lin +# Ramakrishna Vedantam + +import copy +from collections import defaultdict +import numpy as np +import math + +def precook(s, n=4): + """ + Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well. + :param s: string : sentence to be converted into ngrams + :param n: int : number of ngrams for which representation is calculated + :return: term frequency vector for occuring ngrams + """ + words = s.split() + counts = defaultdict(int) + for k in range(1,n+1): + for i in range(len(words)-k+1): + ngram = tuple(words[i:i+k]) + counts[ngram] += 1 + return counts + +def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them. + :param refs: list of string : reference sentences for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (list of dict) + ''' + return [precook(ref, n) for ref in refs] + +def cook_test(test, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it. + :param test: list of string : hypothesis sentence for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (dict) + ''' + return precook(test, n) + +class CiderScorer(object): + """CIDEr scorer. + """ + + def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None): + ''' singular instance ''' + self.n = n + self.sigma = sigma + self.crefs = [] + self.ctest = [] + self.doc_frequency = defaultdict(float) + self.ref_len = None + + for k in refs.keys(): + self.crefs.append(cook_refs(refs[k])) + if test is not None: + self.ctest.append(cook_test(test[k][0])) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + if doc_frequency is None and ref_len is None: + # compute idf + self.compute_doc_freq() + # compute log reference length + self.ref_len = np.log(float(len(self.crefs))) + else: + self.doc_frequency = doc_frequency + self.ref_len = ref_len + + def compute_doc_freq(self): + ''' + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + ''' + for refs in self.crefs: + # refs, k ref captions of one image + for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): + self.doc_frequency[ngram] += 1 + # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + + def compute_cider(self): + def counts2vec(cnts): + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(self.n)] + length = 0 + norm = [0.0 for _ in range(self.n)] + for (ngram,term_freq) in cnts.items(): + # give word count 1 if it doesn't appear in reference corpus + df = np.log(max(1.0, self.doc_frequency[ngram])) + # ngram index + n = len(ngram)-1 + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[n][ngram] = float(term_freq)*(self.ref_len - df) + # compute norm for the vector. the norm will be used for computing similarity + norm[n] += pow(vec[n][ngram], 2) + + if n == 1: + length += term_freq + norm = [np.sqrt(n) for n in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + ''' + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + ''' + delta = float(length_hyp - length_ref) + # measure consine similarity + val = np.array([0.0 for _ in range(self.n)]) + for n in range(self.n): + # ngram + for (ngram,count) in vec_hyp[n].items(): + # vrama91 : added clipping + val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] + + if (norm_hyp[n] != 0) and (norm_ref[n] != 0): + val[n] /= (norm_hyp[n]*norm_ref[n]) + + assert(not math.isnan(val[n])) + # vrama91: added a length based gaussian penalty + val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) + return val + + scores = [] + for test, refs in zip(self.ctest, self.crefs): + # compute vector for test captions + vec, norm, length = counts2vec(test) + # compute vector for ref captions + score = np.array([0.0 for _ in range(self.n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec(ref) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + # change by vrama91 - mean of ngram scores, instead of sum + score_avg = np.mean(score) + # divide by number of references + score_avg /= len(refs) + # multiply score by 10 + score_avg *= 10.0 + # append score of an image to the score list + scores.append(score_avg) + return scores + + def compute_score(self): + # compute cider score + score = self.compute_cider() + # debug + # print score + return np.mean(np.array(score)), np.array(score) \ No newline at end of file diff --git a/evaluation/meteor/__init__.py b/evaluation/meteor/__init__.py new file mode 100644 index 0000000..35b4a26 --- /dev/null +++ b/evaluation/meteor/__init__.py @@ -0,0 +1 @@ +from .meteor import Meteor \ No newline at end of file diff --git a/evaluation/meteor/data/paraphrase-en.gz b/evaluation/meteor/data/paraphrase-en.gz new file mode 100644 index 0000000..88033c8 Binary files /dev/null and b/evaluation/meteor/data/paraphrase-en.gz differ diff --git a/evaluation/meteor/meteor.py b/evaluation/meteor/meteor.py new file mode 100644 index 0000000..e41fb18 --- /dev/null +++ b/evaluation/meteor/meteor.py @@ -0,0 +1,75 @@ +# Python wrapper for METEOR implementation, by Xinlei Chen +# Acknowledge Michael Denkowski for the generous discussion and help + +import os +import subprocess +import threading +import tarfile +from utils import download_from_url + +METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz' +METEOR_JAR = 'meteor-1.5.jar' + +class Meteor: + def __init__(self): + base_path = os.path.dirname(os.path.abspath(__file__)) + jar_path = os.path.join(base_path, METEOR_JAR) + gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL)) + if not os.path.isfile(jar_path): + if not os.path.isfile(gz_path): + download_from_url(METEOR_GZ_URL, gz_path) + tar = tarfile.open(gz_path, "r") + tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) + tar.close() + os.remove(gz_path) + + self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ + '-', '-', '-stdio', '-l', 'en', '-norm'] + self.meteor_p = subprocess.Popen(self.meteor_cmd, \ + cwd=os.path.dirname(os.path.abspath(__file__)), \ + stdin=subprocess.PIPE, \ + stdout=subprocess.PIPE, \ + stderr=subprocess.PIPE) + # Used to guarantee thread safety + self.lock = threading.Lock() + + def compute_score(self, gts, res): + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + scores = [] + + eval_line = 'EVAL' + self.lock.acquire() + for i in imgIds: + assert(len(res[i]) == 1) + stat = self._stat(res[i][0], gts[i]) + eval_line += ' ||| {}'.format(stat) + + self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) + self.meteor_p.stdin.flush() + for i in range(0,len(imgIds)): + scores.append(float(self.meteor_p.stdout.readline().strip())) + score = float(self.meteor_p.stdout.readline().strip()) + self.lock.release() + + return score, scores + + def _stat(self, hypothesis_str, reference_list): + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) + self.meteor_p.stdin.flush() + raw = self.meteor_p.stdout.readline().decode().strip() + numbers = [str(int(float(n))) for n in raw.split()] + return ' '.join(numbers) + + def __del__(self): + self.lock.acquire() + self.meteor_p.stdin.close() + self.meteor_p.kill() + self.meteor_p.wait() + self.lock.release() + + def __str__(self): + return 'METEOR' diff --git a/evaluation/rouge/__init__.py b/evaluation/rouge/__init__.py new file mode 100644 index 0000000..59397f8 --- /dev/null +++ b/evaluation/rouge/__init__.py @@ -0,0 +1 @@ +from .rouge import Rouge \ No newline at end of file diff --git a/evaluation/rouge/rouge.py b/evaluation/rouge/rouge.py new file mode 100644 index 0000000..06529f3 --- /dev/null +++ b/evaluation/rouge/rouge.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# +# File Name : rouge.py +# +# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) +# +# Creation Date : 2015-01-07 06:03 +# Author : Ramakrishna Vedantam + +import numpy as np +import pdb + + +def my_lcs(string, sub): + """ + Calculates longest common subsequence for a pair of tokenized strings + :param string : list of str : tokens from a string split using whitespace + :param sub : list of str : shorter string, also split using whitespace + :returns: length (list of int): length of the longest common subsequence between the two strings + + Note: my_lcs only gives length of the longest common subsequence, not the actual LCS + """ + if (len(string) < len(sub)): + sub, string = string, sub + + lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] + + for j in range(1, len(sub) + 1): + for i in range(1, len(string) + 1): + if (string[i - 1] == sub[j - 1]): + lengths[i][j] = lengths[i - 1][j - 1] + 1 + else: + lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) + + return lengths[len(string)][len(sub)] + + +class Rouge(): + ''' + Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set + + ''' + + def __init__(self): + # vrama91: updated the value below based on discussion with Hovey + self.beta = 1.2 + + def calc_score(self, candidate, refs): + """ + Compute ROUGE-L score given one candidate and references for an image + :param candidate: str : candidate sentence to be evaluated + :param refs: list of str : COCO reference sentences for the particular image to be evaluated + :returns score: int (ROUGE-L score for the candidate evaluated against references) + """ + assert (len(candidate) == 1) + assert (len(refs) > 0) + prec = [] + rec = [] + + # split into tokens + token_c = candidate[0].split(" ") + + for reference in refs: + # split into tokens + token_r = reference.split(" ") + # compute the longest common subsequence + lcs = my_lcs(token_r, token_c) + prec.append(lcs / float(len(token_c))) + rec.append(lcs / float(len(token_r))) + + prec_max = max(prec) + rec_max = max(rec) + + if (prec_max != 0 and rec_max != 0): + score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) + else: + score = 0.0 + return score + + def compute_score(self, gts, res): + """ + Computes Rouge-L score given a set of reference and candidate sentences for the dataset + Invoked by evaluate_captions.py + :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values + :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values + :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) + """ + assert (gts.keys() == res.keys()) + imgIds = gts.keys() + + score = [] + for id in imgIds: + hypo = res[id] + ref = gts[id] + + score.append(self.calc_score(hypo, ref)) + + # Sanity check. + assert (type(hypo) is list) + assert (len(hypo) == 1) + assert (type(ref) is list) + assert (len(ref) > 0) + + average_score = np.mean(np.array(score)) + return average_score, np.array(score) + + def __str__(self): + return 'ROUGE' diff --git a/evaluation/tokenizer.py b/evaluation/tokenizer.py new file mode 100644 index 0000000..73fb49c --- /dev/null +++ b/evaluation/tokenizer.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# +# File Name : ptbtokenizer.py +# +# Description : Do the PTB Tokenization and remove punctuations. +# +# Creation Date : 29-12-2014 +# Last Modified : Thu Mar 19 09:53:35 2015 +# Authors : Hao Fang and Tsung-Yi Lin + +import os +import subprocess +import tempfile + +class PTBTokenizer(object): + """Python wrapper of Stanford PTBTokenizer""" + + corenlp_jar = 'stanford-corenlp-3.4.1.jar' + punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ + ".", "?", "!", ",", ":", "-", "--", "...", ";"] + + @classmethod + def tokenize(cls, corpus): + cmd = ['java', '-cp', cls.corenlp_jar, \ + 'edu.stanford.nlp.process.PTBTokenizer', \ + '-preserveLines', '-lowerCase'] + + if isinstance(corpus, list) or isinstance(corpus, tuple): + if isinstance(corpus[0], list) or isinstance(corpus[0], tuple): + corpus = {i:c for i, c in enumerate(corpus)} + else: + corpus = {i: [c, ] for i, c in enumerate(corpus)} + + # prepare data for PTB Tokenizer + tokenized_corpus = {} + image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))] + sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v]) + + # save sentences to temporary file + path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) + tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) + tmp_file.write(sentences.encode()) + tmp_file.close() + + # tokenize sentence + cmd.append(os.path.basename(tmp_file.name)) + p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ + stdout=subprocess.PIPE, stderr=open(os.devnull, 'w')) + token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] + token_lines = token_lines.decode() + lines = token_lines.split('\n') + # remove temp file + os.remove(tmp_file.name) + + # create dictionary for tokenized captions + for k, line in zip(image_id, lines): + if not k in tokenized_corpus: + tokenized_corpus[k] = [] + tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ + if w not in cls.punctuations]) + tokenized_corpus[k].append(tokenized_caption) + + return tokenized_corpus \ No newline at end of file diff --git a/framework.png b/framework.png new file mode 100644 index 0000000..7d40153 Binary files /dev/null and b/framework.png differ diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..943ce68 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,2 @@ +from .transformer import Transformer +from .captioning_model import CaptioningModel diff --git a/models/beam_search/__init__.py b/models/beam_search/__init__.py new file mode 100644 index 0000000..21ac612 --- /dev/null +++ b/models/beam_search/__init__.py @@ -0,0 +1 @@ +from .beam_search import BeamSearch diff --git a/models/beam_search/beam_search.py b/models/beam_search/beam_search.py new file mode 100644 index 0000000..bfe99b4 --- /dev/null +++ b/models/beam_search/beam_search.py @@ -0,0 +1,145 @@ +import torch +import utils + + +class BeamSearch(object): + def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): + self.model = model + self.max_len = max_len + self.eos_idx = eos_idx + self.beam_size = beam_size + self.b_s = None + self.device = None + self.seq_mask = None + self.seq_logprob = None + self.outputs = None + self.log_probs = None + self.selected_words = None + self.all_log_probs = None + + def _expand_state(self, selected_beam, cur_beam_size): + def fn(s): + shape = [int(sh) for sh in s.shape] + beam = selected_beam + for _ in shape[1:]: + beam = beam.unsqueeze(-1) + s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, + beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) + s = s.view(*([-1, ] + shape[1:])) + return s + + return fn + + def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): + if isinstance(visual, torch.Tensor): + visual_shape = visual.shape + visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] + visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] + selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) + selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] + visual_exp = visual.view(visual_exp_shape) + selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) + visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) + else: + new_visual = [] + for im in visual: + visual_shape = im.shape + visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] + visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] + selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) + selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] + visual_exp = im.view(visual_exp_shape) + selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) + new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) + new_visual.append(new_im) + visual = tuple(new_visual) + return visual + + def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs): + self.b_s = utils.get_batch_size(visual) + self.device = utils.get_device(visual) + self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) + self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) + self.log_probs = [] + self.selected_words = None + if return_probs: + self.all_log_probs = [] + + outputs = [] + with self.model.statefulness(self.b_s): + for t in range(self.max_len): + visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs) + + # Sort result + seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) + outputs = torch.cat(outputs, -1) + outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) + log_probs = torch.cat(self.log_probs, -1) + log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) + if return_probs: + all_log_probs = torch.cat(self.all_log_probs, 2) + all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, + self.max_len, + all_log_probs.shape[-1])) + outputs = outputs.contiguous()[:, :out_size] + log_probs = log_probs.contiguous()[:, :out_size] + if out_size == 1: + outputs = outputs.squeeze(1) + log_probs = log_probs.squeeze(1) + + if return_probs: + return outputs, log_probs, all_log_probs + else: + return outputs, log_probs + + def select(self, t, candidate_logprob, **kwargs): + selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) + selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] + return selected_idx, selected_logprob + + def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs): + cur_beam_size = 1 if t == 0 else self.beam_size + + word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs) + word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1) + candidate_logprob = self.seq_logprob + word_logprob + + # Mask sequence if it reaches EOS + if t > 0: + mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1) + self.seq_mask = self.seq_mask * mask + word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) + old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() + old_seq_logprob[:, :, 1:] = -999 + candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) + + selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) + # selected_beam = selected_idx / candidate_logprob.shape[-1] # // + # selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] # 取余 + selected_beam = selected_idx // candidate_logprob.shape[-1] # // + selected_words = selected_idx % candidate_logprob.shape[-1] # 取余 + + self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) + visual = self._expand_visual(visual, cur_beam_size, selected_beam) + + self.seq_logprob = selected_logprob.unsqueeze(-1) + self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) + outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) + outputs.append(selected_words.unsqueeze(-1)) + + if return_probs: + if t == 0: + self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) + else: + self.all_log_probs.append(word_logprob.unsqueeze(2)) + + this_word_logprob = torch.gather(word_logprob, 1, + selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, + word_logprob.shape[-1])) + this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) + self.log_probs = list( + torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) + self.log_probs.append(this_word_logprob) + self.selected_words = selected_words.view(-1, 1) + + return visual, outputs diff --git a/models/captioning_model.py b/models/captioning_model.py new file mode 100644 index 0000000..c170ac0 --- /dev/null +++ b/models/captioning_model.py @@ -0,0 +1,70 @@ +import torch +from torch import distributions +import utils +from models.containers import Module +from models.beam_search import * + + +class CaptioningModel(Module): + def __init__(self): + super(CaptioningModel, self).__init__() + + def init_weights(self): + raise NotImplementedError + + def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): + raise NotImplementedError + + def forward(self, images, seq, *args): + device = images.device + b_s = images.size(0) + seq_len = seq.size(1) + state = self.init_state(b_s, device) + out = None + + outputs = [] + for t in range(seq_len): + out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing') + outputs.append(out) + + outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1) + return outputs + + def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: + b_s = utils.get_batch_size(visual) + device = utils.get_device(visual) + outputs = [] + log_probs = [] + + mask = torch.ones((b_s,), device=device) + with self.statefulness(b_s): + out = None + for t in range(max_len): + log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs) + out = torch.max(log_probs_t, -1)[1] + mask = mask * (out.squeeze(-1) != eos_idx).float() + log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1)) + outputs.append(out) + + return torch.cat(outputs, 1), torch.cat(log_probs, 1) + + def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: + b_s = utils.get_batch_size(visual) + outputs = [] + log_probs = [] + + with self.statefulness(b_s): + out = None + for t in range(max_len): + out = self.step(t, out, visual, None, mode='feedback', **kwargs) + distr = distributions.Categorical(logits=out[:, 0]) + out = distr.sample().unsqueeze(1) + outputs.append(out) + log_probs.append(distr.log_prob(out).unsqueeze(1)) + + return torch.cat(outputs, 1), torch.cat(log_probs, 1) + + def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1, + return_probs=False, **kwargs): + bs = BeamSearch(self, max_len, eos_idx, beam_size) + return bs.apply(visual, out_size, return_probs, **kwargs) diff --git a/models/containers.py b/models/containers.py new file mode 100644 index 0000000..52296cd --- /dev/null +++ b/models/containers.py @@ -0,0 +1,80 @@ +from contextlib import contextmanager +from torch import nn +from utils.typing import * + + +class Module(nn.Module): + def __init__(self): + super(Module, self).__init__() + self._is_stateful = False + self._state_names = [] + self._state_defaults = dict() + + def register_state(self, name: str, default: TensorOrNone): + self._state_names.append(name) + if default is None: + self._state_defaults[name] = None + else: + self._state_defaults[name] = default.clone().detach() + self.register_buffer(name, default) + + def states(self): + for name in self._state_names: + yield self._buffers[name] + for m in self.children(): + if isinstance(m, Module): + yield from m.states() + + def apply_to_states(self, fn): + for name in self._state_names: + self._buffers[name] = fn(self._buffers[name]) + for m in self.children(): + if isinstance(m, Module): + m.apply_to_states(fn) + + def _init_states(self, batch_size: int): + for name in self._state_names: + if self._state_defaults[name] is None: + self._buffers[name] = None + else: + self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) + self._buffers[name] = self._buffers[name].unsqueeze(0) + self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) + self._buffers[name] = self._buffers[name].contiguous() + + def _reset_states(self): + for name in self._state_names: + if self._state_defaults[name] is None: + self._buffers[name] = None + else: + self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) + + def enable_statefulness(self, batch_size: int): + for m in self.children(): + if isinstance(m, Module): + m.enable_statefulness(batch_size) + self._init_states(batch_size) + self._is_stateful = True + + def disable_statefulness(self): + for m in self.children(): + if isinstance(m, Module): + m.disable_statefulness() + self._reset_states() + self._is_stateful = False + + @contextmanager + def statefulness(self, batch_size: int): + self.enable_statefulness(batch_size) + try: + yield + finally: + self.disable_statefulness() + + +class ModuleList(nn.ModuleList, Module): + pass + + +class ModuleDict(nn.ModuleDict, Module): + pass diff --git a/models/transformer/__init__.py b/models/transformer/__init__.py new file mode 100644 index 0000000..c52d903 --- /dev/null +++ b/models/transformer/__init__.py @@ -0,0 +1,4 @@ +from .transformer import * +from .encoders import * +from .decoders import * +from .attention import * diff --git a/models/transformer/attention.py b/models/transformer/attention.py new file mode 100644 index 0000000..e4fc0e6 --- /dev/null +++ b/models/transformer/attention.py @@ -0,0 +1,184 @@ +import numpy as np +import torch +from torch import nn +from models.containers import Module + + +class ScaledDotProductAttention(nn.Module): + ''' + Scaled dot-product attention + ''' + + def __init__(self, d_model, d_k, d_v, h, dropout=.1, comment=None): + ''' + :param d_model: Output dimensionality of the model + :param d_k: Dimensionality of queries and keys + :param d_v: Dimensionality of values + :param h: Number of heads + ''' + super(ScaledDotProductAttention, self).__init__() + self.fc_q = nn.Linear(d_model, h * d_k) + self.fc_k = nn.Linear(d_model, h * d_k) + self.fc_v = nn.Linear(d_model, h * d_v) + self.fc_o = nn.Linear(h * d_v, d_model) + self.dropout = nn.Dropout(dropout) + + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + self.h = h + + self.init_weights() + + self.comment = comment + + def init_weights(self): + nn.init.xavier_uniform_(self.fc_q.weight) + nn.init.xavier_uniform_(self.fc_k.weight) + nn.init.xavier_uniform_(self.fc_v.weight) + nn.init.xavier_uniform_(self.fc_o.weight) + nn.init.constant_(self.fc_q.bias, 0) + nn.init.constant_(self.fc_k.bias, 0) + nn.init.constant_(self.fc_v.bias, 0) + nn.init.constant_(self.fc_o.bias, 0) + + def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): + ''' + Computes + :param queries: Queries (b_s, nq, d_model) + :param keys: Keys (b_s, nk, d_model) + :param values: Values (b_s, nk, d_model) + :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. + :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). + :return: + ''' + + b_s, nq = queries.shape[:2] + nk = keys.shape[1] + + q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) + k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) + v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) + + att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) + if attention_weights is not None: + att = att * attention_weights + if attention_mask is not None: + att = att.masked_fill(attention_mask, -np.inf) + att = torch.softmax(att, -1) + att = self.dropout(att) + + out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) + out = self.fc_o(out) # (b_s, nq, d_model) + return out + + +class ScaledDotProductAttentionMemory(nn.Module): + ''' + Scaled dot-product attention with memory + ''' + + def __init__(self, d_model, d_k, d_v, h, m): + ''' + :param d_model: Output dimensionality of the model + :param d_k: Dimensionality of queries and keys + :param d_v: Dimensionality of values + :param h: Number of heads + :param m: Number of memory slots + ''' + super(ScaledDotProductAttentionMemory, self).__init__() + self.fc_q = nn.Linear(d_model, h * d_k) + self.fc_k = nn.Linear(d_model, h * d_k) + self.fc_v = nn.Linear(d_model, h * d_v) + self.fc_o = nn.Linear(h * d_v, d_model) + self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k)) + self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v)) + + self.d_model = d_model + self.d_k = d_k + self.d_v = d_v + self.h = h + self.m = m + + self.init_weights() + + def init_weights(self): + nn.init.xavier_uniform_(self.fc_q.weight) + nn.init.xavier_uniform_(self.fc_k.weight) + nn.init.xavier_uniform_(self.fc_v.weight) + nn.init.xavier_uniform_(self.fc_o.weight) + nn.init.normal_(self.m_k, 0, 1 / self.d_k) + nn.init.normal_(self.m_v, 0, 1 / self.m) + nn.init.constant_(self.fc_q.bias, 0) + nn.init.constant_(self.fc_k.bias, 0) + nn.init.constant_(self.fc_v.bias, 0) + nn.init.constant_(self.fc_o.bias, 0) + + def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): + ''' + Computes + :param queries: Queries (b_s, nq, d_model) + :param keys: Keys (b_s, nk, d_model) + :param values: Values (b_s, nk, d_model) + :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. + :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). + :return: + ''' + b_s, nq = queries.shape[:2] + nk = keys.shape[1] + + m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k) + m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v) + + q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) + k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) + v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) + + att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) + if attention_weights is not None: + att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1) + if attention_mask is not None: + att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf) + att = torch.softmax(att, -1) + out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) + out = self.fc_o(out) # (b_s, nq, d_model) + return out + + +class MultiHeadAttention(Module): + ''' + Multi-head attention layer with Dropout and Layer Normalization. + ''' + + def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False, + attention_module=None, attention_module_kwargs=None, comment=None): + super(MultiHeadAttention, self).__init__() + self.identity_map_reordering = identity_map_reordering + self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h, comment=comment) + self.dropout = nn.Dropout(p=dropout) + self.layer_norm = nn.LayerNorm(d_model) + + self.can_be_stateful = can_be_stateful + if self.can_be_stateful: + self.register_state('running_keys', torch.zeros((0, d_model))) + self.register_state('running_values', torch.zeros((0, d_model))) + + def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): + if self.can_be_stateful and self._is_stateful: + self.running_keys = torch.cat([self.running_keys, keys], 1) + keys = self.running_keys + + self.running_values = torch.cat([self.running_values, values], 1) + values = self.running_values + + if self.identity_map_reordering: + q_norm = self.layer_norm(queries) + k_norm = self.layer_norm(keys) + v_norm = self.layer_norm(values) + out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights) + out = queries + self.dropout(torch.relu(out)) + else: + out = self.attention(queries, keys, values, attention_mask, attention_weights) + out = self.dropout(out) + out = self.layer_norm(queries + out) + return out diff --git a/models/transformer/decoders.py b/models/transformer/decoders.py new file mode 100644 index 0000000..4d57413 --- /dev/null +++ b/models/transformer/decoders.py @@ -0,0 +1,84 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from models.transformer.attention import MultiHeadAttention +from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward +from models.containers import Module, ModuleList + + +class DecoderLayer(Module): + def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, + enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): + super(DecoderLayer, self).__init__() + self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, + attention_module=self_att_module, + attention_module_kwargs=self_att_module_kwargs) + self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, + attention_module=enc_att_module, + attention_module_kwargs=enc_att_module_kwargs) + + self.dropout1 = nn.Dropout(dropout) + self.lnorm1 = nn.LayerNorm(d_model) + + self.dropout2 = nn.Dropout(dropout) + self.lnorm2 = nn.LayerNorm(d_model) + self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) + + def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): + # MHA+AddNorm + self_att = self.self_att(input, input, input, mask_self_att) + self_att = self.lnorm1(input + self.dropout1(self_att)) + self_att = self_att * mask_pad + # MHA+AddNorm:Image + enc_att = self.enc_att(self_att, enc_output, enc_output, mask_enc_att) + enc_att = self.lnorm2(self_att + self.dropout2(enc_att)) + enc_att = enc_att * mask_pad + + ff = self.pwff(enc_att) + ff = ff * mask_pad + return ff + +class TransformerDecoderLayer(Module): + def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, + self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): + super(TransformerDecoderLayer, self).__init__() + self.d_model = d_model + self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) + self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) + self.layers = ModuleList( + [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) + self.fc = nn.Linear(d_model, vocab_size, bias=False) + self.max_len = max_len + self.padding_idx = padding_idx + self.N = N_dec + + self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) + self.register_state('running_seq', torch.zeros((1,)).long()) + + def forward(self, input, encoder_output, mask_encoder): + # input (b_s, seq_len) + b_s, seq_len = input.shape[:2] + mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) + mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), + diagonal=1) + mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) + mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() + mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) + if self._is_stateful: + self.running_mask_self_attention = torch.cat([self.running_mask_self_attention.type_as(mask_self_attention), mask_self_attention], -1) + mask_self_attention = self.running_mask_self_attention + + seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) + seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) + if self._is_stateful: + self.running_seq.add_(1) + seq = self.running_seq + + out = self.word_emb(input) + self.pos_emb(seq) + + for i, l in enumerate(self.layers): + out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) + + out = self.fc(out) + return F.log_softmax(out, dim=-1) \ No newline at end of file diff --git a/models/transformer/encoders.py b/models/transformer/encoders.py new file mode 100644 index 0000000..1bcad8d --- /dev/null +++ b/models/transformer/encoders.py @@ -0,0 +1,84 @@ +from torch.nn import functional as F +from torch.nn.modules.activation import LeakyReLU +from models.transformer.utils import PositionWiseFeedForward +import torch +from torch import nn +from models.transformer.attention import MultiHeadAttention + + +class SR(nn.Module): + def __init__(self, N, d_model=512): + super(SR, self).__init__() + self.MLP = nn.Sequential( + nn.Linear(N*d_model, N*d_model), + nn.LeakyReLU(), + nn.Linear(N*d_model, d_model), + nn.LeakyReLU() + ) + def forward(self, x, layers, attention_mask = None, attention_weights = None): + out = x + outs = [] + for l in layers: + out = l(out, out, out, attention_mask, attention_weights) + outs.append(out) + outs = self.MLP(torch.cat(outs, -1)) + out = 0.2 * outs + out + return out + +class EncoderLayer(nn.Module): + def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, + attention_module=None, attention_module_kwargs=None): + super(EncoderLayer, self).__init__() + self.identity_map_reordering = identity_map_reordering + self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, + attention_module=attention_module, + attention_module_kwargs=attention_module_kwargs) + self.dropout = nn.Dropout(dropout) + self.lnorm = nn.LayerNorm(d_model) + self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) + + def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): + + att = self.mhatt(queries, keys, values, attention_mask, attention_weights) + att = self.lnorm(queries + self.dropout(att)) + ff = self.pwff(att) + return ff + + +class MultiLevelEncoder(nn.Module): + def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, + identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): + super(MultiLevelEncoder, self).__init__() + self.d_model = d_model + self.dropout = dropout + self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, + identity_map_reordering=identity_map_reordering, + attention_module=attention_module, # ScaledDotProductAttention + attention_module_kwargs=attention_module_kwargs) # {'m': args.m} + for _ in range(N)]) + self.SR = SR(N, d_model) + self.padding_idx = padding_idx + + def forward(self, input, attention_weights=None): + + # input (b_s, seq_len, d_in) + attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len) + out = self.SR(input, self.layers, attention_mask, attention_weights) + + return out, attention_mask + + +class TransformerEncoder(MultiLevelEncoder): + def __init__(self, N, padding_idx, d_in=2048, **kwargs): + super(TransformerEncoder, self).__init__(N, padding_idx, **kwargs) + self.fc = nn.Linear(d_in, self.d_model) + self.dropout = nn.Dropout(p=self.dropout) + self.layer_norm = nn.LayerNorm(self.d_model) + + def forward(self, input, attention_weights=None): + mask = (torch.sum(input, dim=-1) == 0).unsqueeze(-1) + out = F.relu(self.fc(input)) + out = self.dropout(out) + out = self.layer_norm(out) + out = out.masked_fill(mask, 0) + return super(TransformerEncoder, self).forward(out, attention_weights=attention_weights) diff --git a/models/transformer/transformer.py b/models/transformer/transformer.py new file mode 100644 index 0000000..1d1e40a --- /dev/null +++ b/models/transformer/transformer.py @@ -0,0 +1,155 @@ +from matplotlib import image +import torch +import torch.nn.functional as F +from torch import nn +import copy +from models.containers import ModuleList +from models.transformer.utils import sinusoid_encoding_table +from models.beam_search import * +from ..captioning_model import CaptioningModel + + +class SP(nn.Module): + """SP layer implementation + + Args: + num_clusters : int + The number of pseudo regions + dim : int + Dimension of pseudo regions + alpha : float + Parameter of initialization. Larger value is harder assignment. + normalize_input : bool + If true, pseudo regions-wise L2 normalization is applied to input. + """ + def __init__(self, num_regions=64, dim=128, alpha=100.0, normalize_input=True): + super().__init__() + self.num_regions = num_regions + self.dim = dim + self.alpha = alpha + self.normalize_input = normalize_input + self.conv = nn.Conv2d(dim, num_regions, kernel_size=(1, 1), bias=True) + self.centroids = nn.Parameter(torch.rand(num_regions, dim)) + self.init_weights() + def init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + def forward(self, grids): + + N, C = grids.shape[0], grids.shape[-1] + + grids = grids.view(N, 7, 7, -1).permute(0,3,1,2).contiguous() + + if self.normalize_input: + grids = F.normalize(grids, p=2, dim=1) # across descriptor dim + + soft_assign = self.conv(grids).view(N, self.num_regions, -1) + soft_assign = F.softmax(soft_assign, dim=1) + + x_flatten = grids.view(N, C, -1) + + residual = x_flatten.expand(self.num_regions, -1, -1, -1).permute(1, 0, 2, 3).contiguous() - \ + self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).contiguous().unsqueeze(0) + + residual *= soft_assign.unsqueeze(2) + p = residual.sum(dim=-1) + + p = F.normalize(p, p=2, dim=2) # intra-normalization + p = p.view(grids.size(0), -1) + p = F.normalize(p, p=2, dim=1) # L2 normalize + + return p + +class Transformer(CaptioningModel): + def __init__(self, bos_idx, encoder, decoder,num_clusters, vocab_size, max_len, padding_idx, text_d_model=512): + super(Transformer, self).__init__() + self.bos_idx = bos_idx + self.encoder = encoder + self.decoder = decoder + self.text_d_model = text_d_model + self.num_clusters=num_clusters + self.padding_idx = padding_idx + self.word_emb = nn.Embedding(vocab_size, text_d_model, padding_idx=padding_idx) + self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, text_d_model, 0), freeze=True) + + self.SP = SP(num_regions=self.num_clusters, dim=2048) + + self.softmax = nn.Softmax(dim=-1) + self.register_state('enc_output', None) + self.register_state('mask_enc', None) + self.init_weights() + + @property + def d_model(self): + return self.decoder.d_model + + def init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, mode, images, seq=None, max_len=None, eos_idx=None, beam_size=None, out_size=1, return_probs=False): + ''' + images: torch.Size([50, 49, 2048]) + seq: torch.Size([50, 27]) + ''' + if mode == 'xe': + bs, _, vis_dim = images.size() + # Grid feature + grid_enc_output, grid_mask_enc = self.encoder(images) + + # Pseudo-region feature + pseudo_region = self.SP(images).view(bs, -1, vis_dim) # (N, num_clusters*2048) -> (N, num_clusters, 2048) + pseudo_region_enc_output, pseudo_region_mask_enc = self.encoder(pseudo_region) + + output, mask = torch.cat([grid_enc_output, pseudo_region_enc_output],dim=1), torch.cat([grid_mask_enc, pseudo_region_mask_enc], dim=-1) + dec_output = self.decoder(seq, output, mask) + + return dec_output + + elif mode == 'rl': + bs = BeamSearch(self, max_len, eos_idx, beam_size) + return bs.apply(images, out_size, return_probs) + + def init_state(self, b_s, device): + return [torch.zeros((b_s, 0), dtype=torch.long, device=device), + None, None] + + def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): + it = None + if mode == 'teacher_forcing': + raise NotImplementedError + elif mode == 'feedback': + if t == 0: + grid_enc_output, grid_mask_enc = self.encoder(visual) + bs, _, vis_dim = visual.size() + pseudo_region = self.SP(visual).view(bs, -1, vis_dim) + pseudo_region_enc_output, pseudo_region_mask_enc = self.encoder(pseudo_region) + self.enc_output, self.mask_enc = torch.cat([grid_enc_output, pseudo_region_enc_output],dim=1), torch.cat([grid_mask_enc, pseudo_region_mask_enc], dim=-1) + + if isinstance(visual, torch.Tensor): + it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() # self.bos_idx: '' + else: + it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() + else: + it = prev_output + return self.decoder(it, self.enc_output, self.mask_enc) + + +class TransformerEnsemble(CaptioningModel): + def __init__(self, model: Transformer, weight_files): + super(TransformerEnsemble, self).__init__() + self.n = len(weight_files) + self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) + for i in range(self.n): + state_dict_i = torch.load(weight_files[i])['state_dict'] + self.models[i].load_state_dict(state_dict_i) + + def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): + out_ensemble = [] + for i in range(self.n): + out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) + out_ensemble.append(out_i.unsqueeze(0)) + + return torch.mean(torch.cat(out_ensemble, 0), dim=0) diff --git a/models/transformer/utils.py b/models/transformer/utils.py new file mode 100644 index 0000000..dc01168 --- /dev/null +++ b/models/transformer/utils.py @@ -0,0 +1,50 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +def position_embedding(input, d_model): + input = input.view(-1, 1) + dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) + sin = torch.sin(input / 10000 ** (2 * dim / d_model)) + cos = torch.cos(input / 10000 ** (2 * dim / d_model)) + + out = torch.zeros((input.shape[0], d_model), device=input.device) + out[:, ::2] = sin + out[:, 1::2] = cos + return out + + +def sinusoid_encoding_table(max_len, d_model, padding_idx=None): + pos = torch.arange(max_len, dtype=torch.float32) + out = position_embedding(pos, d_model) + + if padding_idx is not None: + out[padding_idx] = 0 + return out + + +class PositionWiseFeedForward(nn.Module): + ''' + Position-wise feed forward layer + ''' + + def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): + super(PositionWiseFeedForward, self).__init__() + self.identity_map_reordering = identity_map_reordering + self.fc1 = nn.Linear(d_model, d_ff) + self.fc2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(p=dropout) + self.dropout_2 = nn.Dropout(p=dropout) + self.layer_norm = nn.LayerNorm(d_model) + + def forward(self, input): + if self.identity_map_reordering: + out = self.layer_norm(input) + out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) + out = input + self.dropout(torch.relu(out)) + else: + out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) + out = self.dropout(out) + out = self.layer_norm(input + out) + return out diff --git a/test_transformer.py b/test_transformer.py new file mode 100644 index 0000000..90b2aeb --- /dev/null +++ b/test_transformer.py @@ -0,0 +1,93 @@ +import random +import os +from data import ImageDetectionsField, TextField, RawField +from data import COCO, DataLoader +import evaluation +from models.transformer import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention +from visualize import visualize_grid_attention_v2 +import torch +from tqdm import tqdm +import argparse +import pickle +import numpy as np +import time + +random.seed(1234) +torch.manual_seed(1234) +np.random.seed(1234) + + +def predict_captions(model, dataloader, text_field): + import itertools + model.eval() + seq_len = 20 + beam_size = 5 + gen = {} + gts = {} + with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar: + for it, (images, caps_gt, _) in enumerate(iter(dataloader)): + images = images.to(device) + with torch.no_grad(): + out, _ = model(mode='rl', images=images, max_len=seq_len, eos_idx=text_field.vocab.stoi[''], beam_size=beam_size, out_size=1) + # print(out.size(), att_map.size()) + caps_gen = text_field.decode(out, join_words=False) + for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): + gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) + gen['%d_%d' % (it, i)] = [gen_i.strip(), ] + gts['%d_%d' % (it, i)] = gts_i + pbar.update() + + gts = evaluation.PTBTokenizer.tokenize(gts) + gen = evaluation.PTBTokenizer.tokenize(gen) + scores, _ = evaluation.compute_scores(gts, gen) + + return scores + + +if __name__ == '__main__': + start_time = time.time() + device = torch.device('cuda') + + parser = argparse.ArgumentParser(description='Transformer') + parser.add_argument('--batch_size', type=int, default=10) + parser.add_argument('--workers', type=int, default=4) + parser.add_argument('--m', type=int, default=40) + + parser.add_argument('--features_path', type=str, default='/home/zhanghaonan/RSTNet-master/X101-features/X101_grid_feats_coco_trainval.hdf5') + parser.add_argument('--annotation_folder', type=str, default='/home/zhanghaonan/RSTNet-master/m2_annotations') + + # the path of tested model and vocabulary + parser.add_argument('--model_path', type=str, default='saved_transformer_models/demo_rl_v5_best_test.pth') + parser.add_argument('--vocab_path', type=str, default='vocab.pkl') + parser.add_argument('--num_clusters', type=int, default=5) + args = parser.parse_args() + + print('Transformer Evaluation') + + # Pipeline for image regions + image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False) + + # Pipeline for text + text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', + remove_punctuation=True, nopoints=False) + + # Create the dataset + dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) + _, _, test_dataset = dataset.splits + text_field.vocab = pickle.load(open(args.vocab_path, 'rb')) + + # Model and dataloaders + encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': args.m}) + decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) + + model = Transformer(text_field.vocab.stoi[''], encoder, decoder, args.num_clusters, len(text_field.vocab), 54, text_field.vocab.stoi[''], 512).to(device) + + data = torch.load(args.model_path) + model.load_state_dict({k.replace('module.',''):v for k,v in data['state_dict'].items()}) + + dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField(), 'add_text':text_field}) + dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers) + + scores = predict_captions(model, dict_dataloader_test, text_field) + print(scores) + print('it costs {} s to test.'.format(time.time() - start_time)) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..46d538d --- /dev/null +++ b/train.sh @@ -0,0 +1 @@ +python train_transformer.py --exp_name demo --num_clusters 5 diff --git a/train_transformer.py b/train_transformer.py new file mode 100644 index 0000000..8abadf1 --- /dev/null +++ b/train_transformer.py @@ -0,0 +1,485 @@ +import random +import os +# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" +from torch.nn.modules.loss import MSELoss +from data import ImageDetectionsField, TextField, RawField +from data import COCO, DataLoader +import evaluation +from evaluation import PTBTokenizer, Cider + +from models.transformer import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention + + +import torch +from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from torch.nn import NLLLoss, MSELoss +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter +import argparse + + +import pickle +import numpy as np +import itertools +from shutil import copyfile + +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.utils.data import DistributedSampler + +def evaluate_loss(model, dataloader, loss_fn, text_field, e, device): + + # Validation loss + model.eval() + running_loss = .0 + with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader), disable=device!=0) as pbar: + with torch.no_grad(): + for it, (detections, captions) in enumerate(dataloader): + detections, captions = detections.to(device), captions.to(device) + out = model(mode='xe', images=detections, seq=captions) + captions_gt = captions[:, 1:].contiguous() + out = out[:, :-1].contiguous() + + loss = loss_fn[0](out.view(-1, len(text_field.vocab)), captions_gt.view(-1)) + this_loss = loss.item() + running_loss += this_loss + + pbar.set_postfix(loss=running_loss / (it + 1)) + pbar.update() + + val_loss = running_loss / len(dataloader) + return val_loss + +def evaluate_metrics(model, dataloader, text_field, e, device): + import itertools + model.eval() + seq_len = 20 + beam_size = 5 + gen = {} + gts = {} + + with tqdm(desc='Epoch %d - evaluation' % e, unit='it', total=len(dataloader), disable=device!=0) as pbar: + for it, (images, caps_gt, captions) in enumerate(iter(dataloader)): + images = images.to(device) + with torch.no_grad(): + out, _ = model(mode='rl', images=images, max_len=seq_len, eos_idx=text_field.vocab.stoi[''], beam_size=beam_size, out_size=1) + caps_gen = text_field.decode(out, join_words=False) + for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): + gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) + gen['%d_%d' % (it, i)] = [gen_i, ] + gts['%d_%d' % (it, i)] = gts_i + pbar.update() + + gts = evaluation.PTBTokenizer.tokenize(gts) + gen = evaluation.PTBTokenizer.tokenize(gen) + scores, _ = evaluation.compute_scores(gts, gen) + return scores + +def train_xe(model, dataloader, optim, text_field, scheduler, loss_fn, e, device): + # Training with cross-entropy + model.train() + scheduler.step() + if device == 0: + print('lr = ', optim.state_dict()['param_groups'][0]['lr']) + + running_loss = .0 + with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader), disable=device!=0) as pbar: + for it, (detections, captions) in enumerate(dataloader): + detections, captions = detections.to(device), captions.to(device) + out = model(mode='xe', images=detections, seq=captions) + optim.zero_grad() + captions_gt = captions[:, 1:].contiguous() + out = out[:, :-1].contiguous() + + loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1)) + + loss.backward() + + optim.step() + this_loss = loss.item() + running_loss += this_loss + + pbar.set_postfix(loss=running_loss / (it + 1)) + pbar.update() + + # scheduler.step() + + loss = running_loss / len(dataloader) + return loss + + +def train_scst(model, dataloader, optim, cider, text_field, scheduler_rl, e, device): + # Training with self-critical + # tokenizer_pool = multiprocessing.Pool() + running_reward = .0 + running_reward_baseline = .0 + + model.train() + scheduler_rl.step() + if device == 0: + print('lr = ', optim.state_dict()['param_groups'][0]['lr']) + + running_loss = .0 + seq_len = 20 + beam_size = 5 + # kwargs = { + # 'text_flag': args.text2text + # } + with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader), disable=device!=0) as pbar: + for it, (detections, caps_gt, captions) in enumerate(dataloader): + detections = detections.to(device) + text = captions.to(device) + # kwargs['text'] = text + outs, log_probs = model(mode='rl', images=detections, max_len=seq_len, eos_idx=text_field.vocab.stoi[''], beam_size=beam_size, out_size=beam_size) + optim.zero_grad() + # Rewards + caps_gen = text_field.decode(outs.view(-1, seq_len)) + caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt))) + # caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt]) + caps_gen = evaluation.PTBTokenizer.tokenize(caps_gen) + caps_gt = evaluation.PTBTokenizer.tokenize(caps_gt) + reward = cider.compute_score(caps_gt, caps_gen)[1].astype(np.float32) + reward = torch.from_numpy(reward).to(device).view(detections.shape[0], beam_size) + reward_baseline = torch.mean(reward, -1, keepdim=True) + loss = -torch.mean(log_probs, -1) * (reward - reward_baseline) + + loss = loss.mean() + loss.backward() + optim.step() + + running_loss += loss.item() + running_reward += reward.mean().item() + running_reward_baseline += reward_baseline.mean().item() + # pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1), + # reward_baseline=running_reward_baseline / (it + 1)) + pbar.update() + # scheduler_rl.step() + + loss = running_loss / len(dataloader) + reward = running_reward / len(dataloader) + reward_baseline = running_reward_baseline / len(dataloader) + # tokenizer_pool.close() + return loss, reward, reward_baseline + + +def _changeConfig(config, worldSize): + batchSize = config.batch_size * worldSize + # exponent = math.log2(batchSize) + # scale = 3 - exponent / 2 + # config.xe_base_lr /= (2 ** scale) + # config.rl_base_lr /= (2 ** scale) + config.xe_base_lr *= worldSize + config.rl_base_lr *= worldSize + +def _generalConfig(rank: int, worldSize: int): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "2227" #任意一个没被占用的端口号 + torch.autograd.set_detect_anomaly(False) + torch.backends.cudnn.benchmark = True + random.seed(1234) + torch.manual_seed(1234) + np.random.seed(1234) + torch.cuda.set_device(rank) + dist.init_process_group("nccl", world_size=worldSize, rank=rank) + + +def train(rank, worldSize, args): + _generalConfig(rank, worldSize) + if rank == 0: + print('Rank{}: Transformer Training'.format(rank)) + if rank == 0: + writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name)) + + # Pipeline for image regions + image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False) + # Pipeline for text + text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', remove_punctuation=True, nopoints=False) + + # Create the dataset + dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) + train_dataset, val_dataset, test_dataset = dataset.splits + + if not os.path.isfile('vocab.pkl'): + print("Rank{}: Building vocabulary".format(rank)) + text_field.build_vocab(train_dataset, val_dataset, min_freq=5) + pickle.dump(text_field.vocab, open('vocab.pkl', 'wb')) + else: + print('Rank{}: Loading from vocabulary'.format(rank)) + text_field.vocab = pickle.load(open('vocab.pkl', 'rb')) + + # Model and dataloaders + encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': args.m}) + decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) + + model = Transformer(text_field.vocab.stoi[''], encoder, decoder, args.num_clusters, len(text_field.vocab), 54, text_field.vocab.stoi[''], 512) + model = torch.nn.parallel.DistributedDataParallel(model.to(rank), device_ids=[rank], output_device=rank, broadcast_buffers=False, find_unused_parameters=True) + + dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField(), 'add_text':text_field}) + ref_caps_train = train_dataset.text() + cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train)) + dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField(), 'add_text':text_field}) + dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField(), 'add_text':text_field}) + + ''' + def lambda_lr(s): + warm_up = args.warmup + s += 1 + return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5) + ''' + + def lambda_lr(s): + print("s:", s) + if s <= 3: + lr = args.xe_base_lr * s / 4 + elif s <= 10: + lr = args.xe_base_lr + elif s <= 12: + lr = args.xe_base_lr * 0.2 + else: + lr = args.xe_base_lr * 0.2 * 0.2 + return lr + + def lambda_lr_rl(s): + refine_epoch = args.refine_epoch_rl + print("rl_s:", s) + if s <= refine_epoch: + lr = args.rl_base_lr + elif s <= refine_epoch + 3: + lr = args.rl_base_lr * 0.2 + elif s <= refine_epoch + 6: + lr = args.rl_base_lr * 0.2 * 0.2 + else: + lr = args.rl_base_lr * 0.2 * 0.2 * 0.2 + return lr + + # Initial conditions + optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) + scheduler = LambdaLR(optim, lambda_lr) + + optim_rl = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) + scheduler_rl = LambdaLR(optim_rl, lambda_lr_rl) + + loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi['']) + loss_align = MSELoss() + loss = (loss_fn, loss_align) + use_rl = False + best_cider = .0 + best_test_cider = 0. + patience = 0 + start_epoch = 0 + + if args.resume_last or args.resume_best: + if args.resume_last: + fname = 'saved_transformer_models/%s_last.pth' % args.exp_name + else: + fname = 'saved_transformer_models/%s_best.pth' % args.exp_name + + # fname = 'saved_transformer_models/align_share_K5_init_vlad_last.pth' + + if os.path.exists(fname): + data = torch.load(fname) + torch.set_rng_state(data['torch_rng_state']) + torch.cuda.set_rng_state(data['cuda_rng_state']) + np.random.set_state(data['numpy_rng_state']) + random.setstate(data['random_rng_state']) + model.load_state_dict(data['state_dict'], strict=False) + """ + optim.load_state_dict(data['optimizer']) + scheduler.load_state_dict(data['scheduler']) + """ + start_epoch = data['epoch'] + 1 + best_cider = data['best_cider'] + best_test_cider = data['best_test_cider'] + patience = data['patience'] + use_rl = data['use_rl'] + + if use_rl: + optim.load_state_dict(data['optimizer']) + scheduler.load_state_dict(data['scheduler']) + else: + optim_rl.load_state_dict(data['optimizer']) + scheduler_rl.load_state_dict(data['scheduler']) + + print('Resuming from epoch %d, validation loss %f, best cider %f, and best_test_cider %f' % ( + data['epoch'], data['val_loss'], data['best_cider'], data['best_test_cider'])) + print('patience:', data['patience']) + + print("Training starts") + for e in range(start_epoch, start_epoch + 100): + trainSampler = DistributedSampler(train_dataset, worldSize, rank) + trainSampler.set_epoch(e) + dataloader_train = DataLoader(train_dataset, sampler=trainSampler, batch_size=args.batch_size, pin_memory=True, drop_last=False, num_workers=args.workers, persistent_workers=True) + + dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) + + dict_trainSampler = DistributedSampler(dict_dataset_train, worldSize, rank) + dict_trainSampler.set_epoch(e) + dict_dataloader_train = DataLoader(dict_dataset_train, sampler=dict_trainSampler, batch_size=args.batch_size // 5, pin_memory=True, drop_last=False, num_workers=args.workers, persistent_workers=True) + + dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5) + dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5) + + if not use_rl: + train_loss = train_xe(model, dataloader_train, optim, text_field, scheduler, loss_fn, e, rank) + if rank == 0: + writer.add_scalar('data/train_loss', train_loss, e) + else: + train_loss, reward, reward_baseline = train_scst(model, dict_dataloader_train, optim_rl, cider_train, text_field, scheduler_rl, e, rank) + if rank == 0: + writer.add_scalar('data/train_loss', train_loss, e) + writer.add_scalar('data/reward', reward, e) + writer.add_scalar('data/reward_baseline', reward_baseline, e) + + # Validation loss + val_loss = evaluate_loss(model, dataloader_val, loss, text_field, e, rank) + if rank == 0: + writer.add_scalar('data/val_loss', val_loss, e) + + # Validation scores + scores = evaluate_metrics(model, dict_dataloader_val, text_field, e, rank) + val_cider = scores['CIDEr'] + if rank == 0: + print("Validation scores", scores) + writer.add_scalar('data/val_cider', val_cider, e) + writer.add_scalar('data/val_bleu1', scores['BLEU'][0], e) + writer.add_scalar('data/val_bleu4', scores['BLEU'][3], e) + writer.add_scalar('data/val_meteor', scores['METEOR'], e) + writer.add_scalar('data/val_rouge', scores['ROUGE'], e) + + # Test scores + scores = evaluate_metrics(model, dict_dataloader_test, text_field, e, rank) + test_cider = scores['CIDEr'] + if rank == 0: + print("Test scores", scores) + writer.add_scalar('data/test_cider', test_cider, e) + writer.add_scalar('data/test_bleu1', scores['BLEU'][0], e) + writer.add_scalar('data/test_bleu4', scores['BLEU'][3], e) + writer.add_scalar('data/test_meteor', scores['METEOR'], e) + writer.add_scalar('data/test_rouge', scores['ROUGE'], e) + + # Prepare for next epoch + best = False + if val_cider >= best_cider: + best_cider = val_cider + patience = 0 + best = True + else: + patience += 1 + + best_test = False + if test_cider >= best_test_cider: + best_test_cider = test_cider + best_test = True + + switch_to_rl = False + exit_train = False + + if patience == 5: + if e < args.xe_least: # xe stage train 15 epoches at least + if rank == 0: + print('special treatment, e = {}'.format(e)) + use_rl = False + switch_to_rl = False + patience = 0 + elif not use_rl: + use_rl = True + switch_to_rl = True + patience = 0 + + optim_rl = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) + scheduler_rl = LambdaLR(optim_rl, lambda_lr_rl) + + for k in range(e-1): + scheduler_rl.step() + if rank == 0: + print("Switching to RL") + else: + if rank == 0: + print('patience reached.') + exit_train = True + + if e == args.xe_most: # xe stage no more than 20 epoches + if not use_rl: + use_rl = True + switch_to_rl = True + patience = 0 + + optim_rl = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) + scheduler_rl = LambdaLR(optim_rl, lambda_lr_rl) + + for k in range(e-1): + scheduler_rl.step() + if rank == 0: + print("Switching to RL") + if rank == 0: + if switch_to_rl and not best: + data = torch.load('saved_transformer_models/%s_best.pth' % args.exp_name) + torch.set_rng_state(data['torch_rng_state']) + torch.cuda.set_rng_state(data['cuda_rng_state']) + np.random.set_state(data['numpy_rng_state']) + random.setstate(data['random_rng_state']) + model.load_state_dict(data['state_dict']) + print('Resuming from epoch %d, validation loss %f, best_cider %f, and best test_cider %f' % ( + data['epoch'], data['val_loss'], data['best_cider'], data['best_test_cider'])) + + torch.save({ + 'torch_rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state(), + 'numpy_rng_state': np.random.get_state(), + 'random_rng_state': random.getstate(), + 'epoch': e, + 'val_loss': val_loss, + 'val_cider': val_cider, + 'state_dict': model.state_dict(), + 'optimizer': optim.state_dict() if not use_rl else optim_rl.state_dict(), + 'scheduler': scheduler.state_dict() if not use_rl else scheduler_rl.state_dict(), + 'patience': patience, + 'best_cider': best_cider, + 'best_test_cider': best_test_cider, + 'use_rl': use_rl, + }, 'saved_transformer_models/%s_last.pth' % args.exp_name) + + if best: + copyfile('saved_transformer_models/%s_last.pth' % args.exp_name, 'saved_transformer_models/%s_best.pth' % args.exp_name) + if best_test: + copyfile('saved_transformer_models/%s_last.pth' % args.exp_name, 'saved_transformer_models/%s_best_test.pth' % args.exp_name) + + if exit_train: + if rank==0: + writer.close() + break + + +if __name__ == '__main__': + # device = torch.device('cuda') + + parser = argparse.ArgumentParser(description='Transformer') + parser.add_argument('--exp_name', type=str, default='demo') + parser.add_argument('--batch_size', type=int, default=50) + parser.add_argument('--workers', type=int, default=4) + parser.add_argument('--m', type=int, default=40) + parser.add_argument('--head', type=int, default=8) + parser.add_argument('--warmup', type=int, default=10000) + parser.add_argument('--resume_last', action='store_true') + parser.add_argument('--resume_best', action='store_true') + parser.add_argument('--features_path', type=str, default='./X101-features/X101_grid_feats_coco_trainval.hdf5') + parser.add_argument('--annotation_folder', type=str, default='./m2_annotations') + + parser.add_argument('--logs_folder', type=str, default='tensorboard_logs') + parser.add_argument('--xe_least', type=int, default=15) + parser.add_argument('--xe_most', type=int, default=20) # 18 + parser.add_argument('--refine_epoch_rl', type=int, default=28) # 35 + + parser.add_argument('--xe_base_lr', type=float, default=0.0001) + parser.add_argument('--rl_base_lr', type=float, default=5e-6) + parser.add_argument('--num_clusters', type=int, default=5) + parser.add_argument('--text2text', type=int, default=0) + + args = parser.parse_args() + print(args) + ## DDP Training + worldSize = 8 + _changeConfig(args, worldSize) + print('\nDistribute config', args) + mp.spawn(train, (worldSize, args), worldSize) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..0162b87 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,17 @@ +from .utils import download_from_url +from .typing import * + +def get_batch_size(x: TensorOrSequence) -> int: + if isinstance(x, torch.Tensor): + b_s = x.size(0) + else: + b_s = x[0].size(0) + return b_s + + +def get_device(x: TensorOrSequence) -> int: + if isinstance(x, torch.Tensor): + b_s = x.device + else: + b_s = x[0].device + return b_s diff --git a/utils/typing.py b/utils/typing.py new file mode 100644 index 0000000..3270f86 --- /dev/null +++ b/utils/typing.py @@ -0,0 +1,5 @@ +from typing import Union, Sequence, Tuple +import torch + +TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] +TensorOrNone = Union[torch.Tensor, None] diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..ef80d23 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,27 @@ +import requests + +def download_from_url(url, path): + """Download file, with logic (from tensor2tensor) for Google Drive""" + if 'drive.google.com' not in url: + print('Downloading %s; may take a few minutes' % url) + r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) + with open(path, "wb") as file: + file.write(r.content) + return + print('Downloading from Google Drive; may take a few minutes') + confirm_token = None + session = requests.Session() + response = session.get(url, stream=True) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + confirm_token = v + + if confirm_token: + url = url + "&confirm=" + confirm_token + response = session.get(url, stream=True) + + chunk_size = 16 * 1024 + with open(path, "wb") as f: + for chunk in response.iter_content(chunk_size): + if chunk: + f.write(chunk) diff --git a/vocab.pkl b/vocab.pkl new file mode 100644 index 0000000..091dbb4 Binary files /dev/null and b/vocab.pkl differ diff --git a/vocab_idx2word.pkl b/vocab_idx2word.pkl new file mode 100644 index 0000000..abaa00b Binary files /dev/null and b/vocab_idx2word.pkl differ diff --git a/vocab_language/vocab_bert_language.pkl b/vocab_language/vocab_bert_language.pkl new file mode 100644 index 0000000..1858f6a Binary files /dev/null and b/vocab_language/vocab_bert_language.pkl differ