diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..337c774
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+/X101-features
+/X152-features
+/m2_annotations
+evaluation/spice/*
+*.pyc
+*.jar
+/saved_transformer_models
+/tensorboard_logs
+/visualization
+/.vscode
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..1f0dcad
--- /dev/null
+++ b/README.md
@@ -0,0 +1,112 @@
+# $\mathcal{S}^2$ Transformer for Image Captioning
+
+[](https://www.python.org/)
+[](https://github.com/zchoi/S2-Transformer/blob/main/LICENSE)
+[](https://pytorch.org/)
+
+This repository contains the official code implementation for the paper [_S2 Transformer for Image Captioning_ (IJCAI 2022)](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhang_RSTNet_Captioning_With_Adaptive_Attention_on_Visual_and_Non-Visual_Words_CVPR_2021_paper.pdf).
+
+
+
+
+
+## Table of Contents
+- [Environment setup](#environment-setup)
+- [Data Preparation](#data-preparation)
+- [Training](#training)
+- [Evaluation](#evaluation)
+- [Reference and Citation](#reference-and-citation)
+- [Acknowledgements](#acknowledgements)
+
+## Environment setup
+
+Clone this repository and create the `m2release` conda environment using the `environment.yml` file:
+```
+conda env create -f environment.yaml
+conda activate m2release
+```
+
+Then download spacy data by executing the following command:
+```
+python -m spacy download en_core_web_md
+```
+
+**Note:** Python 3 is required to run our code. If you suffer network problems, please download ```en_core_web_md``` library from [here](https://drive.google.com/file/d/1jf6ecYDzIomaGt3HgOqO_7rEL6oiTjgN/view?usp=sharing), unzip and place it to ```/your/anaconda/path/envs/m2release/lib/python*/site-packages/```
+
+
+## Data Preparation
+
+* **Annotation**. Download the annotation file [annotation.zip](https://drive.google.com/file/d/1Zc2P3-MIBg3JcHT1qKeYuQt9CnQcY5XJ/view?usp=sharing) [1]. Extract and put it in the project root directory.
+* **Feature**. Download processed image features [ResNeXt-101](https://stduestceducn-my.sharepoint.com/:f:/g/personal/zhn_std_uestc_edu_cn/EssZY4Xdb0JErCk0A1Yx3vUBaRbXau88scRvYw4r1ZuwPg?e=f2QFGp) and [ResNeXt-152](https://stduestceducn-my.sharepoint.com/:f:/g/personal/zhn_std_uestc_edu_cn/EssZY4Xdb0JErCk0A1Yx3vUBaRbXau88scRvYw4r1ZuwPg?e=f2QFGp) features [2], put it in the project root directory.
+
+
+
+## Training
+Run `python train_transformer.py` using the following arguments:
+
+| Argument | Possible values |
+|------|------|
+| `--exp_name` | Experiment name|
+| `--batch_size` | Batch size (default: 50) |
+| `--workers` | Number of workers, accelerate model training in the xe stage.|
+| `--head` | Number of heads (default: 8) |
+| `--resume_last` | If used, the training will be resumed from the last checkpoint. |
+| `--resume_best` | If used, the training will be resumed from the best checkpoint. |
+| `--features_path` | Path to visual features file (h5py)|
+| `--annotation_folder` | Path to annotations |
+| `--num_clusters` | Number of pseudo regions |
+
+For example, to train the model, run the following command:
+```
+python train_transformer.py --exp_name S2 --batch_size 50 --m 40 --head 8 --features_path /path/to/features --num_clusters 5
+```
+or just run:
+```
+bash train.sh
+```
+**Note:** We apply `torch.distributed` to train our model, you can set the `worldSize` in [train_transformer.py]() to determine the number of GPUs for your training.
+
+## Evaluation
+### Offline Evaluation.
+Run `python test_transformer.py` to evaluate the model using the following arguments:
+```
+python test_transformer.py --batch_size 10 --features_path /path/to/features --model_path /path/to/saved_transformer_models/ckpt --num_clusters 5
+```
+
+**Note:** We have removed the ```SPICE``` evaluation metric during training because it is time-cost. You can add it when evaluate the model: download this [file](https://drive.google.com/file/d/1vEVsbEFjDstmSvoWhu4UdKaJjX1jJXpR/view?usp=sharing) and put it in ```/path/to/evaluation/```, then uncomment codes in [init代码]().
+
+We provide pretrained model [here](https://drive.google.com/file/d/1Y133r4Wd9ediS1Jqlwc1qtL15vCK_Mik/view?usp=sharing), you will get following results (second row) by evaluating the pretrained model:
+
+| Model | B@1 | B@4 | M | R | C | S |
+|:---------: |:-------: |:-: |:---------------: |:--------------------------: |:-------: | :-------:|
+| Our Paper (ResNext101) | 81.1 | 39.6 | 29.6 | 59.1 | 133.5 | 23.2|
+| Reproduced Model (ResNext101) | 81.2 | 39.9 | 29.6 | 59.1 | 133.7 | 23.3|
+
+
+
+### Online Evaluation
+We also report the performance of our model on the online COCO test server with an ensemble of four S2 models. The detailed online test code can be obtained in this [repo](https://github.com/zhangxuying1004/RSTNet).
+
+## Reference and Citation
+### Reference
+[1] Cornia, M., Stefanini, M., Baraldi, L., & Cucchiara, R. (2020). Meshed-memory transformer for image captioning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition.
+[2] Xuying Zhang, Xiaoshuai Sun, Yunpeng Luo, Jiayi Ji, Yiyi Zhou, Yongjian Wu, Feiyue
+Huang, and Rongrong Ji. Rstnet: Captioning with adaptive attention on visual and non-visual words. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15465–15474, 2021.
+### Citation
+```
+@inproceedings{S2,
+ author = {Pengpeng Zeng and
+ Haonan Zhang and
+ Jingkuan Song and
+ Lianli Gao},
+ title = {S2 Transformer for Image Captioning},
+ booktitle = {IJCAI},
+ % pages = {????--????}
+ year = {2022}
+}
+```
+## Acknowledgements
+Thanks Zhang _et.al_ for releasing the visual features (ResNeXt-101 and ResNeXt-152). Our code implementation is also based on their [repo](https://github.com/zhangxuying1004/RSTNet).
+Thanks for the original annotations prepared by [M2 Transformer](https://github.com/aimagelab/meshed-memory-transformer), and effective visual representation from [grid-feats-vqa](https://github.com/facebookresearch/grid-feats-vqa).
+
+
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..202e709
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,7 @@
+from .field import RawField, Merge, ImageDetectionsField, TextField
+from .dataset import COCO
+from torch.utils.data import DataLoader as TorchDataLoader
+
+class DataLoader(TorchDataLoader):
+ def __init__(self, dataset, *args, **kwargs):
+ super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)
diff --git a/data/dataset.py b/data/dataset.py
new file mode 100644
index 0000000..264684f
--- /dev/null
+++ b/data/dataset.py
@@ -0,0 +1,284 @@
+import os
+import numpy as np
+import itertools
+import collections
+import torch
+from .example import Example
+from .utils import nostdout
+from pycocotools.coco import COCO as pyCOCO
+
+
+class Dataset(object):
+ def __init__(self, examples, fields):
+ self.examples = examples
+ self.fields = dict(fields)
+
+ def collate(self, batch):
+ if len(self.fields) == 1:
+ batch = [batch, ]
+ else:
+ batch = list(zip(*batch))
+
+ tensors = []
+ for field, data in zip(self.fields.values(), batch):
+ tensor = field.process(data)
+ if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
+ tensors.extend(tensor)
+ else:
+ tensors.append(tensor)
+
+ if len(tensors) > 1:
+ return tensors
+ else:
+ return tensors[0]
+
+ def collate_fn(self):
+ return self.collate
+
+ def text(self):
+ return [x.text for x in self.examples]
+
+ def __getitem__(self, i):
+ example = self.examples[i]
+ data = []
+ for field_name, field in self.fields.items():
+ data.append(field.preprocess(getattr(example, field_name)))
+
+ if len(data) == 1:
+ data = data[0]
+ return data
+
+ def __len__(self):
+ return len(self.examples)
+
+
+class ValueDataset(Dataset):
+ def __init__(self, examples, fields, dictionary):
+ super(ValueDataset, self).__init__(examples, fields)
+ self.dictionary = dictionary
+
+ def valuecollate(self, batch):
+ value_batch_flattened = list(itertools.chain(*batch))
+ value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened)
+
+ lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch]))
+ if isinstance(value_tensors_flattened, collections.Sequence) \
+ and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened):
+ value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened]
+ else:
+ value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
+
+ return value_tensors
+
+ def collate_fn(self):
+ return self.valuecollate
+
+ def __getitem__(self, i):
+ if i not in self.dictionary:
+ raise IndexError
+
+ values_data = []
+ for idx in self.dictionary[i]:
+ value_data = super(ValueDataset, self).__getitem__(idx)
+ values_data.append(value_data)
+ return values_data
+
+ def __len__(self):
+ return len(self.dictionary)
+
+
+class DictionaryDataset(Dataset):
+ def __init__(self, examples, fields, key_fields):
+ if not isinstance(key_fields, (tuple, list)):
+ key_fields = (key_fields,)
+ for field in key_fields:
+ assert (field in fields)
+ caption_field = fields.pop('add_text')
+ dictionary = collections.defaultdict(list)
+ # {'image': image_field}
+ key_fields = {k: fields[k] for k in key_fields}
+ # {’text':RawField()}
+ value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields}
+ key_examples = []
+ key_dict = dict()
+ value_examples = []
+
+ for i, e in enumerate(examples):
+ key_example = Example.fromdict({k: getattr(e, k) for k in key_fields})
+ value_example = Example.fromdict({v: getattr(e, v) for v in value_fields})
+ if key_example not in key_dict:
+ key_dict[key_example] = len(key_examples)
+ key_examples.append(key_example)
+
+ value_examples.append(value_example)
+ dictionary[key_dict[key_example]].append(i)
+
+ self.key_dataset = Dataset(key_examples, key_fields)
+ self.value_dataset = ValueDataset(value_examples, value_fields, dictionary)
+ self.not_key_dataset = Dataset(value_examples, {'text': caption_field})
+
+ super(DictionaryDataset, self).__init__(examples, fields)
+
+ def dictcollate(self, batch):
+ key_batch, value_batch, not_key_batch = list(zip(*batch))
+ key_tensors = self.key_dataset.collate_fn()(key_batch)
+ value_tensors = self.value_dataset.collate_fn()(value_batch)
+ not_key_tensors = self.not_key_dataset.collate_fn()(not_key_batch)
+ return key_tensors, value_tensors, not_key_tensors
+ def collate_fn(self):
+ return self.dictcollate
+
+
+ def __getitem__(self, i):
+ return self.key_dataset[i], self.value_dataset[i], self.not_key_dataset[i]
+
+ def __len__(self):
+ return len(self.key_dataset)
+
+
+def unique(sequence):
+ seen = set()
+ if isinstance(sequence[0], list):
+ return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))]
+ else:
+ return [x for x in sequence if not (x in seen or seen.add(x))]
+
+
+class PairedDataset(Dataset):
+ def __init__(self, examples, fields):
+ assert ('image' in fields)
+ assert ('text' in fields)
+ super(PairedDataset, self).__init__(examples, fields)
+ self.image_field = self.fields['image']
+ self.text_field = self.fields['text']
+
+ def image_set(self):
+ img_list = [e.image for e in self.examples]
+ image_set = unique(img_list)
+ examples = [Example.fromdict({'image': i}) for i in image_set]
+ dataset = Dataset(examples, {'image': self.image_field})
+ return dataset
+
+ def text_set(self):
+ text_list = [e.text for e in self.examples]
+ text_list = unique(text_list)
+ examples = [Example.fromdict({'text': t}) for t in text_list]
+ dataset = Dataset(examples, {'text': self.text_field})
+ return dataset
+
+ def image_dictionary(self, fields=None):
+ if not fields:
+ fields = self.fields
+ dataset = DictionaryDataset(self.examples, fields, key_fields='image')
+ return dataset
+
+ def text_dictionary(self, fields=None):
+ if not fields:
+ fields = self.fields
+ dataset = DictionaryDataset(self.examples, fields, key_fields='text')
+ return dataset
+
+ @property
+ def splits(self):
+ raise NotImplementedError
+
+
+class COCO(PairedDataset):
+ def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True,
+ cut_validation=False):
+ roots = {}
+ roots['train'] = {
+ 'img': os.path.join(img_root, 'train2014'),
+ 'cap': os.path.join(ann_root, 'captions_train2014.json')
+ }
+ roots['val'] = {
+ 'img': os.path.join(img_root, 'val2014'),
+ 'cap': os.path.join(ann_root, 'captions_val2014.json')
+ }
+ roots['test'] = {
+ 'img': os.path.join(img_root, 'val2014'),
+ 'cap': os.path.join(ann_root, 'captions_val2014.json')
+ }
+ roots['trainrestval'] = {
+ 'img': (roots['train']['img'], roots['val']['img']),
+ 'cap': (roots['train']['cap'], roots['val']['cap'])
+ }
+
+ if id_root is not None:
+ ids = {}
+ ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy'))
+ ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy'))
+ if cut_validation:
+ ids['val'] = ids['val'][:5000]
+ ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy'))
+ ids['trainrestval'] = (
+ ids['train'],
+ np.load(os.path.join(id_root, 'coco_restval_ids.npy')))
+
+ if use_restval:
+ roots['train'] = roots['trainrestval']
+ ids['train'] = ids['trainrestval']
+ else:
+ ids = None
+
+ with nostdout():
+ self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids)
+ examples = self.train_examples + self.val_examples + self.test_examples
+ super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field})
+
+ @property
+ def splits(self):
+ train_split = PairedDataset(self.train_examples, self.fields)
+ val_split = PairedDataset(self.val_examples, self.fields)
+ test_split = PairedDataset(self.test_examples, self.fields)
+ return train_split, val_split, test_split
+
+ @classmethod
+ def get_samples(cls, roots, ids_dataset=None):
+ train_samples = []
+ val_samples = []
+ test_samples = []
+
+ for split in ['train', 'val', 'test']:
+ if isinstance(roots[split]['cap'], tuple):
+ coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1]))
+ root = roots[split]['img']
+ else:
+ coco_dataset = (pyCOCO(roots[split]['cap']),)
+ root = (roots[split]['img'],)
+
+ if ids_dataset is None:
+ ids = list(coco_dataset.anns.keys())
+ else:
+ ids = ids_dataset[split]
+
+ if isinstance(ids, tuple):
+ bp = len(ids[0])
+ ids = list(ids[0]) + list(ids[1])
+ else:
+ bp = len(ids)
+
+ for index in range(len(ids)):
+ if index < bp:
+ coco = coco_dataset[0]
+ img_root = root[0]
+ else:
+ coco = coco_dataset[1]
+ img_root = root[1]
+
+ ann_id = ids[index]
+ caption = coco.anns[ann_id]['caption']
+ img_id = coco.anns[ann_id]['image_id']
+ filename = coco.loadImgs(img_id)[0]['file_name']
+
+ example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption})
+
+ if split == 'train':
+ train_samples.append(example)
+ elif split == 'val':
+ val_samples.append(example)
+ elif split == 'test':
+ test_samples.append(example)
+
+ return train_samples, val_samples, test_samples
+
diff --git a/data/example.py b/data/example.py
new file mode 100644
index 0000000..d46c07f
--- /dev/null
+++ b/data/example.py
@@ -0,0 +1,27 @@
+
+class Example(object):
+ """Defines a single training or test example.
+ Stores each column of the example as an attribute.
+ """
+ @classmethod
+ def fromdict(cls, data):
+ ex = cls(data)
+ return ex
+
+ def __init__(self, data):
+ for key, val in data.items():
+ super(Example, self).__setattr__(key, val)
+
+ def __setattr__(self, key, value):
+ raise AttributeError
+
+ def __hash__(self):
+ return hash(tuple(x for x in self.__dict__.values()))
+
+ def __eq__(self, other):
+ this = tuple(x for x in self.__dict__.values())
+ other = tuple(x for x in other.__dict__.values())
+ return this == other
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
diff --git a/data/field.py b/data/field.py
new file mode 100644
index 0000000..5173096
--- /dev/null
+++ b/data/field.py
@@ -0,0 +1,329 @@
+# coding: utf8
+from collections import Counter, OrderedDict
+from torch.utils.data.dataloader import default_collate
+from itertools import chain
+import six
+import torch
+import numpy as np
+import h5py
+import os
+import warnings
+import shutil
+
+from .dataset import Dataset
+from .vocab import Vocab
+from .utils import get_tokenizer
+
+
+class RawField(object):
+ """ Defines a general datatype.
+
+ Every dataset consists of one or more types of data. For instance,
+ a machine translation dataset contains paired examples of text, while
+ an image captioning dataset contains images and texts.
+ Each of these types of data is represented by a RawField object.
+ An RawField object does not assume any property of the data type and
+ it holds parameters relating to how a datatype should be processed.
+
+ Attributes:
+ preprocessing: The Pipeline that will be applied to examples
+ using this field before creating an example.
+ Default: None.
+ postprocessing: A Pipeline that will be applied to a list of examples
+ using this field before assigning to a batch.
+ Function signature: (batch(list)) -> object
+ Default: None.
+ """
+
+ def __init__(self, preprocessing=None, postprocessing=None):
+ self.preprocessing = preprocessing
+ self.postprocessing = postprocessing
+
+ def preprocess(self, x):
+ """ Preprocess an example if the `preprocessing` Pipeline is provided. """
+ if self.preprocessing is not None:
+ return self.preprocessing(x)
+ else:
+ return x
+
+ def process(self, batch, *args, **kwargs):
+ """ Process a list of examples to create a batch.
+
+ Postprocess the batch with user-provided Pipeline.
+
+ Args:
+ batch (list(object)): A list of object from a batch of examples.
+ Returns:
+ object: Processed object given the input and custom
+ postprocessing Pipeline.
+ """
+ if self.postprocessing is not None:
+ batch = self.postprocessing(batch)
+ return default_collate(batch)
+
+
+class Merge(RawField):
+ def __init__(self, *fields):
+ super(Merge, self).__init__()
+ self.fields = fields
+
+ def preprocess(self, x):
+ return tuple(f.preprocess(x) for f in self.fields)
+
+ def process(self, batch, *args, **kwargs):
+ if len(self.fields) == 1:
+ batch = [batch, ]
+ else:
+ batch = list(zip(*batch))
+
+ out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch))
+ return out
+
+class ImageDetectionsField(RawField):
+ def __init__(self, preprocessing=None, postprocessing=None, detections_path=None, max_detections=100,
+ sort_by_prob=False, load_in_tmp=True):
+ self.max_detections = max_detections
+ self.detections_path = detections_path
+ self.sort_by_prob = sort_by_prob
+
+ tmp_detections_path = os.path.join('/tmp', os.path.basename(detections_path))
+
+ if load_in_tmp:
+ if not os.path.isfile(tmp_detections_path):
+ if shutil.disk_usage("/tmp")[-1] < os.path.getsize(detections_path):
+ warnings.warn('Loading from %s, because /tmp has no enough space.' % detections_path)
+ else:
+ warnings.warn("Copying detection file to /tmp")
+ shutil.copyfile(detections_path, tmp_detections_path)
+ warnings.warn("Done.")
+ self.detections_path = tmp_detections_path
+ else:
+ self.detections_path = tmp_detections_path
+
+ super(ImageDetectionsField, self).__init__(preprocessing, postprocessing)
+
+ def preprocess(self, x, avoid_precomp=False):
+ image_id = int(x.split('_')[-1].split('.')[0])
+ try:
+ f = h5py.File(self.detections_path, 'r')
+# precomp_data = f['%d_features' % image_id][()]
+ precomp_data = f['%d_grids' % image_id][()]
+ if self.sort_by_prob:
+ precomp_data = precomp_data[np.argsort(np.max(f['%d_cls_prob' % image_id][()], -1))[::-1]]
+ except KeyError:
+ warnings.warn('Could not find detections for %d' % image_id)
+ precomp_data = np.random.rand(10,2048)
+
+ delta = self.max_detections - precomp_data.shape[0]
+ if delta > 0:
+ precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0)
+ elif delta < 0:
+ precomp_data = precomp_data[:self.max_detections]
+
+ return precomp_data.astype(np.float32)
+
+
+class TextField(RawField):
+ vocab_cls = Vocab
+ # Dictionary mapping PyTorch tensor dtypes to the appropriate Python
+ # numeric type.
+ dtypes = {
+ torch.float32: float,
+ torch.float: float,
+ torch.float64: float,
+ torch.double: float,
+ torch.float16: float,
+ torch.half: float,
+
+ torch.uint8: int,
+ torch.int8: int,
+ torch.int16: int,
+ torch.short: int,
+ torch.int32: int,
+ torch.int: int,
+ torch.int64: int,
+ torch.long: int,
+ }
+ punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
+ ".", "?", "!", ",", ":", "-", "--", "...", ";"]
+
+ def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long,
+ preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()),
+ remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="",
+ unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True):
+ self.use_vocab = use_vocab
+ self.init_token = init_token
+ self.eos_token = eos_token
+ self.fix_length = fix_length
+ self.dtype = dtype
+ self.lower = lower
+ self.tokenize = get_tokenizer(tokenize)
+ self.remove_punctuation = remove_punctuation
+ self.include_lengths = include_lengths
+ self.batch_first = batch_first
+ self.pad_token = pad_token
+ self.unk_token = unk_token
+ self.pad_first = pad_first
+ self.truncate_first = truncate_first
+ self.vocab = None
+ self.vectors = vectors
+ if nopoints:
+ self.punctuations.append("..")
+
+ super(TextField, self).__init__(preprocessing, postprocessing)
+
+ def preprocess(self, x):
+ if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type):
+ x = six.text_type(x, encoding='utf-8')
+ if self.lower:
+ x = six.text_type.lower(x)
+ x = self.tokenize(x.rstrip('\n'))
+ if self.remove_punctuation:
+ x = [w for w in x if w not in self.punctuations]
+ if self.preprocessing is not None:
+ return self.preprocessing(x)
+ else:
+ return x
+
+ def process(self, batch, device=None):
+ padded = self.pad(batch)
+ tensor = self.numericalize(padded, device=device)
+ return tensor
+
+ def build_vocab(self, *args, **kwargs):
+ counter = Counter()
+ sources = []
+ for arg in args:
+ if isinstance(arg, Dataset):
+ sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self]
+ else:
+ sources.append(arg)
+
+ for data in sources:
+ for x in data:
+ x = self.preprocess(x)
+ try:
+ counter.update(x)
+ except TypeError:
+ counter.update(chain.from_iterable(x))
+
+ specials = list(OrderedDict.fromkeys([
+ tok for tok in [self.unk_token, self.pad_token, self.init_token,
+ self.eos_token]
+ if tok is not None]))
+ self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
+
+ def pad(self, minibatch):
+ """Pad a batch of examples using this field.
+ Pads to self.fix_length if provided, otherwise pads to the length of
+ the longest example in the batch. Prepends self.init_token and appends
+ self.eos_token if those attributes are not None. Returns a tuple of the
+ padded list and a list containing lengths of each example if
+ `self.include_lengths` is `True`, else just
+ returns the padded list.
+ """
+ minibatch = list(minibatch)
+ if self.fix_length is None:
+ max_len = max(len(x) for x in minibatch)
+ else:
+ max_len = self.fix_length + (
+ self.init_token, self.eos_token).count(None) - 2
+ padded, lengths = [], []
+ for x in minibatch:
+ if self.pad_first:
+ padded.append(
+ [self.pad_token] * max(0, max_len - len(x)) +
+ ([] if self.init_token is None else [self.init_token]) +
+ list(x[-max_len:] if self.truncate_first else x[:max_len]) +
+ ([] if self.eos_token is None else [self.eos_token]))
+ else:
+ padded.append(
+ ([] if self.init_token is None else [self.init_token]) +
+ list(x[-max_len:] if self.truncate_first else x[:max_len]) +
+ ([] if self.eos_token is None else [self.eos_token]) +
+ [self.pad_token] * max(0, max_len - len(x)))
+ lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
+ if self.include_lengths:
+ return padded, lengths
+ return padded
+
+ def numericalize(self, arr, device=None):
+ """Turn a batch of examples that use this field into a list of Variables.
+ If the field has include_lengths=True, a tensor of lengths will be
+ included in the return value.
+ Arguments:
+ arr (List[List[str]], or tuple of (List[List[str]], List[int])):
+ List of tokenized and padded examples, or tuple of List of
+ tokenized and padded examples and List of lengths of each
+ example if self.include_lengths is True.
+ device (str or torch.device): A string or instance of `torch.device`
+ specifying which device the Variables are going to be created on.
+ If left as default, the tensors will be created on cpu. Default: None.
+ """
+ if self.include_lengths and not isinstance(arr, tuple):
+ raise ValueError("Field has include_lengths set to True, but "
+ "input data is not a tuple of "
+ "(data batch, batch lengths).")
+ if isinstance(arr, tuple):
+ arr, lengths = arr
+ lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
+
+ if self.use_vocab:
+ arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
+
+ if self.postprocessing is not None:
+ arr = self.postprocessing(arr, self.vocab)
+
+ var = torch.tensor(arr, dtype=self.dtype, device=device)
+ else:
+ if self.vectors:
+ arr = [[self.vectors[x] for x in ex] for ex in arr]
+ if self.dtype not in self.dtypes:
+ raise ValueError(
+ "Specified Field dtype {} can not be used with "
+ "use_vocab=False because we do not know how to numericalize it. "
+ "Please raise an issue at "
+ "https://github.com/pytorch/text/issues".format(self.dtype))
+ numericalization_func = self.dtypes[self.dtype]
+ # It doesn't make sense to explictly coerce to a numeric type if
+ # the data is sequential, since it's unclear how to coerce padding tokens
+ # to a numeric type.
+ arr = [numericalization_func(x) if isinstance(x, six.string_types)
+ else x for x in arr]
+
+ if self.postprocessing is not None:
+ arr = self.postprocessing(arr, None)
+
+ var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr])
+
+ # var = torch.tensor(arr, dtype=self.dtype, device=device)
+ if not self.batch_first:
+ var.t_()
+ var = var.contiguous()
+
+ if self.include_lengths:
+ return var, lengths
+ return var
+
+ def decode(self, word_idxs, join_words=True):
+ if isinstance(word_idxs, list) and len(word_idxs) == 0:
+ return self.decode([word_idxs, ], join_words)[0]
+ if isinstance(word_idxs, list) and isinstance(word_idxs[0], int):
+ return self.decode([word_idxs, ], join_words)[0]
+ elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1:
+ return self.decode(word_idxs.reshape((1, -1)), join_words)[0]
+ elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1:
+ return self.decode(word_idxs.unsqueeze(0), join_words)[0]
+
+ captions = []
+ for wis in word_idxs:
+ caption = []
+ for wi in wis:
+ word = self.vocab.itos[int(wi)]
+ if word == self.eos_token:
+ break
+ caption.append(word)
+ if join_words:
+ caption = ' '.join(caption)
+ captions.append(caption)
+ return captions
diff --git a/data/utils.py b/data/utils.py
new file mode 100644
index 0000000..e1dc21d
--- /dev/null
+++ b/data/utils.py
@@ -0,0 +1,91 @@
+import contextlib, sys
+
+class DummyFile(object):
+ def write(self, x): pass
+
+@contextlib.contextmanager
+def nostdout():
+ save_stdout = sys.stdout
+ sys.stdout = DummyFile()
+ yield
+ sys.stdout = save_stdout
+
+def reporthook(t):
+ """https://github.com/tqdm/tqdm"""
+ last_b = [0]
+
+ def inner(b=1, bsize=1, tsize=None):
+ """
+ b: int, optionala
+ Number of blocks just transferred [default: 1].
+ bsize: int, optional
+ Size of each block (in tqdm units) [default: 1].
+ tsize: int, optional
+ Total size (in tqdm units). If [default: None] remains unchanged.
+ """
+ if tsize is not None:
+ t.total = tsize
+ t.update((b - last_b[0]) * bsize)
+ last_b[0] = b
+ return inner
+
+import revtok
+import spacy
+spacy_en = spacy.load('en_core_web_md')
+
+
+def _getTokenizerrevtok(x):
+ return revtok.tokenize(x, decap=True)
+
+def _getTokenizerspacy(s):
+ return [tok.text for tok in spacy_en.tokenizer(s)]
+
+
+def get_tokenizer(tokenizer):
+ if callable(tokenizer):
+ return tokenizer
+ if tokenizer == "spacy":
+ try:
+ return _getTokenizerspacy
+ except ImportError:
+ print("Please install SpaCy and the SpaCy English tokenizer. "
+ "See the docs at https://spacy.io for more information.")
+ raise
+ except AttributeError:
+ print("Please install SpaCy and the SpaCy English tokenizer. "
+ "See the docs at https://spacy.io for more information.")
+ raise
+ elif tokenizer == "moses":
+ try:
+ from nltk.tokenize.moses import MosesTokenizer
+ moses_tokenizer = MosesTokenizer()
+ return moses_tokenizer.tokenize
+ except ImportError:
+ print("Please install NLTK. "
+ "See the docs at http://nltk.org for more information.")
+ raise
+ except LookupError:
+ print("Please install the necessary NLTK corpora. "
+ "See the docs at http://nltk.org for more information.")
+ raise
+ elif tokenizer == 'revtok':
+ try:
+ import revtok
+ return revtok.tokenize
+ except ImportError:
+ print("Please install revtok.")
+ raise
+ elif tokenizer == 'subword':
+ try:
+ return _getTokenizerrevtok
+ except ImportError:
+ print("Please install revtok.")
+ raise
+ raise ValueError("Requested tokenizer {}, valid choices are a "
+ "callable that takes a single string as input, "
+ "\"revtok\" for the revtok reversible tokenizer, "
+ "\"subword\" for the revtok caps-aware tokenizer, "
+ "\"spacy\" for the SpaCy English tokenizer, or "
+ "\"moses\" for the NLTK port of the Moses tokenization "
+ "script.".format(tokenizer))
+
diff --git a/data/vocab.py b/data/vocab.py
new file mode 100644
index 0000000..7614248
--- /dev/null
+++ b/data/vocab.py
@@ -0,0 +1,372 @@
+from __future__ import unicode_literals
+import array
+from collections import defaultdict
+from functools import partial
+import io
+import logging
+import os
+import zipfile
+
+import six
+from six.moves.urllib.request import urlretrieve
+import torch
+from tqdm import tqdm
+import tarfile
+
+from .utils import reporthook
+
+logger = logging.getLogger(__name__)
+
+
+class Vocab(object):
+ """Defines a vocabulary object that will be used to numericalize a field.
+
+ Attributes:
+ freqs: A collections.Counter object holding the frequencies of tokens
+ in the data used to build the Vocab.
+ stoi: A collections.defaultdict instance mapping token strings to
+ numerical identifiers.
+ itos: A list of token strings indexed by their numerical identifiers.
+ """
+ def __init__(self, counter, max_size=None, min_freq=1, specials=[''],
+ vectors=None, unk_init=None, vectors_cache=None):
+ """Create a Vocab object from a collections.Counter.
+
+ Arguments:
+ counter: collections.Counter object holding the frequencies of
+ each value found in the data.
+ max_size: The maximum size of the vocabulary, or None for no
+ maximum. Default: None.
+ min_freq: The minimum frequency needed to include a token in the
+ vocabulary. Values less than 1 will be set to 1. Default: 1.
+ specials: The list of special tokens (e.g., padding or eos) that
+ will be prepended to the vocabulary in addition to an
+ token. Default: ['']
+ vectors: One of either the available pretrained vectors
+ or custom pretrained vectors (see Vocab.load_vectors);
+ or a list of aforementioned vectors
+ unk_init (callback): by default, initialize out-of-vocabulary word vectors
+ to zero vectors; can be any function that takes in a Tensor and
+ returns a Tensor of the same size. Default: torch.Tensor.zero_
+ vectors_cache: directory for cached vectors. Default: '.vector_cache'
+ """
+ self.freqs = counter
+ counter = counter.copy()
+ min_freq = max(min_freq, 1)
+
+ self.itos = list(specials)
+ # frequencies of special tokens are not counted when building vocabulary
+ # in frequency order
+ for tok in specials:
+ del counter[tok]
+
+ max_size = None if max_size is None else max_size + len(self.itos)
+
+ # sort by frequency, then alphabetically
+ words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
+ words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
+
+ for word, freq in words_and_frequencies:
+ if freq < min_freq or len(self.itos) == max_size:
+ break
+ self.itos.append(word)
+
+ self.stoi = defaultdict(_default_unk_index)
+ # stoi is simply a reverse dict for itos
+ self.stoi.update({tok: i for i, tok in enumerate(self.itos)})
+
+ self.vectors = None
+ if vectors is not None:
+ self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
+ else:
+ assert unk_init is None and vectors_cache is None
+
+ def __eq__(self, other):
+ if self.freqs != other.freqs:
+ return False
+ if self.stoi != other.stoi:
+ return False
+ if self.itos != other.itos:
+ return False
+ if self.vectors != other.vectors:
+ return False
+ return True
+
+ def __len__(self):
+ return len(self.itos)
+
+ def extend(self, v, sort=False):
+ words = sorted(v.itos) if sort else v.itos
+ for w in words:
+ if w not in self.stoi:
+ self.itos.append(w)
+ self.stoi[w] = len(self.itos) - 1
+
+ def load_vectors(self, vectors, **kwargs):
+ """
+ Arguments:
+ vectors: one of or a list containing instantiations of the
+ GloVe, CharNGram, or Vectors classes. Alternatively, one
+ of or a list of available pretrained vectors:
+ charngram.100d
+ fasttext.en.300d
+ fasttext.simple.300d
+ glove.42B.300d
+ glove.840B.300d
+ glove.twitter.27B.25d
+ glove.twitter.27B.50d
+ glove.twitter.27B.100d
+ glove.twitter.27B.200d
+ glove.6B.50d
+ glove.6B.100d
+ glove.6B.200d
+ glove.6B.300d
+ Remaining keyword arguments: Passed to the constructor of Vectors classes.
+ """
+ if not isinstance(vectors, list):
+ vectors = [vectors]
+ for idx, vector in enumerate(vectors):
+ if six.PY2 and isinstance(vector, str):
+ vector = six.text_type(vector)
+ if isinstance(vector, six.string_types):
+ # Convert the string pretrained vector identifier
+ # to a Vectors object
+ if vector not in pretrained_aliases:
+ raise ValueError(
+ "Got string input vector {}, but allowed pretrained "
+ "vectors are {}".format(
+ vector, list(pretrained_aliases.keys())))
+ vectors[idx] = pretrained_aliases[vector](**kwargs)
+ elif not isinstance(vector, Vectors):
+ raise ValueError(
+ "Got input vectors of type {}, expected str or "
+ "Vectors object".format(type(vector)))
+
+ tot_dim = sum(v.dim for v in vectors)
+ self.vectors = torch.Tensor(len(self), tot_dim)
+ for i, token in enumerate(self.itos):
+ start_dim = 0
+ for v in vectors:
+ end_dim = start_dim + v.dim
+ self.vectors[i][start_dim:end_dim] = v[token.strip()]
+ start_dim = end_dim
+ assert(start_dim == tot_dim)
+
+ def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_):
+ """
+ Set the vectors for the Vocab instance from a collection of Tensors.
+
+ Arguments:
+ stoi: A dictionary of string to the index of the associated vector
+ in the `vectors` input argument.
+ vectors: An indexed iterable (or other structure supporting __getitem__) that
+ given an input index, returns a FloatTensor representing the vector
+ for the token associated with the index. For example,
+ vector[stoi["string"]] should return the vector for "string".
+ dim: The dimensionality of the vectors.
+ unk_init (callback): by default, initialize out-of-vocabulary word vectors
+ to zero vectors; can be any function that takes in a Tensor and
+ returns a Tensor of the same size. Default: torch.Tensor.zero_
+ """
+ self.vectors = torch.Tensor(len(self), dim)
+ for i, token in enumerate(self.itos):
+ wv_index = stoi.get(token, None)
+ if wv_index is not None:
+ self.vectors[i] = vectors[wv_index]
+ else:
+ self.vectors[i] = unk_init(self.vectors[i])
+
+
+class Vectors(object):
+
+ def __init__(self, name, cache=None,
+ url=None, unk_init=None):
+ """
+ Arguments:
+ name: name of the file that contains the vectors
+ cache: directory for cached vectors
+ url: url for download if vectors not found in cache
+ unk_init (callback): by default, initalize out-of-vocabulary word vectors
+ to zero vectors; can be any function that takes in a Tensor and
+ returns a Tensor of the same size
+ """
+ cache = '.vector_cache' if cache is None else cache
+ self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init
+ self.cache(name, cache, url=url)
+
+ def __getitem__(self, token):
+ if token in self.stoi:
+ return self.vectors[self.stoi[token]]
+ else:
+ return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim))
+
+ def cache(self, name, cache, url=None):
+ if os.path.isfile(name):
+ path = name
+ path_pt = os.path.join(cache, os.path.basename(name)) + '.pt'
+ else:
+ path = os.path.join(cache, name)
+ path_pt = path + '.pt'
+
+ if not os.path.isfile(path_pt):
+ if not os.path.isfile(path) and url:
+ logger.info('Downloading vectors from {}'.format(url))
+ if not os.path.exists(cache):
+ os.makedirs(cache)
+ dest = os.path.join(cache, os.path.basename(url))
+ if not os.path.isfile(dest):
+ with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t:
+ try:
+ urlretrieve(url, dest, reporthook=reporthook(t))
+ except KeyboardInterrupt as e: # remove the partial zip file
+ os.remove(dest)
+ raise e
+ logger.info('Extracting vectors into {}'.format(cache))
+ ext = os.path.splitext(dest)[1][1:]
+ if ext == 'zip':
+ with zipfile.ZipFile(dest, "r") as zf:
+ zf.extractall(cache)
+ elif ext == 'gz':
+ with tarfile.open(dest, 'r:gz') as tar:
+ tar.extractall(path=cache)
+ if not os.path.isfile(path):
+ raise RuntimeError('no vectors found at {}'.format(path))
+
+ # str call is necessary for Python 2/3 compatibility, since
+ # argument must be Python 2 str (Python 3 bytes) or
+ # Python 3 str (Python 2 unicode)
+ itos, vectors, dim = [], array.array(str('d')), None
+
+ # Try to read the whole file with utf-8 encoding.
+ binary_lines = False
+ try:
+ with io.open(path, encoding="utf8") as f:
+ lines = [line for line in f]
+ # If there are malformed lines, read in binary mode
+ # and manually decode each word from utf-8
+ except:
+ logger.warning("Could not read {} as UTF8 file, "
+ "reading file as bytes and skipping "
+ "words with malformed UTF8.".format(path))
+ with open(path, 'rb') as f:
+ lines = [line for line in f]
+ binary_lines = True
+
+ logger.info("Loading vectors from {}".format(path))
+ for line in tqdm(lines, total=len(lines)):
+ # Explicitly splitting on " " is important, so we don't
+ # get rid of Unicode non-breaking spaces in the vectors.
+ entries = line.rstrip().split(b" " if binary_lines else " ")
+
+ word, entries = entries[0], entries[1:]
+ if dim is None and len(entries) > 1:
+ dim = len(entries)
+ elif len(entries) == 1:
+ logger.warning("Skipping token {} with 1-dimensional "
+ "vector {}; likely a header".format(word, entries))
+ continue
+ elif dim != len(entries):
+ raise RuntimeError(
+ "Vector for token {} has {} dimensions, but previously "
+ "read vectors have {} dimensions. All vectors must have "
+ "the same number of dimensions.".format(word, len(entries), dim))
+
+ if binary_lines:
+ try:
+ if isinstance(word, six.binary_type):
+ word = word.decode('utf-8')
+ except:
+ logger.info("Skipping non-UTF8 token {}".format(repr(word)))
+ continue
+ vectors.extend(float(x) for x in entries)
+ itos.append(word)
+
+ self.itos = itos
+ self.stoi = {word: i for i, word in enumerate(itos)}
+ self.vectors = torch.Tensor(vectors).view(-1, dim)
+ self.dim = dim
+ logger.info('Saving vectors to {}'.format(path_pt))
+ if not os.path.exists(cache):
+ os.makedirs(cache)
+ torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt)
+ else:
+ logger.info('Loading vectors from {}'.format(path_pt))
+ self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt)
+
+
+class GloVe(Vectors):
+ url = {
+ '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip',
+ '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip',
+ 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip',
+ '6B': 'http://nlp.stanford.edu/data/glove.6B.zip',
+ }
+
+ def __init__(self, name='840B', dim=300, **kwargs):
+ url = self.url[name]
+ name = 'glove.{}.{}d.txt'.format(name, str(dim))
+ super(GloVe, self).__init__(name, url=url, **kwargs)
+
+
+class FastText(Vectors):
+
+ url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec'
+
+ def __init__(self, language="en", **kwargs):
+ url = self.url_base.format(language)
+ name = os.path.basename(url)
+ super(FastText, self).__init__(name, url=url, **kwargs)
+
+
+class CharNGram(Vectors):
+
+ name = 'charNgram.txt'
+ url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/'
+ 'jmt_pre-trained_embeddings.tar.gz')
+
+ def __init__(self, **kwargs):
+ super(CharNGram, self).__init__(self.name, url=self.url, **kwargs)
+
+ def __getitem__(self, token):
+ vector = torch.Tensor(1, self.dim).zero_()
+ if token == "":
+ return self.unk_init(vector)
+ # These literals need to be coerced to unicode for Python 2 compatibility
+ # when we try to join them with read ngrams from the files.
+ chars = ['#BEGIN#'] + list(token) + ['#END#']
+ num_vectors = 0
+ for n in [2, 3, 4]:
+ end = len(chars) - n + 1
+ grams = [chars[i:(i + n)] for i in range(end)]
+ for gram in grams:
+ gram_key = '{}gram-{}'.format(n, ''.join(gram))
+ if gram_key in self.stoi:
+ vector += self.vectors[self.stoi[gram_key]]
+ num_vectors += 1
+ if num_vectors > 0:
+ vector /= num_vectors
+ else:
+ vector = self.unk_init(vector)
+ return vector
+
+
+def _default_unk_index():
+ return 0
+
+
+pretrained_aliases = {
+ "charngram.100d": partial(CharNGram),
+ "fasttext.en.300d": partial(FastText, language="en"),
+ "fasttext.simple.300d": partial(FastText, language="simple"),
+ "glove.42B.300d": partial(GloVe, name="42B", dim="300"),
+ "glove.840B.300d": partial(GloVe, name="840B", dim="300"),
+ "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"),
+ "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"),
+ "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"),
+ "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"),
+ "glove.6B.50d": partial(GloVe, name="6B", dim="50"),
+ "glove.6B.100d": partial(GloVe, name="6B", dim="100"),
+ "glove.6B.200d": partial(GloVe, name="6B", dim="200"),
+ "glove.6B.300d": partial(GloVe, name="6B", dim="300")
+}
+"""Mapping from string name to factory function"""
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000..3990d3a
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,194 @@
+name: ic
+channels:
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=4.5=1_gnu
+ - argcomplete=1.12.3=pyhd3eb1b0_0
+ - backcall=0.2.0=pyhd3eb1b0_0
+ - ca-certificates=2021.7.5=h06a4308_1
+ - certifi=2021.5.30=py37h06a4308_0
+ - debugpy=1.4.1=py37h295c915_0
+ - decorator=5.0.9=pyhd3eb1b0_0
+ - entrypoints=0.3=py37_0
+ - importlib-metadata=3.10.0=py37h06a4308_0
+ - importlib_metadata=3.10.0=hd3eb1b0_0
+ - ipykernel=6.2.0=py37h06a4308_1
+ - ipython=7.26.0=py37hb070fc8_0
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
+ - jedi=0.18.0=py37h06a4308_1
+ - jupyter_client=7.0.1=pyhd3eb1b0_0
+ - jupyter_core=4.7.1=py37h06a4308_0
+ - ld_impl_linux-64=2.35.1=h7274673_9
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=9.3.0=h5101ec6_17
+ - libgomp=9.3.0=h5101ec6_17
+ - libsodium=1.0.18=h7b6447c_0
+ - libstdcxx-ng=9.3.0=hd4cf53a_17
+ - matplotlib-inline=0.1.2=pyhd3eb1b0_2
+ - ncurses=6.2=he6710b0_1
+ - nest-asyncio=1.5.1=pyhd3eb1b0_0
+ - openssl=1.1.1l=h7f8727e_0
+ - parso=0.8.2=pyhd3eb1b0_0
+ - pexpect=4.8.0=pyhd3eb1b0_3
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
+ - pip=21.0.1=py37h06a4308_0
+ - prompt-toolkit=3.0.17=pyhca03da5_0
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
+ - pygments=2.10.0=pyhd3eb1b0_0
+ - python=3.7.11=h12debd9_0
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
+ - pyzmq=22.2.1=py37h295c915_1
+ - readline=8.1=h27cfd23_0
+ - setuptools=52.0.0=py37h06a4308_0
+ - six=1.16.0=pyhd3eb1b0_0
+ - sqlite=3.36.0=hc218d9a_0
+ - tk=8.6.10=hbc83047_0
+ - tornado=6.1=py37h27cfd23_0
+ - traitlets=5.0.5=pyhd3eb1b0_0
+ - typing_extensions=3.10.0.0=pyhca03da5_0
+ - wcwidth=0.2.5=pyhd3eb1b0_0
+ - wheel=0.37.0=pyhd3eb1b0_1
+ - xz=5.2.5=h7b6447c_0
+ - zeromq=4.3.4=h2531618_0
+ - zipp=3.5.0=pyhd3eb1b0_0
+ - zlib=1.2.11=h7b6447c_3
+ - pip:
+ - absl-py==0.13.0
+ - addict==2.4.0
+ - attrs==21.2.0
+ - av==8.1.0
+ - bert-serving-client==1.10.0
+ - bert-serving-server==1.10.0
+ - blessings==1.7
+ - blis==0.7.4
+ - boto3==1.18.48
+ - botocore==1.21.48
+ - cached-property==1.5.2
+ - cachetools==4.2.2
+ - catalogue==1.0.0
+ - charset-normalizer==2.0.4
+ - click==7.1.2
+ - cycler==0.10.0
+ - cymem==2.0.5
+ - cython==0.29.24
+ - distlib==0.3.4
+ - easydict==1.9
+ - einops==0.4.0
+ - en-core-web-md==2.3.1
+ - en-core-web-sm==3.1.0
+ - en-vectors-web-lg==2.1.0
+ - filelock==3.6.0
+ - future==0.18.2
+ - gensim==4.1.2
+ - google-auth==1.35.0
+ - google-auth-oauthlib==0.4.6
+ - gpustat==0.6.0
+ - gputil==1.4.0
+ - grad-cam==1.3.7
+ - grpcio==1.40.0
+ - h5py==3.4.0
+ - huggingface-hub==0.0.17
+ - idna==3.2
+ - imageio==2.13.3
+ - iniconfig==1.1.1
+ - jinja2==3.0.1
+ - jmespath==0.10.0
+ - joblib==1.0.1
+ - jsonschema==4.1.0
+ - kiwisolver==1.3.2
+ - lmdb==1.2.1
+ - lxml==4.6.3
+ - lz4==3.1.3
+ - markdown==3.3.4
+ - markupsafe==2.0.1
+ - matplotlib==3.4.3
+ - middle==0.2.4
+ - mime==0.1.0
+ - mmcv==1.4.0
+ - msgpack==1.0.3
+ - murmurhash==1.0.5
+ - nbformat==5.1.3
+ - networkx==2.6.3
+ - nltk==3.6.7
+ - nose==1.3.7
+ - numpy==1.21.2
+ - nvidia-ml-py3==7.352.0
+ - oauthlib==3.1.1
+ - opencv-python==4.5.5.62
+ - packaging==21.0
+ - pandas==1.3.3
+ - pathy==0.6.0
+ - pillow==8.3.2
+ - plac==1.1.3
+ - platformdirs==2.5.1
+ - plotly==5.3.1
+ - pluggy==1.0.0
+ - preshed==3.0.5
+ - progressbar==2.5
+ - protobuf==3.17.3
+ - psutil==5.8.0
+ - py==1.11.0
+ - pyasn1==0.4.8
+ - pyasn1-modules==0.2.8
+ - pycocotools==2.0.2
+ - pydantic==1.8.2
+ - pyemd==0.5.1
+ - pyparsing==2.4.7
+ - pyrsistent==0.18.0
+ - pytest==7.1.0
+ - python-docx==0.8.11
+ - pytorch-transformers==1.2.0
+ - pytz==2021.1
+ - pywavelets==1.2.0
+ - pyyaml==5.4.1
+ - regex==2021.8.28
+ - requests==2.26.0
+ - requests-oauthlib==1.3.0
+ - revtok==0.0.3
+ - rsa==4.7.2
+ - s3transfer==0.5.0
+ - sacremoses==0.0.45
+ - scikit-image==0.19.0
+ - scikit-learn==1.0
+ - scipy==1.7.1
+ - seaborn==0.11.2
+ - sentencepiece==0.1.96
+ - sklearn==0.0
+ - smart-open==5.2.1
+ - spacy==2.3.7
+ - spacy-legacy==3.0.8
+ - spherecluster==0.1.7
+ - srsly==1.0.5
+ - tenacity==8.0.1
+ - tensorboard==2.6.0
+ - tensorboard-data-server==0.6.1
+ - tensorboard-logger==0.1.0
+ - tensorboard-plugin-wit==1.8.0
+ - tensorboardx==2.5
+ - termcolor==1.1.0
+ - thinc==7.4.5
+ - thop==0.0.31-2005241907
+ - threadpoolctl==3.0.0
+ - tifffile==2021.11.2
+ - timm==0.4.12
+ - tokenizers==0.10.3
+ - tomli==2.0.1
+ - torch==1.7.1+cu110
+ - torch-tb-profiler==0.2.1
+ - torchsummary==1.5.1
+ - torchtext==0.8.0
+ - torchvision==0.8.2
+ - tqdm==4.62.2
+ - transformers==4.10.2
+ - ttach==0.0.3
+ - typer==0.3.2
+ - typing-extensions==3.10.0.2
+ - ujson==5.1.0
+ - urllib3==1.26.6
+ - virtualenv==20.14.0
+ - visualize==0.5.1
+ - wasabi==0.8.2
+ - werkzeug==2.0.1
+ - yapf==0.32.0
+prefix: /home/zhanghaonan/anaconda3/envs/ic
diff --git a/evaluation/__init__.py b/evaluation/__init__.py
new file mode 100644
index 0000000..150e32d
--- /dev/null
+++ b/evaluation/__init__.py
@@ -0,0 +1,17 @@
+from .bleu import Bleu
+from .meteor import Meteor
+from .rouge import Rouge
+from .cider import Cider
+# from .spice import Spice
+from .tokenizer import PTBTokenizer
+
+def compute_scores(gts, gen):
+ metrics = (Bleu(), Meteor(), Rouge(), Cider())
+ all_score = {}
+ all_scores = {}
+ for metric in metrics:
+ score, scores = metric.compute_score(gts, gen)
+ all_score[str(metric)] = score
+ all_scores[str(metric)] = scores
+
+ return all_score, all_scores
diff --git a/evaluation/bleu/__init__.py b/evaluation/bleu/__init__.py
new file mode 100644
index 0000000..8000b20
--- /dev/null
+++ b/evaluation/bleu/__init__.py
@@ -0,0 +1 @@
+from .bleu import Bleu
\ No newline at end of file
diff --git a/evaluation/bleu/bleu.py b/evaluation/bleu/bleu.py
new file mode 100644
index 0000000..6d3139d
--- /dev/null
+++ b/evaluation/bleu/bleu.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+#
+# File Name : bleu.py
+#
+# Description : Wrapper for BLEU scorer.
+#
+# Creation Date : 06-01-2015
+# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
+# Authors : Hao Fang and Tsung-Yi Lin
+
+from .bleu_scorer import BleuScorer
+
+
+class Bleu:
+ def __init__(self, n=4):
+ # default compute Blue score up to 4
+ self._n = n
+ self._hypo_for_image = {}
+ self.ref_for_image = {}
+
+ def compute_score(self, gts, res):
+
+ assert(gts.keys() == res.keys())
+ imgIds = gts.keys()
+
+ bleu_scorer = BleuScorer(n=self._n)
+ for id in imgIds:
+ hypo = res[id]
+ ref = gts[id]
+
+ # Sanity check.
+ assert(type(hypo) is list)
+ assert(len(hypo) == 1)
+ assert(type(ref) is list)
+ assert(len(ref) >= 1)
+
+ bleu_scorer += (hypo[0], ref)
+
+ # score, scores = bleu_scorer.compute_score(option='shortest')
+ score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
+ # score, scores = bleu_scorer.compute_score(option='average', verbose=1)
+
+ return score, scores
+
+ def __str__(self):
+ return 'BLEU'
diff --git a/evaluation/bleu/bleu_scorer.py b/evaluation/bleu/bleu_scorer.py
new file mode 100644
index 0000000..8047f46
--- /dev/null
+++ b/evaluation/bleu/bleu_scorer.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python
+
+# bleu_scorer.py
+# David Chiang
+
+# Copyright (c) 2004-2006 University of Maryland. All rights
+# reserved. Do not redistribute without permission from the
+# author. Not for commercial use.
+
+# Modified by:
+# Hao Fang
+# Tsung-Yi Lin
+
+''' Provides:
+cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
+cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
+'''
+
+import copy
+import sys, math, re
+from collections import defaultdict
+
+
+def precook(s, n=4, out=False):
+ """Takes a string as input and returns an object that can be given to
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
+ can take string arguments as well."""
+ words = s.split()
+ counts = defaultdict(int)
+ for k in range(1, n + 1):
+ for i in range(len(words) - k + 1):
+ ngram = tuple(words[i:i + k])
+ counts[ngram] += 1
+ return (len(words), counts)
+
+
+def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
+ '''Takes a list of reference sentences for a single segment
+ and returns an object that encapsulates everything that BLEU
+ needs to know about them.'''
+
+ reflen = []
+ maxcounts = {}
+ for ref in refs:
+ rl, counts = precook(ref, n)
+ reflen.append(rl)
+ for (ngram, count) in counts.items():
+ maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
+
+ # Calculate effective reference sentence length.
+ if eff == "shortest":
+ reflen = min(reflen)
+ elif eff == "average":
+ reflen = float(sum(reflen)) / len(reflen)
+
+ ## lhuang: N.B.: leave reflen computaiton to the very end!!
+
+ ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
+
+ return (reflen, maxcounts)
+
+
+def cook_test(test, ref_tuple, eff=None, n=4):
+ '''Takes a test sentence and returns an object that
+ encapsulates everything that BLEU needs to know about it.'''
+
+ testlen, counts = precook(test, n, True)
+ reflen, refmaxcounts = ref_tuple
+
+ result = {}
+
+ # Calculate effective reference sentence length.
+
+ if eff == "closest":
+ result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1]
+ else: ## i.e., "average" or "shortest" or None
+ result["reflen"] = reflen
+
+ result["testlen"] = testlen
+
+ result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)]
+
+ result['correct'] = [0] * n
+ for (ngram, count) in counts.items():
+ result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
+
+ return result
+
+
+class BleuScorer(object):
+ """Bleu scorer.
+ """
+
+ __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
+
+ # special_reflen is used in oracle (proportional effective ref len for a node).
+
+ def copy(self):
+ ''' copy the refs.'''
+ new = BleuScorer(n=self.n)
+ new.ctest = copy.copy(self.ctest)
+ new.crefs = copy.copy(self.crefs)
+ new._score = None
+ return new
+
+ def __init__(self, test=None, refs=None, n=4, special_reflen=None):
+ ''' singular instance '''
+
+ self.n = n
+ self.crefs = []
+ self.ctest = []
+ self.cook_append(test, refs)
+ self.special_reflen = special_reflen
+
+ def cook_append(self, test, refs):
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
+
+ if refs is not None:
+ self.crefs.append(cook_refs(refs))
+ if test is not None:
+ cooked_test = cook_test(test, self.crefs[-1])
+ self.ctest.append(cooked_test) ## N.B.: -1
+ else:
+ self.ctest.append(None) # lens of crefs and ctest have to match
+
+ self._score = None ## need to recompute
+
+ def ratio(self, option=None):
+ self.compute_score(option=option)
+ return self._ratio
+
+ def score_ratio(self, option=None):
+ '''
+ return (bleu, len_ratio) pair
+ '''
+
+ return self.fscore(option=option), self.ratio(option=option)
+
+ def score_ratio_str(self, option=None):
+ return "%.4f (%.2f)" % self.score_ratio(option)
+
+ def reflen(self, option=None):
+ self.compute_score(option=option)
+ return self._reflen
+
+ def testlen(self, option=None):
+ self.compute_score(option=option)
+ return self._testlen
+
+ def retest(self, new_test):
+ if type(new_test) is str:
+ new_test = [new_test]
+ assert len(new_test) == len(self.crefs), new_test
+ self.ctest = []
+ for t, rs in zip(new_test, self.crefs):
+ self.ctest.append(cook_test(t, rs))
+ self._score = None
+
+ return self
+
+ def rescore(self, new_test):
+ ''' replace test(s) with new test(s), and returns the new score.'''
+
+ return self.retest(new_test).compute_score()
+
+ def size(self):
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
+ return len(self.crefs)
+
+ def __iadd__(self, other):
+ '''add an instance (e.g., from another sentence).'''
+
+ if type(other) is tuple:
+ ## avoid creating new BleuScorer instances
+ self.cook_append(other[0], other[1])
+ else:
+ assert self.compatible(other), "incompatible BLEUs."
+ self.ctest.extend(other.ctest)
+ self.crefs.extend(other.crefs)
+ self._score = None ## need to recompute
+
+ return self
+
+ def compatible(self, other):
+ return isinstance(other, BleuScorer) and self.n == other.n
+
+ def single_reflen(self, option="average"):
+ return self._single_reflen(self.crefs[0][0], option)
+
+ def _single_reflen(self, reflens, option=None, testlen=None):
+
+ if option == "shortest":
+ reflen = min(reflens)
+ elif option == "average":
+ reflen = float(sum(reflens)) / len(reflens)
+ elif option == "closest":
+ reflen = min((abs(l - testlen), l) for l in reflens)[1]
+ else:
+ assert False, "unsupported reflen option %s" % option
+
+ return reflen
+
+ def recompute_score(self, option=None, verbose=0):
+ self._score = None
+ return self.compute_score(option, verbose)
+
+ def compute_score(self, option=None, verbose=0):
+ n = self.n
+ small = 1e-9
+ tiny = 1e-15 ## so that if guess is 0 still return 0
+ bleu_list = [[] for _ in range(n)]
+
+ if self._score is not None:
+ return self._score
+
+ if option is None:
+ option = "average" if len(self.crefs) == 1 else "closest"
+
+ self._testlen = 0
+ self._reflen = 0
+ totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n}
+
+ # for each sentence
+ for comps in self.ctest:
+ testlen = comps['testlen']
+ self._testlen += testlen
+
+ if self.special_reflen is None: ## need computation
+ reflen = self._single_reflen(comps['reflen'], option, testlen)
+ else:
+ reflen = self.special_reflen
+
+ self._reflen += reflen
+
+ for key in ['guess', 'correct']:
+ for k in range(n):
+ totalcomps[key][k] += comps[key][k]
+
+ # append per image bleu score
+ bleu = 1.
+ for k in range(n):
+ bleu *= (float(comps['correct'][k]) + tiny) \
+ / (float(comps['guess'][k]) + small)
+ bleu_list[k].append(bleu ** (1. / (k + 1)))
+ ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
+ if ratio < 1:
+ for k in range(n):
+ bleu_list[k][-1] *= math.exp(1 - 1 / ratio)
+
+ if verbose > 1:
+ print(comps, reflen)
+
+ totalcomps['reflen'] = self._reflen
+ totalcomps['testlen'] = self._testlen
+
+ bleus = []
+ bleu = 1.
+ for k in range(n):
+ bleu *= float(totalcomps['correct'][k] + tiny) \
+ / (totalcomps['guess'][k] + small)
+ bleus.append(bleu ** (1. / (k + 1)))
+ ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
+ if ratio < 1:
+ for k in range(n):
+ bleus[k] *= math.exp(1 - 1 / ratio)
+
+ if verbose > 0:
+ print(totalcomps)
+ print("ratio:", ratio)
+
+ self._score = bleus
+ return self._score, bleu_list
diff --git a/evaluation/cider/__init__.py b/evaluation/cider/__init__.py
new file mode 100644
index 0000000..aaa32ec
--- /dev/null
+++ b/evaluation/cider/__init__.py
@@ -0,0 +1 @@
+from .cider import Cider
\ No newline at end of file
diff --git a/evaluation/cider/cider.py b/evaluation/cider/cider.py
new file mode 100644
index 0000000..57ae309
--- /dev/null
+++ b/evaluation/cider/cider.py
@@ -0,0 +1,42 @@
+# Filename: cider.py
+#
+# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
+# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
+#
+# Creation Date: Sun Feb 8 14:16:54 2015
+#
+# Authors: Ramakrishna Vedantam and Tsung-Yi Lin
+
+from .cider_scorer import CiderScorer
+
+class Cider:
+ """
+ Main Class to compute the CIDEr metric
+
+ """
+ def __init__(self, gts=None, n=4, sigma=6.0):
+ # set cider to sum over 1 to 4-grams
+ self._n = n
+ # set the standard deviation parameter for gaussian penalty
+ self._sigma = sigma
+ self.doc_frequency = None
+ self.ref_len = None
+ if gts is not None:
+ tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma)
+ self.doc_frequency = tmp_cider.doc_frequency
+ self.ref_len = tmp_cider.ref_len
+
+ def compute_score(self, gts, res):
+ """
+ Main function to compute CIDEr score
+ :param gts (dict) : dictionary with key and value
+ res (dict) : dictionary with key and value
+ :return: cider (float) : computed CIDEr score for the corpus
+ """
+ assert(gts.keys() == res.keys())
+ cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency,
+ ref_len=self.ref_len)
+ return cider_scorer.compute_score()
+
+ def __str__(self):
+ return 'CIDEr'
diff --git a/evaluation/cider/cider_scorer.py b/evaluation/cider/cider_scorer.py
new file mode 100644
index 0000000..37243ef
--- /dev/null
+++ b/evaluation/cider/cider_scorer.py
@@ -0,0 +1,167 @@
+#!/usr/bin/env python
+# Tsung-Yi Lin
+# Ramakrishna Vedantam
+
+import copy
+from collections import defaultdict
+import numpy as np
+import math
+
+def precook(s, n=4):
+ """
+ Takes a string as input and returns an object that can be given to
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
+ can take string arguments as well.
+ :param s: string : sentence to be converted into ngrams
+ :param n: int : number of ngrams for which representation is calculated
+ :return: term frequency vector for occuring ngrams
+ """
+ words = s.split()
+ counts = defaultdict(int)
+ for k in range(1,n+1):
+ for i in range(len(words)-k+1):
+ ngram = tuple(words[i:i+k])
+ counts[ngram] += 1
+ return counts
+
+def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
+ '''Takes a list of reference sentences for a single segment
+ and returns an object that encapsulates everything that BLEU
+ needs to know about them.
+ :param refs: list of string : reference sentences for some image
+ :param n: int : number of ngrams for which (ngram) representation is calculated
+ :return: result (list of dict)
+ '''
+ return [precook(ref, n) for ref in refs]
+
+def cook_test(test, n=4):
+ '''Takes a test sentence and returns an object that
+ encapsulates everything that BLEU needs to know about it.
+ :param test: list of string : hypothesis sentence for some image
+ :param n: int : number of ngrams for which (ngram) representation is calculated
+ :return: result (dict)
+ '''
+ return precook(test, n)
+
+class CiderScorer(object):
+ """CIDEr scorer.
+ """
+
+ def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None):
+ ''' singular instance '''
+ self.n = n
+ self.sigma = sigma
+ self.crefs = []
+ self.ctest = []
+ self.doc_frequency = defaultdict(float)
+ self.ref_len = None
+
+ for k in refs.keys():
+ self.crefs.append(cook_refs(refs[k]))
+ if test is not None:
+ self.ctest.append(cook_test(test[k][0])) ## N.B.: -1
+ else:
+ self.ctest.append(None) # lens of crefs and ctest have to match
+
+ if doc_frequency is None and ref_len is None:
+ # compute idf
+ self.compute_doc_freq()
+ # compute log reference length
+ self.ref_len = np.log(float(len(self.crefs)))
+ else:
+ self.doc_frequency = doc_frequency
+ self.ref_len = ref_len
+
+ def compute_doc_freq(self):
+ '''
+ Compute term frequency for reference data.
+ This will be used to compute idf (inverse document frequency later)
+ The term frequency is stored in the object
+ :return: None
+ '''
+ for refs in self.crefs:
+ # refs, k ref captions of one image
+ for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
+ self.doc_frequency[ngram] += 1
+ # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
+
+ def compute_cider(self):
+ def counts2vec(cnts):
+ """
+ Function maps counts of ngram to vector of tfidf weights.
+ The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
+ The n-th entry of array denotes length of n-grams.
+ :param cnts:
+ :return: vec (array of dict), norm (array of float), length (int)
+ """
+ vec = [defaultdict(float) for _ in range(self.n)]
+ length = 0
+ norm = [0.0 for _ in range(self.n)]
+ for (ngram,term_freq) in cnts.items():
+ # give word count 1 if it doesn't appear in reference corpus
+ df = np.log(max(1.0, self.doc_frequency[ngram]))
+ # ngram index
+ n = len(ngram)-1
+ # tf (term_freq) * idf (precomputed idf) for n-grams
+ vec[n][ngram] = float(term_freq)*(self.ref_len - df)
+ # compute norm for the vector. the norm will be used for computing similarity
+ norm[n] += pow(vec[n][ngram], 2)
+
+ if n == 1:
+ length += term_freq
+ norm = [np.sqrt(n) for n in norm]
+ return vec, norm, length
+
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
+ '''
+ Compute the cosine similarity of two vectors.
+ :param vec_hyp: array of dictionary for vector corresponding to hypothesis
+ :param vec_ref: array of dictionary for vector corresponding to reference
+ :param norm_hyp: array of float for vector corresponding to hypothesis
+ :param norm_ref: array of float for vector corresponding to reference
+ :param length_hyp: int containing length of hypothesis
+ :param length_ref: int containing length of reference
+ :return: array of score for each n-grams cosine similarity
+ '''
+ delta = float(length_hyp - length_ref)
+ # measure consine similarity
+ val = np.array([0.0 for _ in range(self.n)])
+ for n in range(self.n):
+ # ngram
+ for (ngram,count) in vec_hyp[n].items():
+ # vrama91 : added clipping
+ val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
+
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
+ val[n] /= (norm_hyp[n]*norm_ref[n])
+
+ assert(not math.isnan(val[n]))
+ # vrama91: added a length based gaussian penalty
+ val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
+ return val
+
+ scores = []
+ for test, refs in zip(self.ctest, self.crefs):
+ # compute vector for test captions
+ vec, norm, length = counts2vec(test)
+ # compute vector for ref captions
+ score = np.array([0.0 for _ in range(self.n)])
+ for ref in refs:
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
+ # change by vrama91 - mean of ngram scores, instead of sum
+ score_avg = np.mean(score)
+ # divide by number of references
+ score_avg /= len(refs)
+ # multiply score by 10
+ score_avg *= 10.0
+ # append score of an image to the score list
+ scores.append(score_avg)
+ return scores
+
+ def compute_score(self):
+ # compute cider score
+ score = self.compute_cider()
+ # debug
+ # print score
+ return np.mean(np.array(score)), np.array(score)
\ No newline at end of file
diff --git a/evaluation/meteor/__init__.py b/evaluation/meteor/__init__.py
new file mode 100644
index 0000000..35b4a26
--- /dev/null
+++ b/evaluation/meteor/__init__.py
@@ -0,0 +1 @@
+from .meteor import Meteor
\ No newline at end of file
diff --git a/evaluation/meteor/data/paraphrase-en.gz b/evaluation/meteor/data/paraphrase-en.gz
new file mode 100644
index 0000000..88033c8
Binary files /dev/null and b/evaluation/meteor/data/paraphrase-en.gz differ
diff --git a/evaluation/meteor/meteor.py b/evaluation/meteor/meteor.py
new file mode 100644
index 0000000..e41fb18
--- /dev/null
+++ b/evaluation/meteor/meteor.py
@@ -0,0 +1,75 @@
+# Python wrapper for METEOR implementation, by Xinlei Chen
+# Acknowledge Michael Denkowski for the generous discussion and help
+
+import os
+import subprocess
+import threading
+import tarfile
+from utils import download_from_url
+
+METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz'
+METEOR_JAR = 'meteor-1.5.jar'
+
+class Meteor:
+ def __init__(self):
+ base_path = os.path.dirname(os.path.abspath(__file__))
+ jar_path = os.path.join(base_path, METEOR_JAR)
+ gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL))
+ if not os.path.isfile(jar_path):
+ if not os.path.isfile(gz_path):
+ download_from_url(METEOR_GZ_URL, gz_path)
+ tar = tarfile.open(gz_path, "r")
+ tar.extractall(path=os.path.dirname(os.path.abspath(__file__)))
+ tar.close()
+ os.remove(gz_path)
+
+ self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
+ '-', '-', '-stdio', '-l', 'en', '-norm']
+ self.meteor_p = subprocess.Popen(self.meteor_cmd, \
+ cwd=os.path.dirname(os.path.abspath(__file__)), \
+ stdin=subprocess.PIPE, \
+ stdout=subprocess.PIPE, \
+ stderr=subprocess.PIPE)
+ # Used to guarantee thread safety
+ self.lock = threading.Lock()
+
+ def compute_score(self, gts, res):
+ assert(gts.keys() == res.keys())
+ imgIds = gts.keys()
+ scores = []
+
+ eval_line = 'EVAL'
+ self.lock.acquire()
+ for i in imgIds:
+ assert(len(res[i]) == 1)
+ stat = self._stat(res[i][0], gts[i])
+ eval_line += ' ||| {}'.format(stat)
+
+ self.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
+ self.meteor_p.stdin.flush()
+ for i in range(0,len(imgIds)):
+ scores.append(float(self.meteor_p.stdout.readline().strip()))
+ score = float(self.meteor_p.stdout.readline().strip())
+ self.lock.release()
+
+ return score, scores
+
+ def _stat(self, hypothesis_str, reference_list):
+ # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
+ hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
+ score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
+ self.meteor_p.stdin.write('{}\n'.format(score_line).encode())
+ self.meteor_p.stdin.flush()
+ raw = self.meteor_p.stdout.readline().decode().strip()
+ numbers = [str(int(float(n))) for n in raw.split()]
+ return ' '.join(numbers)
+
+ def __del__(self):
+ self.lock.acquire()
+ self.meteor_p.stdin.close()
+ self.meteor_p.kill()
+ self.meteor_p.wait()
+ self.lock.release()
+
+ def __str__(self):
+ return 'METEOR'
diff --git a/evaluation/rouge/__init__.py b/evaluation/rouge/__init__.py
new file mode 100644
index 0000000..59397f8
--- /dev/null
+++ b/evaluation/rouge/__init__.py
@@ -0,0 +1 @@
+from .rouge import Rouge
\ No newline at end of file
diff --git a/evaluation/rouge/rouge.py b/evaluation/rouge/rouge.py
new file mode 100644
index 0000000..06529f3
--- /dev/null
+++ b/evaluation/rouge/rouge.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python
+#
+# File Name : rouge.py
+#
+# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
+#
+# Creation Date : 2015-01-07 06:03
+# Author : Ramakrishna Vedantam
+
+import numpy as np
+import pdb
+
+
+def my_lcs(string, sub):
+ """
+ Calculates longest common subsequence for a pair of tokenized strings
+ :param string : list of str : tokens from a string split using whitespace
+ :param sub : list of str : shorter string, also split using whitespace
+ :returns: length (list of int): length of the longest common subsequence between the two strings
+
+ Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
+ """
+ if (len(string) < len(sub)):
+ sub, string = string, sub
+
+ lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)]
+
+ for j in range(1, len(sub) + 1):
+ for i in range(1, len(string) + 1):
+ if (string[i - 1] == sub[j - 1]):
+ lengths[i][j] = lengths[i - 1][j - 1] + 1
+ else:
+ lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
+
+ return lengths[len(string)][len(sub)]
+
+
+class Rouge():
+ '''
+ Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
+
+ '''
+
+ def __init__(self):
+ # vrama91: updated the value below based on discussion with Hovey
+ self.beta = 1.2
+
+ def calc_score(self, candidate, refs):
+ """
+ Compute ROUGE-L score given one candidate and references for an image
+ :param candidate: str : candidate sentence to be evaluated
+ :param refs: list of str : COCO reference sentences for the particular image to be evaluated
+ :returns score: int (ROUGE-L score for the candidate evaluated against references)
+ """
+ assert (len(candidate) == 1)
+ assert (len(refs) > 0)
+ prec = []
+ rec = []
+
+ # split into tokens
+ token_c = candidate[0].split(" ")
+
+ for reference in refs:
+ # split into tokens
+ token_r = reference.split(" ")
+ # compute the longest common subsequence
+ lcs = my_lcs(token_r, token_c)
+ prec.append(lcs / float(len(token_c)))
+ rec.append(lcs / float(len(token_r)))
+
+ prec_max = max(prec)
+ rec_max = max(rec)
+
+ if (prec_max != 0 and rec_max != 0):
+ score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max)
+ else:
+ score = 0.0
+ return score
+
+ def compute_score(self, gts, res):
+ """
+ Computes Rouge-L score given a set of reference and candidate sentences for the dataset
+ Invoked by evaluate_captions.py
+ :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
+ :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
+ :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
+ """
+ assert (gts.keys() == res.keys())
+ imgIds = gts.keys()
+
+ score = []
+ for id in imgIds:
+ hypo = res[id]
+ ref = gts[id]
+
+ score.append(self.calc_score(hypo, ref))
+
+ # Sanity check.
+ assert (type(hypo) is list)
+ assert (len(hypo) == 1)
+ assert (type(ref) is list)
+ assert (len(ref) > 0)
+
+ average_score = np.mean(np.array(score))
+ return average_score, np.array(score)
+
+ def __str__(self):
+ return 'ROUGE'
diff --git a/evaluation/tokenizer.py b/evaluation/tokenizer.py
new file mode 100644
index 0000000..73fb49c
--- /dev/null
+++ b/evaluation/tokenizer.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python
+#
+# File Name : ptbtokenizer.py
+#
+# Description : Do the PTB Tokenization and remove punctuations.
+#
+# Creation Date : 29-12-2014
+# Last Modified : Thu Mar 19 09:53:35 2015
+# Authors : Hao Fang and Tsung-Yi Lin
+
+import os
+import subprocess
+import tempfile
+
+class PTBTokenizer(object):
+ """Python wrapper of Stanford PTBTokenizer"""
+
+ corenlp_jar = 'stanford-corenlp-3.4.1.jar'
+ punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
+ ".", "?", "!", ",", ":", "-", "--", "...", ";"]
+
+ @classmethod
+ def tokenize(cls, corpus):
+ cmd = ['java', '-cp', cls.corenlp_jar, \
+ 'edu.stanford.nlp.process.PTBTokenizer', \
+ '-preserveLines', '-lowerCase']
+
+ if isinstance(corpus, list) or isinstance(corpus, tuple):
+ if isinstance(corpus[0], list) or isinstance(corpus[0], tuple):
+ corpus = {i:c for i, c in enumerate(corpus)}
+ else:
+ corpus = {i: [c, ] for i, c in enumerate(corpus)}
+
+ # prepare data for PTB Tokenizer
+ tokenized_corpus = {}
+ image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))]
+ sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v])
+
+ # save sentences to temporary file
+ path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
+ tmp_file.write(sentences.encode())
+ tmp_file.close()
+
+ # tokenize sentence
+ cmd.append(os.path.basename(tmp_file.name))
+ p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
+ stdout=subprocess.PIPE, stderr=open(os.devnull, 'w'))
+ token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
+ token_lines = token_lines.decode()
+ lines = token_lines.split('\n')
+ # remove temp file
+ os.remove(tmp_file.name)
+
+ # create dictionary for tokenized captions
+ for k, line in zip(image_id, lines):
+ if not k in tokenized_corpus:
+ tokenized_corpus[k] = []
+ tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
+ if w not in cls.punctuations])
+ tokenized_corpus[k].append(tokenized_caption)
+
+ return tokenized_corpus
\ No newline at end of file
diff --git a/framework.png b/framework.png
new file mode 100644
index 0000000..7d40153
Binary files /dev/null and b/framework.png differ
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..943ce68
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,2 @@
+from .transformer import Transformer
+from .captioning_model import CaptioningModel
diff --git a/models/beam_search/__init__.py b/models/beam_search/__init__.py
new file mode 100644
index 0000000..21ac612
--- /dev/null
+++ b/models/beam_search/__init__.py
@@ -0,0 +1 @@
+from .beam_search import BeamSearch
diff --git a/models/beam_search/beam_search.py b/models/beam_search/beam_search.py
new file mode 100644
index 0000000..bfe99b4
--- /dev/null
+++ b/models/beam_search/beam_search.py
@@ -0,0 +1,145 @@
+import torch
+import utils
+
+
+class BeamSearch(object):
+ def __init__(self, model, max_len: int, eos_idx: int, beam_size: int):
+ self.model = model
+ self.max_len = max_len
+ self.eos_idx = eos_idx
+ self.beam_size = beam_size
+ self.b_s = None
+ self.device = None
+ self.seq_mask = None
+ self.seq_logprob = None
+ self.outputs = None
+ self.log_probs = None
+ self.selected_words = None
+ self.all_log_probs = None
+
+ def _expand_state(self, selected_beam, cur_beam_size):
+ def fn(s):
+ shape = [int(sh) for sh in s.shape]
+ beam = selected_beam
+ for _ in shape[1:]:
+ beam = beam.unsqueeze(-1)
+ s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
+ beam.expand(*([self.b_s, self.beam_size] + shape[1:])))
+ s = s.view(*([-1, ] + shape[1:]))
+ return s
+
+ return fn
+
+ def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor):
+ if isinstance(visual, torch.Tensor):
+ visual_shape = visual.shape
+ visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
+ visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
+ selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
+ selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
+ visual_exp = visual.view(visual_exp_shape)
+ selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
+ visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
+ else:
+ new_visual = []
+ for im in visual:
+ visual_shape = im.shape
+ visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
+ visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
+ selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
+ selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
+ visual_exp = im.view(visual_exp_shape)
+ selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
+ new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
+ new_visual.append(new_im)
+ visual = tuple(new_visual)
+ return visual
+
+ def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs):
+ self.b_s = utils.get_batch_size(visual)
+ self.device = utils.get_device(visual)
+ self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device)
+ self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device)
+ self.log_probs = []
+ self.selected_words = None
+ if return_probs:
+ self.all_log_probs = []
+
+ outputs = []
+ with self.model.statefulness(self.b_s):
+ for t in range(self.max_len):
+ visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs)
+
+ # Sort result
+ seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True)
+ outputs = torch.cat(outputs, -1)
+ outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
+ log_probs = torch.cat(self.log_probs, -1)
+ log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
+ if return_probs:
+ all_log_probs = torch.cat(self.all_log_probs, 2)
+ all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size,
+ self.max_len,
+ all_log_probs.shape[-1]))
+ outputs = outputs.contiguous()[:, :out_size]
+ log_probs = log_probs.contiguous()[:, :out_size]
+ if out_size == 1:
+ outputs = outputs.squeeze(1)
+ log_probs = log_probs.squeeze(1)
+
+ if return_probs:
+ return outputs, log_probs, all_log_probs
+ else:
+ return outputs, log_probs
+
+ def select(self, t, candidate_logprob, **kwargs):
+ selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True)
+ selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size]
+ return selected_idx, selected_logprob
+
+ def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs):
+ cur_beam_size = 1 if t == 0 else self.beam_size
+
+ word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs)
+ word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1)
+ candidate_logprob = self.seq_logprob + word_logprob
+
+ # Mask sequence if it reaches EOS
+ if t > 0:
+ mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1)
+ self.seq_mask = self.seq_mask * mask
+ word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob)
+ old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous()
+ old_seq_logprob[:, :, 1:] = -999
+ candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask)
+
+ selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs)
+ # selected_beam = selected_idx / candidate_logprob.shape[-1] # //
+ # selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] # 取余
+ selected_beam = selected_idx // candidate_logprob.shape[-1] # //
+ selected_words = selected_idx % candidate_logprob.shape[-1] # 取余
+
+ self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
+ visual = self._expand_visual(visual, cur_beam_size, selected_beam)
+
+ self.seq_logprob = selected_logprob.unsqueeze(-1)
+ self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1))
+ outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
+ outputs.append(selected_words.unsqueeze(-1))
+
+ if return_probs:
+ if t == 0:
+ self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2))
+ else:
+ self.all_log_probs.append(word_logprob.unsqueeze(2))
+
+ this_word_logprob = torch.gather(word_logprob, 1,
+ selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size,
+ word_logprob.shape[-1]))
+ this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
+ self.log_probs = list(
+ torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs)
+ self.log_probs.append(this_word_logprob)
+ self.selected_words = selected_words.view(-1, 1)
+
+ return visual, outputs
diff --git a/models/captioning_model.py b/models/captioning_model.py
new file mode 100644
index 0000000..c170ac0
--- /dev/null
+++ b/models/captioning_model.py
@@ -0,0 +1,70 @@
+import torch
+from torch import distributions
+import utils
+from models.containers import Module
+from models.beam_search import *
+
+
+class CaptioningModel(Module):
+ def __init__(self):
+ super(CaptioningModel, self).__init__()
+
+ def init_weights(self):
+ raise NotImplementedError
+
+ def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
+ raise NotImplementedError
+
+ def forward(self, images, seq, *args):
+ device = images.device
+ b_s = images.size(0)
+ seq_len = seq.size(1)
+ state = self.init_state(b_s, device)
+ out = None
+
+ outputs = []
+ for t in range(seq_len):
+ out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing')
+ outputs.append(out)
+
+ outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1)
+ return outputs
+
+ def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
+ b_s = utils.get_batch_size(visual)
+ device = utils.get_device(visual)
+ outputs = []
+ log_probs = []
+
+ mask = torch.ones((b_s,), device=device)
+ with self.statefulness(b_s):
+ out = None
+ for t in range(max_len):
+ log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs)
+ out = torch.max(log_probs_t, -1)[1]
+ mask = mask * (out.squeeze(-1) != eos_idx).float()
+ log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1))
+ outputs.append(out)
+
+ return torch.cat(outputs, 1), torch.cat(log_probs, 1)
+
+ def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
+ b_s = utils.get_batch_size(visual)
+ outputs = []
+ log_probs = []
+
+ with self.statefulness(b_s):
+ out = None
+ for t in range(max_len):
+ out = self.step(t, out, visual, None, mode='feedback', **kwargs)
+ distr = distributions.Categorical(logits=out[:, 0])
+ out = distr.sample().unsqueeze(1)
+ outputs.append(out)
+ log_probs.append(distr.log_prob(out).unsqueeze(1))
+
+ return torch.cat(outputs, 1), torch.cat(log_probs, 1)
+
+ def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1,
+ return_probs=False, **kwargs):
+ bs = BeamSearch(self, max_len, eos_idx, beam_size)
+ return bs.apply(visual, out_size, return_probs, **kwargs)
diff --git a/models/containers.py b/models/containers.py
new file mode 100644
index 0000000..52296cd
--- /dev/null
+++ b/models/containers.py
@@ -0,0 +1,80 @@
+from contextlib import contextmanager
+from torch import nn
+from utils.typing import *
+
+
+class Module(nn.Module):
+ def __init__(self):
+ super(Module, self).__init__()
+ self._is_stateful = False
+ self._state_names = []
+ self._state_defaults = dict()
+
+ def register_state(self, name: str, default: TensorOrNone):
+ self._state_names.append(name)
+ if default is None:
+ self._state_defaults[name] = None
+ else:
+ self._state_defaults[name] = default.clone().detach()
+ self.register_buffer(name, default)
+
+ def states(self):
+ for name in self._state_names:
+ yield self._buffers[name]
+ for m in self.children():
+ if isinstance(m, Module):
+ yield from m.states()
+
+ def apply_to_states(self, fn):
+ for name in self._state_names:
+ self._buffers[name] = fn(self._buffers[name])
+ for m in self.children():
+ if isinstance(m, Module):
+ m.apply_to_states(fn)
+
+ def _init_states(self, batch_size: int):
+ for name in self._state_names:
+ if self._state_defaults[name] is None:
+ self._buffers[name] = None
+ else:
+ self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
+ self._buffers[name] = self._buffers[name].unsqueeze(0)
+ self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
+ self._buffers[name] = self._buffers[name].contiguous()
+
+ def _reset_states(self):
+ for name in self._state_names:
+ if self._state_defaults[name] is None:
+ self._buffers[name] = None
+ else:
+ self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
+
+ def enable_statefulness(self, batch_size: int):
+ for m in self.children():
+ if isinstance(m, Module):
+ m.enable_statefulness(batch_size)
+ self._init_states(batch_size)
+ self._is_stateful = True
+
+ def disable_statefulness(self):
+ for m in self.children():
+ if isinstance(m, Module):
+ m.disable_statefulness()
+ self._reset_states()
+ self._is_stateful = False
+
+ @contextmanager
+ def statefulness(self, batch_size: int):
+ self.enable_statefulness(batch_size)
+ try:
+ yield
+ finally:
+ self.disable_statefulness()
+
+
+class ModuleList(nn.ModuleList, Module):
+ pass
+
+
+class ModuleDict(nn.ModuleDict, Module):
+ pass
diff --git a/models/transformer/__init__.py b/models/transformer/__init__.py
new file mode 100644
index 0000000..c52d903
--- /dev/null
+++ b/models/transformer/__init__.py
@@ -0,0 +1,4 @@
+from .transformer import *
+from .encoders import *
+from .decoders import *
+from .attention import *
diff --git a/models/transformer/attention.py b/models/transformer/attention.py
new file mode 100644
index 0000000..e4fc0e6
--- /dev/null
+++ b/models/transformer/attention.py
@@ -0,0 +1,184 @@
+import numpy as np
+import torch
+from torch import nn
+from models.containers import Module
+
+
+class ScaledDotProductAttention(nn.Module):
+ '''
+ Scaled dot-product attention
+ '''
+
+ def __init__(self, d_model, d_k, d_v, h, dropout=.1, comment=None):
+ '''
+ :param d_model: Output dimensionality of the model
+ :param d_k: Dimensionality of queries and keys
+ :param d_v: Dimensionality of values
+ :param h: Number of heads
+ '''
+ super(ScaledDotProductAttention, self).__init__()
+ self.fc_q = nn.Linear(d_model, h * d_k)
+ self.fc_k = nn.Linear(d_model, h * d_k)
+ self.fc_v = nn.Linear(d_model, h * d_v)
+ self.fc_o = nn.Linear(h * d_v, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ self.d_model = d_model
+ self.d_k = d_k
+ self.d_v = d_v
+ self.h = h
+
+ self.init_weights()
+
+ self.comment = comment
+
+ def init_weights(self):
+ nn.init.xavier_uniform_(self.fc_q.weight)
+ nn.init.xavier_uniform_(self.fc_k.weight)
+ nn.init.xavier_uniform_(self.fc_v.weight)
+ nn.init.xavier_uniform_(self.fc_o.weight)
+ nn.init.constant_(self.fc_q.bias, 0)
+ nn.init.constant_(self.fc_k.bias, 0)
+ nn.init.constant_(self.fc_v.bias, 0)
+ nn.init.constant_(self.fc_o.bias, 0)
+
+ def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
+ '''
+ Computes
+ :param queries: Queries (b_s, nq, d_model)
+ :param keys: Keys (b_s, nk, d_model)
+ :param values: Values (b_s, nk, d_model)
+ :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
+ :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
+ :return:
+ '''
+
+ b_s, nq = queries.shape[:2]
+ nk = keys.shape[1]
+
+ q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
+ k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
+ v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
+
+ att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
+ if attention_weights is not None:
+ att = att * attention_weights
+ if attention_mask is not None:
+ att = att.masked_fill(attention_mask, -np.inf)
+ att = torch.softmax(att, -1)
+ att = self.dropout(att)
+
+ out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
+ out = self.fc_o(out) # (b_s, nq, d_model)
+ return out
+
+
+class ScaledDotProductAttentionMemory(nn.Module):
+ '''
+ Scaled dot-product attention with memory
+ '''
+
+ def __init__(self, d_model, d_k, d_v, h, m):
+ '''
+ :param d_model: Output dimensionality of the model
+ :param d_k: Dimensionality of queries and keys
+ :param d_v: Dimensionality of values
+ :param h: Number of heads
+ :param m: Number of memory slots
+ '''
+ super(ScaledDotProductAttentionMemory, self).__init__()
+ self.fc_q = nn.Linear(d_model, h * d_k)
+ self.fc_k = nn.Linear(d_model, h * d_k)
+ self.fc_v = nn.Linear(d_model, h * d_v)
+ self.fc_o = nn.Linear(h * d_v, d_model)
+ self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k))
+ self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v))
+
+ self.d_model = d_model
+ self.d_k = d_k
+ self.d_v = d_v
+ self.h = h
+ self.m = m
+
+ self.init_weights()
+
+ def init_weights(self):
+ nn.init.xavier_uniform_(self.fc_q.weight)
+ nn.init.xavier_uniform_(self.fc_k.weight)
+ nn.init.xavier_uniform_(self.fc_v.weight)
+ nn.init.xavier_uniform_(self.fc_o.weight)
+ nn.init.normal_(self.m_k, 0, 1 / self.d_k)
+ nn.init.normal_(self.m_v, 0, 1 / self.m)
+ nn.init.constant_(self.fc_q.bias, 0)
+ nn.init.constant_(self.fc_k.bias, 0)
+ nn.init.constant_(self.fc_v.bias, 0)
+ nn.init.constant_(self.fc_o.bias, 0)
+
+ def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
+ '''
+ Computes
+ :param queries: Queries (b_s, nq, d_model)
+ :param keys: Keys (b_s, nk, d_model)
+ :param values: Values (b_s, nk, d_model)
+ :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
+ :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
+ :return:
+ '''
+ b_s, nq = queries.shape[:2]
+ nk = keys.shape[1]
+
+ m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
+ m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)
+
+ q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
+ k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
+ v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
+
+ att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
+ if attention_weights is not None:
+ att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
+ if attention_mask is not None:
+ att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
+ att = torch.softmax(att, -1)
+ out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
+ out = self.fc_o(out) # (b_s, nq, d_model)
+ return out
+
+
+class MultiHeadAttention(Module):
+ '''
+ Multi-head attention layer with Dropout and Layer Normalization.
+ '''
+
+ def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
+ attention_module=None, attention_module_kwargs=None, comment=None):
+ super(MultiHeadAttention, self).__init__()
+ self.identity_map_reordering = identity_map_reordering
+ self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h, comment=comment)
+ self.dropout = nn.Dropout(p=dropout)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ self.can_be_stateful = can_be_stateful
+ if self.can_be_stateful:
+ self.register_state('running_keys', torch.zeros((0, d_model)))
+ self.register_state('running_values', torch.zeros((0, d_model)))
+
+ def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
+ if self.can_be_stateful and self._is_stateful:
+ self.running_keys = torch.cat([self.running_keys, keys], 1)
+ keys = self.running_keys
+
+ self.running_values = torch.cat([self.running_values, values], 1)
+ values = self.running_values
+
+ if self.identity_map_reordering:
+ q_norm = self.layer_norm(queries)
+ k_norm = self.layer_norm(keys)
+ v_norm = self.layer_norm(values)
+ out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights)
+ out = queries + self.dropout(torch.relu(out))
+ else:
+ out = self.attention(queries, keys, values, attention_mask, attention_weights)
+ out = self.dropout(out)
+ out = self.layer_norm(queries + out)
+ return out
diff --git a/models/transformer/decoders.py b/models/transformer/decoders.py
new file mode 100644
index 0000000..4d57413
--- /dev/null
+++ b/models/transformer/decoders.py
@@ -0,0 +1,84 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from models.transformer.attention import MultiHeadAttention
+from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
+from models.containers import Module, ModuleList
+
+
+class DecoderLayer(Module):
+ def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
+ enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
+ super(DecoderLayer, self).__init__()
+ self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
+ attention_module=self_att_module,
+ attention_module_kwargs=self_att_module_kwargs)
+ self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
+ attention_module=enc_att_module,
+ attention_module_kwargs=enc_att_module_kwargs)
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.lnorm1 = nn.LayerNorm(d_model)
+
+ self.dropout2 = nn.Dropout(dropout)
+ self.lnorm2 = nn.LayerNorm(d_model)
+ self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
+
+ def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
+ # MHA+AddNorm
+ self_att = self.self_att(input, input, input, mask_self_att)
+ self_att = self.lnorm1(input + self.dropout1(self_att))
+ self_att = self_att * mask_pad
+ # MHA+AddNorm:Image
+ enc_att = self.enc_att(self_att, enc_output, enc_output, mask_enc_att)
+ enc_att = self.lnorm2(self_att + self.dropout2(enc_att))
+ enc_att = enc_att * mask_pad
+
+ ff = self.pwff(enc_att)
+ ff = ff * mask_pad
+ return ff
+
+class TransformerDecoderLayer(Module):
+ def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
+ self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
+ super(TransformerDecoderLayer, self).__init__()
+ self.d_model = d_model
+ self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
+ self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
+ self.layers = ModuleList(
+ [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
+ self.fc = nn.Linear(d_model, vocab_size, bias=False)
+ self.max_len = max_len
+ self.padding_idx = padding_idx
+ self.N = N_dec
+
+ self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
+ self.register_state('running_seq', torch.zeros((1,)).long())
+
+ def forward(self, input, encoder_output, mask_encoder):
+ # input (b_s, seq_len)
+ b_s, seq_len = input.shape[:2]
+ mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1)
+ mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),
+ diagonal=1)
+ mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
+ mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
+ mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
+ if self._is_stateful:
+ self.running_mask_self_attention = torch.cat([self.running_mask_self_attention.type_as(mask_self_attention), mask_self_attention], -1)
+ mask_self_attention = self.running_mask_self_attention
+
+ seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
+ seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
+ if self._is_stateful:
+ self.running_seq.add_(1)
+ seq = self.running_seq
+
+ out = self.word_emb(input) + self.pos_emb(seq)
+
+ for i, l in enumerate(self.layers):
+ out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
+
+ out = self.fc(out)
+ return F.log_softmax(out, dim=-1)
\ No newline at end of file
diff --git a/models/transformer/encoders.py b/models/transformer/encoders.py
new file mode 100644
index 0000000..1bcad8d
--- /dev/null
+++ b/models/transformer/encoders.py
@@ -0,0 +1,84 @@
+from torch.nn import functional as F
+from torch.nn.modules.activation import LeakyReLU
+from models.transformer.utils import PositionWiseFeedForward
+import torch
+from torch import nn
+from models.transformer.attention import MultiHeadAttention
+
+
+class SR(nn.Module):
+ def __init__(self, N, d_model=512):
+ super(SR, self).__init__()
+ self.MLP = nn.Sequential(
+ nn.Linear(N*d_model, N*d_model),
+ nn.LeakyReLU(),
+ nn.Linear(N*d_model, d_model),
+ nn.LeakyReLU()
+ )
+ def forward(self, x, layers, attention_mask = None, attention_weights = None):
+ out = x
+ outs = []
+ for l in layers:
+ out = l(out, out, out, attention_mask, attention_weights)
+ outs.append(out)
+ outs = self.MLP(torch.cat(outs, -1))
+ out = 0.2 * outs + out
+ return out
+
+class EncoderLayer(nn.Module):
+ def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
+ attention_module=None, attention_module_kwargs=None):
+ super(EncoderLayer, self).__init__()
+ self.identity_map_reordering = identity_map_reordering
+ self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
+ attention_module=attention_module,
+ attention_module_kwargs=attention_module_kwargs)
+ self.dropout = nn.Dropout(dropout)
+ self.lnorm = nn.LayerNorm(d_model)
+ self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
+
+ def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
+
+ att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
+ att = self.lnorm(queries + self.dropout(att))
+ ff = self.pwff(att)
+ return ff
+
+
+class MultiLevelEncoder(nn.Module):
+ def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
+ identity_map_reordering=False, attention_module=None, attention_module_kwargs=None):
+ super(MultiLevelEncoder, self).__init__()
+ self.d_model = d_model
+ self.dropout = dropout
+ self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
+ identity_map_reordering=identity_map_reordering,
+ attention_module=attention_module, # ScaledDotProductAttention
+ attention_module_kwargs=attention_module_kwargs) # {'m': args.m}
+ for _ in range(N)])
+ self.SR = SR(N, d_model)
+ self.padding_idx = padding_idx
+
+ def forward(self, input, attention_weights=None):
+
+ # input (b_s, seq_len, d_in)
+ attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len)
+ out = self.SR(input, self.layers, attention_mask, attention_weights)
+
+ return out, attention_mask
+
+
+class TransformerEncoder(MultiLevelEncoder):
+ def __init__(self, N, padding_idx, d_in=2048, **kwargs):
+ super(TransformerEncoder, self).__init__(N, padding_idx, **kwargs)
+ self.fc = nn.Linear(d_in, self.d_model)
+ self.dropout = nn.Dropout(p=self.dropout)
+ self.layer_norm = nn.LayerNorm(self.d_model)
+
+ def forward(self, input, attention_weights=None):
+ mask = (torch.sum(input, dim=-1) == 0).unsqueeze(-1)
+ out = F.relu(self.fc(input))
+ out = self.dropout(out)
+ out = self.layer_norm(out)
+ out = out.masked_fill(mask, 0)
+ return super(TransformerEncoder, self).forward(out, attention_weights=attention_weights)
diff --git a/models/transformer/transformer.py b/models/transformer/transformer.py
new file mode 100644
index 0000000..1d1e40a
--- /dev/null
+++ b/models/transformer/transformer.py
@@ -0,0 +1,155 @@
+from matplotlib import image
+import torch
+import torch.nn.functional as F
+from torch import nn
+import copy
+from models.containers import ModuleList
+from models.transformer.utils import sinusoid_encoding_table
+from models.beam_search import *
+from ..captioning_model import CaptioningModel
+
+
+class SP(nn.Module):
+ """SP layer implementation
+
+ Args:
+ num_clusters : int
+ The number of pseudo regions
+ dim : int
+ Dimension of pseudo regions
+ alpha : float
+ Parameter of initialization. Larger value is harder assignment.
+ normalize_input : bool
+ If true, pseudo regions-wise L2 normalization is applied to input.
+ """
+ def __init__(self, num_regions=64, dim=128, alpha=100.0, normalize_input=True):
+ super().__init__()
+ self.num_regions = num_regions
+ self.dim = dim
+ self.alpha = alpha
+ self.normalize_input = normalize_input
+ self.conv = nn.Conv2d(dim, num_regions, kernel_size=(1, 1), bias=True)
+ self.centroids = nn.Parameter(torch.rand(num_regions, dim))
+ self.init_weights()
+ def init_weights(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ def forward(self, grids):
+
+ N, C = grids.shape[0], grids.shape[-1]
+
+ grids = grids.view(N, 7, 7, -1).permute(0,3,1,2).contiguous()
+
+ if self.normalize_input:
+ grids = F.normalize(grids, p=2, dim=1) # across descriptor dim
+
+ soft_assign = self.conv(grids).view(N, self.num_regions, -1)
+ soft_assign = F.softmax(soft_assign, dim=1)
+
+ x_flatten = grids.view(N, C, -1)
+
+ residual = x_flatten.expand(self.num_regions, -1, -1, -1).permute(1, 0, 2, 3).contiguous() - \
+ self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).contiguous().unsqueeze(0)
+
+ residual *= soft_assign.unsqueeze(2)
+ p = residual.sum(dim=-1)
+
+ p = F.normalize(p, p=2, dim=2) # intra-normalization
+ p = p.view(grids.size(0), -1)
+ p = F.normalize(p, p=2, dim=1) # L2 normalize
+
+ return p
+
+class Transformer(CaptioningModel):
+ def __init__(self, bos_idx, encoder, decoder,num_clusters, vocab_size, max_len, padding_idx, text_d_model=512):
+ super(Transformer, self).__init__()
+ self.bos_idx = bos_idx
+ self.encoder = encoder
+ self.decoder = decoder
+ self.text_d_model = text_d_model
+ self.num_clusters=num_clusters
+ self.padding_idx = padding_idx
+ self.word_emb = nn.Embedding(vocab_size, text_d_model, padding_idx=padding_idx)
+ self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, text_d_model, 0), freeze=True)
+
+ self.SP = SP(num_regions=self.num_clusters, dim=2048)
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.register_state('enc_output', None)
+ self.register_state('mask_enc', None)
+ self.init_weights()
+
+ @property
+ def d_model(self):
+ return self.decoder.d_model
+
+ def init_weights(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, mode, images, seq=None, max_len=None, eos_idx=None, beam_size=None, out_size=1, return_probs=False):
+ '''
+ images: torch.Size([50, 49, 2048])
+ seq: torch.Size([50, 27])
+ '''
+ if mode == 'xe':
+ bs, _, vis_dim = images.size()
+ # Grid feature
+ grid_enc_output, grid_mask_enc = self.encoder(images)
+
+ # Pseudo-region feature
+ pseudo_region = self.SP(images).view(bs, -1, vis_dim) # (N, num_clusters*2048) -> (N, num_clusters, 2048)
+ pseudo_region_enc_output, pseudo_region_mask_enc = self.encoder(pseudo_region)
+
+ output, mask = torch.cat([grid_enc_output, pseudo_region_enc_output],dim=1), torch.cat([grid_mask_enc, pseudo_region_mask_enc], dim=-1)
+ dec_output = self.decoder(seq, output, mask)
+
+ return dec_output
+
+ elif mode == 'rl':
+ bs = BeamSearch(self, max_len, eos_idx, beam_size)
+ return bs.apply(images, out_size, return_probs)
+
+ def init_state(self, b_s, device):
+ return [torch.zeros((b_s, 0), dtype=torch.long, device=device),
+ None, None]
+
+ def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
+ it = None
+ if mode == 'teacher_forcing':
+ raise NotImplementedError
+ elif mode == 'feedback':
+ if t == 0:
+ grid_enc_output, grid_mask_enc = self.encoder(visual)
+ bs, _, vis_dim = visual.size()
+ pseudo_region = self.SP(visual).view(bs, -1, vis_dim)
+ pseudo_region_enc_output, pseudo_region_mask_enc = self.encoder(pseudo_region)
+ self.enc_output, self.mask_enc = torch.cat([grid_enc_output, pseudo_region_enc_output],dim=1), torch.cat([grid_mask_enc, pseudo_region_mask_enc], dim=-1)
+
+ if isinstance(visual, torch.Tensor):
+ it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() # self.bos_idx: ''
+ else:
+ it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long()
+ else:
+ it = prev_output
+ return self.decoder(it, self.enc_output, self.mask_enc)
+
+
+class TransformerEnsemble(CaptioningModel):
+ def __init__(self, model: Transformer, weight_files):
+ super(TransformerEnsemble, self).__init__()
+ self.n = len(weight_files)
+ self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
+ for i in range(self.n):
+ state_dict_i = torch.load(weight_files[i])['state_dict']
+ self.models[i].load_state_dict(state_dict_i)
+
+ def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
+ out_ensemble = []
+ for i in range(self.n):
+ out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs)
+ out_ensemble.append(out_i.unsqueeze(0))
+
+ return torch.mean(torch.cat(out_ensemble, 0), dim=0)
diff --git a/models/transformer/utils.py b/models/transformer/utils.py
new file mode 100644
index 0000000..dc01168
--- /dev/null
+++ b/models/transformer/utils.py
@@ -0,0 +1,50 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def position_embedding(input, d_model):
+ input = input.view(-1, 1)
+ dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
+ sin = torch.sin(input / 10000 ** (2 * dim / d_model))
+ cos = torch.cos(input / 10000 ** (2 * dim / d_model))
+
+ out = torch.zeros((input.shape[0], d_model), device=input.device)
+ out[:, ::2] = sin
+ out[:, 1::2] = cos
+ return out
+
+
+def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
+ pos = torch.arange(max_len, dtype=torch.float32)
+ out = position_embedding(pos, d_model)
+
+ if padding_idx is not None:
+ out[padding_idx] = 0
+ return out
+
+
+class PositionWiseFeedForward(nn.Module):
+ '''
+ Position-wise feed forward layer
+ '''
+
+ def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
+ super(PositionWiseFeedForward, self).__init__()
+ self.identity_map_reordering = identity_map_reordering
+ self.fc1 = nn.Linear(d_model, d_ff)
+ self.fc2 = nn.Linear(d_ff, d_model)
+ self.dropout = nn.Dropout(p=dropout)
+ self.dropout_2 = nn.Dropout(p=dropout)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ def forward(self, input):
+ if self.identity_map_reordering:
+ out = self.layer_norm(input)
+ out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
+ out = input + self.dropout(torch.relu(out))
+ else:
+ out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
+ out = self.dropout(out)
+ out = self.layer_norm(input + out)
+ return out
diff --git a/test_transformer.py b/test_transformer.py
new file mode 100644
index 0000000..90b2aeb
--- /dev/null
+++ b/test_transformer.py
@@ -0,0 +1,93 @@
+import random
+import os
+from data import ImageDetectionsField, TextField, RawField
+from data import COCO, DataLoader
+import evaluation
+from models.transformer import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention
+from visualize import visualize_grid_attention_v2
+import torch
+from tqdm import tqdm
+import argparse
+import pickle
+import numpy as np
+import time
+
+random.seed(1234)
+torch.manual_seed(1234)
+np.random.seed(1234)
+
+
+def predict_captions(model, dataloader, text_field):
+ import itertools
+ model.eval()
+ seq_len = 20
+ beam_size = 5
+ gen = {}
+ gts = {}
+ with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar:
+ for it, (images, caps_gt, _) in enumerate(iter(dataloader)):
+ images = images.to(device)
+ with torch.no_grad():
+ out, _ = model(mode='rl', images=images, max_len=seq_len, eos_idx=text_field.vocab.stoi[''], beam_size=beam_size, out_size=1)
+ # print(out.size(), att_map.size())
+ caps_gen = text_field.decode(out, join_words=False)
+ for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
+ gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
+ gen['%d_%d' % (it, i)] = [gen_i.strip(), ]
+ gts['%d_%d' % (it, i)] = gts_i
+ pbar.update()
+
+ gts = evaluation.PTBTokenizer.tokenize(gts)
+ gen = evaluation.PTBTokenizer.tokenize(gen)
+ scores, _ = evaluation.compute_scores(gts, gen)
+
+ return scores
+
+
+if __name__ == '__main__':
+ start_time = time.time()
+ device = torch.device('cuda')
+
+ parser = argparse.ArgumentParser(description='Transformer')
+ parser.add_argument('--batch_size', type=int, default=10)
+ parser.add_argument('--workers', type=int, default=4)
+ parser.add_argument('--m', type=int, default=40)
+
+ parser.add_argument('--features_path', type=str, default='/home/zhanghaonan/RSTNet-master/X101-features/X101_grid_feats_coco_trainval.hdf5')
+ parser.add_argument('--annotation_folder', type=str, default='/home/zhanghaonan/RSTNet-master/m2_annotations')
+
+ # the path of tested model and vocabulary
+ parser.add_argument('--model_path', type=str, default='saved_transformer_models/demo_rl_v5_best_test.pth')
+ parser.add_argument('--vocab_path', type=str, default='vocab.pkl')
+ parser.add_argument('--num_clusters', type=int, default=5)
+ args = parser.parse_args()
+
+ print('Transformer Evaluation')
+
+ # Pipeline for image regions
+ image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False)
+
+ # Pipeline for text
+ text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy',
+ remove_punctuation=True, nopoints=False)
+
+ # Create the dataset
+ dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
+ _, _, test_dataset = dataset.splits
+ text_field.vocab = pickle.load(open(args.vocab_path, 'rb'))
+
+ # Model and dataloaders
+ encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': args.m})
+ decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi[''])
+
+ model = Transformer(text_field.vocab.stoi[''], encoder, decoder, args.num_clusters, len(text_field.vocab), 54, text_field.vocab.stoi[''], 512).to(device)
+
+ data = torch.load(args.model_path)
+ model.load_state_dict({k.replace('module.',''):v for k,v in data['state_dict'].items()})
+
+ dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField(), 'add_text':text_field})
+ dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers)
+
+ scores = predict_captions(model, dict_dataloader_test, text_field)
+ print(scores)
+ print('it costs {} s to test.'.format(time.time() - start_time))
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000..46d538d
--- /dev/null
+++ b/train.sh
@@ -0,0 +1 @@
+python train_transformer.py --exp_name demo --num_clusters 5
diff --git a/train_transformer.py b/train_transformer.py
new file mode 100644
index 0000000..8abadf1
--- /dev/null
+++ b/train_transformer.py
@@ -0,0 +1,485 @@
+import random
+import os
+# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
+from torch.nn.modules.loss import MSELoss
+from data import ImageDetectionsField, TextField, RawField
+from data import COCO, DataLoader
+import evaluation
+from evaluation import PTBTokenizer, Cider
+
+from models.transformer import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention
+
+
+import torch
+from torch.optim import Adam
+from torch.optim.lr_scheduler import LambdaLR
+from torch.nn import NLLLoss, MSELoss
+from tqdm import tqdm
+from torch.utils.tensorboard import SummaryWriter
+import argparse
+
+
+import pickle
+import numpy as np
+import itertools
+from shutil import copyfile
+
+import torch.multiprocessing as mp
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
+
+def evaluate_loss(model, dataloader, loss_fn, text_field, e, device):
+
+ # Validation loss
+ model.eval()
+ running_loss = .0
+ with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader), disable=device!=0) as pbar:
+ with torch.no_grad():
+ for it, (detections, captions) in enumerate(dataloader):
+ detections, captions = detections.to(device), captions.to(device)
+ out = model(mode='xe', images=detections, seq=captions)
+ captions_gt = captions[:, 1:].contiguous()
+ out = out[:, :-1].contiguous()
+
+ loss = loss_fn[0](out.view(-1, len(text_field.vocab)), captions_gt.view(-1))
+ this_loss = loss.item()
+ running_loss += this_loss
+
+ pbar.set_postfix(loss=running_loss / (it + 1))
+ pbar.update()
+
+ val_loss = running_loss / len(dataloader)
+ return val_loss
+
+def evaluate_metrics(model, dataloader, text_field, e, device):
+ import itertools
+ model.eval()
+ seq_len = 20
+ beam_size = 5
+ gen = {}
+ gts = {}
+
+ with tqdm(desc='Epoch %d - evaluation' % e, unit='it', total=len(dataloader), disable=device!=0) as pbar:
+ for it, (images, caps_gt, captions) in enumerate(iter(dataloader)):
+ images = images.to(device)
+ with torch.no_grad():
+ out, _ = model(mode='rl', images=images, max_len=seq_len, eos_idx=text_field.vocab.stoi[''], beam_size=beam_size, out_size=1)
+ caps_gen = text_field.decode(out, join_words=False)
+ for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
+ gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
+ gen['%d_%d' % (it, i)] = [gen_i, ]
+ gts['%d_%d' % (it, i)] = gts_i
+ pbar.update()
+
+ gts = evaluation.PTBTokenizer.tokenize(gts)
+ gen = evaluation.PTBTokenizer.tokenize(gen)
+ scores, _ = evaluation.compute_scores(gts, gen)
+ return scores
+
+def train_xe(model, dataloader, optim, text_field, scheduler, loss_fn, e, device):
+ # Training with cross-entropy
+ model.train()
+ scheduler.step()
+ if device == 0:
+ print('lr = ', optim.state_dict()['param_groups'][0]['lr'])
+
+ running_loss = .0
+ with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader), disable=device!=0) as pbar:
+ for it, (detections, captions) in enumerate(dataloader):
+ detections, captions = detections.to(device), captions.to(device)
+ out = model(mode='xe', images=detections, seq=captions)
+ optim.zero_grad()
+ captions_gt = captions[:, 1:].contiguous()
+ out = out[:, :-1].contiguous()
+
+ loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1))
+
+ loss.backward()
+
+ optim.step()
+ this_loss = loss.item()
+ running_loss += this_loss
+
+ pbar.set_postfix(loss=running_loss / (it + 1))
+ pbar.update()
+
+ # scheduler.step()
+
+ loss = running_loss / len(dataloader)
+ return loss
+
+
+def train_scst(model, dataloader, optim, cider, text_field, scheduler_rl, e, device):
+ # Training with self-critical
+ # tokenizer_pool = multiprocessing.Pool()
+ running_reward = .0
+ running_reward_baseline = .0
+
+ model.train()
+ scheduler_rl.step()
+ if device == 0:
+ print('lr = ', optim.state_dict()['param_groups'][0]['lr'])
+
+ running_loss = .0
+ seq_len = 20
+ beam_size = 5
+ # kwargs = {
+ # 'text_flag': args.text2text
+ # }
+ with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader), disable=device!=0) as pbar:
+ for it, (detections, caps_gt, captions) in enumerate(dataloader):
+ detections = detections.to(device)
+ text = captions.to(device)
+ # kwargs['text'] = text
+ outs, log_probs = model(mode='rl', images=detections, max_len=seq_len, eos_idx=text_field.vocab.stoi[''], beam_size=beam_size, out_size=beam_size)
+ optim.zero_grad()
+ # Rewards
+ caps_gen = text_field.decode(outs.view(-1, seq_len))
+ caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt)))
+ # caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt])
+ caps_gen = evaluation.PTBTokenizer.tokenize(caps_gen)
+ caps_gt = evaluation.PTBTokenizer.tokenize(caps_gt)
+ reward = cider.compute_score(caps_gt, caps_gen)[1].astype(np.float32)
+ reward = torch.from_numpy(reward).to(device).view(detections.shape[0], beam_size)
+ reward_baseline = torch.mean(reward, -1, keepdim=True)
+ loss = -torch.mean(log_probs, -1) * (reward - reward_baseline)
+
+ loss = loss.mean()
+ loss.backward()
+ optim.step()
+
+ running_loss += loss.item()
+ running_reward += reward.mean().item()
+ running_reward_baseline += reward_baseline.mean().item()
+ # pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1),
+ # reward_baseline=running_reward_baseline / (it + 1))
+ pbar.update()
+ # scheduler_rl.step()
+
+ loss = running_loss / len(dataloader)
+ reward = running_reward / len(dataloader)
+ reward_baseline = running_reward_baseline / len(dataloader)
+ # tokenizer_pool.close()
+ return loss, reward, reward_baseline
+
+
+def _changeConfig(config, worldSize):
+ batchSize = config.batch_size * worldSize
+ # exponent = math.log2(batchSize)
+ # scale = 3 - exponent / 2
+ # config.xe_base_lr /= (2 ** scale)
+ # config.rl_base_lr /= (2 ** scale)
+ config.xe_base_lr *= worldSize
+ config.rl_base_lr *= worldSize
+
+def _generalConfig(rank: int, worldSize: int):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "2227" #任意一个没被占用的端口号
+ torch.autograd.set_detect_anomaly(False)
+ torch.backends.cudnn.benchmark = True
+ random.seed(1234)
+ torch.manual_seed(1234)
+ np.random.seed(1234)
+ torch.cuda.set_device(rank)
+ dist.init_process_group("nccl", world_size=worldSize, rank=rank)
+
+
+def train(rank, worldSize, args):
+ _generalConfig(rank, worldSize)
+ if rank == 0:
+ print('Rank{}: Transformer Training'.format(rank))
+ if rank == 0:
+ writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name))
+
+ # Pipeline for image regions
+ image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False)
+ # Pipeline for text
+ text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', remove_punctuation=True, nopoints=False)
+
+ # Create the dataset
+ dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
+ train_dataset, val_dataset, test_dataset = dataset.splits
+
+ if not os.path.isfile('vocab.pkl'):
+ print("Rank{}: Building vocabulary".format(rank))
+ text_field.build_vocab(train_dataset, val_dataset, min_freq=5)
+ pickle.dump(text_field.vocab, open('vocab.pkl', 'wb'))
+ else:
+ print('Rank{}: Loading from vocabulary'.format(rank))
+ text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))
+
+ # Model and dataloaders
+ encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': args.m})
+ decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi[''])
+
+ model = Transformer(text_field.vocab.stoi[''], encoder, decoder, args.num_clusters, len(text_field.vocab), 54, text_field.vocab.stoi[''], 512)
+ model = torch.nn.parallel.DistributedDataParallel(model.to(rank), device_ids=[rank], output_device=rank, broadcast_buffers=False, find_unused_parameters=True)
+
+ dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField(), 'add_text':text_field})
+ ref_caps_train = train_dataset.text()
+ cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train))
+ dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField(), 'add_text':text_field})
+ dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField(), 'add_text':text_field})
+
+ '''
+ def lambda_lr(s):
+ warm_up = args.warmup
+ s += 1
+ return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5)
+ '''
+
+ def lambda_lr(s):
+ print("s:", s)
+ if s <= 3:
+ lr = args.xe_base_lr * s / 4
+ elif s <= 10:
+ lr = args.xe_base_lr
+ elif s <= 12:
+ lr = args.xe_base_lr * 0.2
+ else:
+ lr = args.xe_base_lr * 0.2 * 0.2
+ return lr
+
+ def lambda_lr_rl(s):
+ refine_epoch = args.refine_epoch_rl
+ print("rl_s:", s)
+ if s <= refine_epoch:
+ lr = args.rl_base_lr
+ elif s <= refine_epoch + 3:
+ lr = args.rl_base_lr * 0.2
+ elif s <= refine_epoch + 6:
+ lr = args.rl_base_lr * 0.2 * 0.2
+ else:
+ lr = args.rl_base_lr * 0.2 * 0.2 * 0.2
+ return lr
+
+ # Initial conditions
+ optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98))
+ scheduler = LambdaLR(optim, lambda_lr)
+
+ optim_rl = Adam(model.parameters(), lr=1, betas=(0.9, 0.98))
+ scheduler_rl = LambdaLR(optim_rl, lambda_lr_rl)
+
+ loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi[''])
+ loss_align = MSELoss()
+ loss = (loss_fn, loss_align)
+ use_rl = False
+ best_cider = .0
+ best_test_cider = 0.
+ patience = 0
+ start_epoch = 0
+
+ if args.resume_last or args.resume_best:
+ if args.resume_last:
+ fname = 'saved_transformer_models/%s_last.pth' % args.exp_name
+ else:
+ fname = 'saved_transformer_models/%s_best.pth' % args.exp_name
+
+ # fname = 'saved_transformer_models/align_share_K5_init_vlad_last.pth'
+
+ if os.path.exists(fname):
+ data = torch.load(fname)
+ torch.set_rng_state(data['torch_rng_state'])
+ torch.cuda.set_rng_state(data['cuda_rng_state'])
+ np.random.set_state(data['numpy_rng_state'])
+ random.setstate(data['random_rng_state'])
+ model.load_state_dict(data['state_dict'], strict=False)
+ """
+ optim.load_state_dict(data['optimizer'])
+ scheduler.load_state_dict(data['scheduler'])
+ """
+ start_epoch = data['epoch'] + 1
+ best_cider = data['best_cider']
+ best_test_cider = data['best_test_cider']
+ patience = data['patience']
+ use_rl = data['use_rl']
+
+ if use_rl:
+ optim.load_state_dict(data['optimizer'])
+ scheduler.load_state_dict(data['scheduler'])
+ else:
+ optim_rl.load_state_dict(data['optimizer'])
+ scheduler_rl.load_state_dict(data['scheduler'])
+
+ print('Resuming from epoch %d, validation loss %f, best cider %f, and best_test_cider %f' % (
+ data['epoch'], data['val_loss'], data['best_cider'], data['best_test_cider']))
+ print('patience:', data['patience'])
+
+ print("Training starts")
+ for e in range(start_epoch, start_epoch + 100):
+ trainSampler = DistributedSampler(train_dataset, worldSize, rank)
+ trainSampler.set_epoch(e)
+ dataloader_train = DataLoader(train_dataset, sampler=trainSampler, batch_size=args.batch_size, pin_memory=True, drop_last=False, num_workers=args.workers, persistent_workers=True)
+
+ dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
+
+ dict_trainSampler = DistributedSampler(dict_dataset_train, worldSize, rank)
+ dict_trainSampler.set_epoch(e)
+ dict_dataloader_train = DataLoader(dict_dataset_train, sampler=dict_trainSampler, batch_size=args.batch_size // 5, pin_memory=True, drop_last=False, num_workers=args.workers, persistent_workers=True)
+
+ dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5)
+ dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5)
+
+ if not use_rl:
+ train_loss = train_xe(model, dataloader_train, optim, text_field, scheduler, loss_fn, e, rank)
+ if rank == 0:
+ writer.add_scalar('data/train_loss', train_loss, e)
+ else:
+ train_loss, reward, reward_baseline = train_scst(model, dict_dataloader_train, optim_rl, cider_train, text_field, scheduler_rl, e, rank)
+ if rank == 0:
+ writer.add_scalar('data/train_loss', train_loss, e)
+ writer.add_scalar('data/reward', reward, e)
+ writer.add_scalar('data/reward_baseline', reward_baseline, e)
+
+ # Validation loss
+ val_loss = evaluate_loss(model, dataloader_val, loss, text_field, e, rank)
+ if rank == 0:
+ writer.add_scalar('data/val_loss', val_loss, e)
+
+ # Validation scores
+ scores = evaluate_metrics(model, dict_dataloader_val, text_field, e, rank)
+ val_cider = scores['CIDEr']
+ if rank == 0:
+ print("Validation scores", scores)
+ writer.add_scalar('data/val_cider', val_cider, e)
+ writer.add_scalar('data/val_bleu1', scores['BLEU'][0], e)
+ writer.add_scalar('data/val_bleu4', scores['BLEU'][3], e)
+ writer.add_scalar('data/val_meteor', scores['METEOR'], e)
+ writer.add_scalar('data/val_rouge', scores['ROUGE'], e)
+
+ # Test scores
+ scores = evaluate_metrics(model, dict_dataloader_test, text_field, e, rank)
+ test_cider = scores['CIDEr']
+ if rank == 0:
+ print("Test scores", scores)
+ writer.add_scalar('data/test_cider', test_cider, e)
+ writer.add_scalar('data/test_bleu1', scores['BLEU'][0], e)
+ writer.add_scalar('data/test_bleu4', scores['BLEU'][3], e)
+ writer.add_scalar('data/test_meteor', scores['METEOR'], e)
+ writer.add_scalar('data/test_rouge', scores['ROUGE'], e)
+
+ # Prepare for next epoch
+ best = False
+ if val_cider >= best_cider:
+ best_cider = val_cider
+ patience = 0
+ best = True
+ else:
+ patience += 1
+
+ best_test = False
+ if test_cider >= best_test_cider:
+ best_test_cider = test_cider
+ best_test = True
+
+ switch_to_rl = False
+ exit_train = False
+
+ if patience == 5:
+ if e < args.xe_least: # xe stage train 15 epoches at least
+ if rank == 0:
+ print('special treatment, e = {}'.format(e))
+ use_rl = False
+ switch_to_rl = False
+ patience = 0
+ elif not use_rl:
+ use_rl = True
+ switch_to_rl = True
+ patience = 0
+
+ optim_rl = Adam(model.parameters(), lr=1, betas=(0.9, 0.98))
+ scheduler_rl = LambdaLR(optim_rl, lambda_lr_rl)
+
+ for k in range(e-1):
+ scheduler_rl.step()
+ if rank == 0:
+ print("Switching to RL")
+ else:
+ if rank == 0:
+ print('patience reached.')
+ exit_train = True
+
+ if e == args.xe_most: # xe stage no more than 20 epoches
+ if not use_rl:
+ use_rl = True
+ switch_to_rl = True
+ patience = 0
+
+ optim_rl = Adam(model.parameters(), lr=1, betas=(0.9, 0.98))
+ scheduler_rl = LambdaLR(optim_rl, lambda_lr_rl)
+
+ for k in range(e-1):
+ scheduler_rl.step()
+ if rank == 0:
+ print("Switching to RL")
+ if rank == 0:
+ if switch_to_rl and not best:
+ data = torch.load('saved_transformer_models/%s_best.pth' % args.exp_name)
+ torch.set_rng_state(data['torch_rng_state'])
+ torch.cuda.set_rng_state(data['cuda_rng_state'])
+ np.random.set_state(data['numpy_rng_state'])
+ random.setstate(data['random_rng_state'])
+ model.load_state_dict(data['state_dict'])
+ print('Resuming from epoch %d, validation loss %f, best_cider %f, and best test_cider %f' % (
+ data['epoch'], data['val_loss'], data['best_cider'], data['best_test_cider']))
+
+ torch.save({
+ 'torch_rng_state': torch.get_rng_state(),
+ 'cuda_rng_state': torch.cuda.get_rng_state(),
+ 'numpy_rng_state': np.random.get_state(),
+ 'random_rng_state': random.getstate(),
+ 'epoch': e,
+ 'val_loss': val_loss,
+ 'val_cider': val_cider,
+ 'state_dict': model.state_dict(),
+ 'optimizer': optim.state_dict() if not use_rl else optim_rl.state_dict(),
+ 'scheduler': scheduler.state_dict() if not use_rl else scheduler_rl.state_dict(),
+ 'patience': patience,
+ 'best_cider': best_cider,
+ 'best_test_cider': best_test_cider,
+ 'use_rl': use_rl,
+ }, 'saved_transformer_models/%s_last.pth' % args.exp_name)
+
+ if best:
+ copyfile('saved_transformer_models/%s_last.pth' % args.exp_name, 'saved_transformer_models/%s_best.pth' % args.exp_name)
+ if best_test:
+ copyfile('saved_transformer_models/%s_last.pth' % args.exp_name, 'saved_transformer_models/%s_best_test.pth' % args.exp_name)
+
+ if exit_train:
+ if rank==0:
+ writer.close()
+ break
+
+
+if __name__ == '__main__':
+ # device = torch.device('cuda')
+
+ parser = argparse.ArgumentParser(description='Transformer')
+ parser.add_argument('--exp_name', type=str, default='demo')
+ parser.add_argument('--batch_size', type=int, default=50)
+ parser.add_argument('--workers', type=int, default=4)
+ parser.add_argument('--m', type=int, default=40)
+ parser.add_argument('--head', type=int, default=8)
+ parser.add_argument('--warmup', type=int, default=10000)
+ parser.add_argument('--resume_last', action='store_true')
+ parser.add_argument('--resume_best', action='store_true')
+ parser.add_argument('--features_path', type=str, default='./X101-features/X101_grid_feats_coco_trainval.hdf5')
+ parser.add_argument('--annotation_folder', type=str, default='./m2_annotations')
+
+ parser.add_argument('--logs_folder', type=str, default='tensorboard_logs')
+ parser.add_argument('--xe_least', type=int, default=15)
+ parser.add_argument('--xe_most', type=int, default=20) # 18
+ parser.add_argument('--refine_epoch_rl', type=int, default=28) # 35
+
+ parser.add_argument('--xe_base_lr', type=float, default=0.0001)
+ parser.add_argument('--rl_base_lr', type=float, default=5e-6)
+ parser.add_argument('--num_clusters', type=int, default=5)
+ parser.add_argument('--text2text', type=int, default=0)
+
+ args = parser.parse_args()
+ print(args)
+ ## DDP Training
+ worldSize = 8
+ _changeConfig(args, worldSize)
+ print('\nDistribute config', args)
+ mp.spawn(train, (worldSize, args), worldSize)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..0162b87
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,17 @@
+from .utils import download_from_url
+from .typing import *
+
+def get_batch_size(x: TensorOrSequence) -> int:
+ if isinstance(x, torch.Tensor):
+ b_s = x.size(0)
+ else:
+ b_s = x[0].size(0)
+ return b_s
+
+
+def get_device(x: TensorOrSequence) -> int:
+ if isinstance(x, torch.Tensor):
+ b_s = x.device
+ else:
+ b_s = x[0].device
+ return b_s
diff --git a/utils/typing.py b/utils/typing.py
new file mode 100644
index 0000000..3270f86
--- /dev/null
+++ b/utils/typing.py
@@ -0,0 +1,5 @@
+from typing import Union, Sequence, Tuple
+import torch
+
+TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor]
+TensorOrNone = Union[torch.Tensor, None]
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000..ef80d23
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,27 @@
+import requests
+
+def download_from_url(url, path):
+ """Download file, with logic (from tensor2tensor) for Google Drive"""
+ if 'drive.google.com' not in url:
+ print('Downloading %s; may take a few minutes' % url)
+ r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
+ with open(path, "wb") as file:
+ file.write(r.content)
+ return
+ print('Downloading from Google Drive; may take a few minutes')
+ confirm_token = None
+ session = requests.Session()
+ response = session.get(url, stream=True)
+ for k, v in response.cookies.items():
+ if k.startswith("download_warning"):
+ confirm_token = v
+
+ if confirm_token:
+ url = url + "&confirm=" + confirm_token
+ response = session.get(url, stream=True)
+
+ chunk_size = 16 * 1024
+ with open(path, "wb") as f:
+ for chunk in response.iter_content(chunk_size):
+ if chunk:
+ f.write(chunk)
diff --git a/vocab.pkl b/vocab.pkl
new file mode 100644
index 0000000..091dbb4
Binary files /dev/null and b/vocab.pkl differ
diff --git a/vocab_idx2word.pkl b/vocab_idx2word.pkl
new file mode 100644
index 0000000..abaa00b
Binary files /dev/null and b/vocab_idx2word.pkl differ
diff --git a/vocab_language/vocab_bert_language.pkl b/vocab_language/vocab_bert_language.pkl
new file mode 100644
index 0000000..1858f6a
Binary files /dev/null and b/vocab_language/vocab_bert_language.pkl differ