diff --git a/main_test_AllinOne.py b/main_test_AllinOne.py new file mode 100644 index 0000000..7665006 --- /dev/null +++ b/main_test_AllinOne.py @@ -0,0 +1,116 @@ +import torch +import argparse +import numpy as np +from modules.tokenizers import Tokenizer +from modules.dataloaders import R2DataLoader +from modules.metrics import compute_scores +from modules.tester_AllinOne import Tester +from modules.loss import compute_loss +from models.histgen_model import HistGenModel + +def parse_agrs(): + parser = argparse.ArgumentParser() + + # Data input settings + parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') + parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') + + # Data loader settings + parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr', 'wsi_report'], help='the dataset to be used.') + parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') + parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') + parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') + parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') + + parser.add_argument('--model_name', type=str, default='histgen', choices=['histgen'], help='model used for experiment') + + # Model settings (for visual extractor) + parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') + parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') + # Model settings (for Transformer) + parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') + parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') + parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') + parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') + parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') + parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') + parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') + parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') + parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') + parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') + # for Cross-modal context module + parser.add_argument('--topk', type=int, default=32, help='the number of k.') + parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.') + parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.') + # for Local-global hierachical visual encoder + parser.add_argument("--region_size", type=int, default=256, help="the size of the region for region transformer.") + + # Sample related + parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') + parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') + parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') + parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') + parser.add_argument('--group_size', type=int, default=1, help='the group size.') + parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') + parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') + parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') + + # Trainer settings + parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') + parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') + parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') + parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments') + parser.add_argument('--save_period', type=int, default=1, help='the saving period.') + parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') + parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') + parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') + parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') + + # Optimization + parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') + parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') + parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') + parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') + parser.add_argument('--amsgrad', type=bool, default=True, help='.') + + # Learning Rate Scheduler + parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') + parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') + parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') + + # Others + parser.add_argument('--seed', type=int, default=9233, help='.') + parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') + parser.add_argument('--load', type=str, help='whether to load a pre-trained model.') + + args = parser.parse_args() + return args + + +def main(): + # parse arguments + args = parse_agrs() + + # fix random seeds + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(args.seed) + + tokenizer = Tokenizer(args) + test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) + model = HistGenModel(args, tokenizer) + + # get function handles of loss and metrics + criterion = compute_loss + metrics = compute_scores + + # build trainer and start to train + tester = Tester(model, criterion, metrics, args, test_dataloader) + tester.test() + + +if __name__ == '__main__': + main() diff --git a/main_train_AllinOne.py b/main_train_AllinOne.py new file mode 100644 index 0000000..4d37d65 --- /dev/null +++ b/main_train_AllinOne.py @@ -0,0 +1,127 @@ +import torch +import argparse +import numpy as np +from modules.tokenizers import Tokenizer +from modules.dataloaders import R2DataLoader +from modules.metrics import compute_scores +from modules.optimizers import build_optimizer, build_lr_scheduler +from modules.trainer_AllinOne import Trainer +from modules.loss import compute_loss +from models.histgen_model import HistGenModel + +def parse_agrs(): + parser = argparse.ArgumentParser() + + # Data input settings + parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') + parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') + + # Data loader settings + parser.add_argument('--dataset_name', type=str, default='wsi_report', choices=['iu_xray', 'mimic_cxr', 'wsi_report'], help='the dataset to be used.') + parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') + parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') + parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') + parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') + + parser.add_argument('--model_name', type=str, default='histgen', choices=['histgen'], help='model used for experiment') + + # Model settings (for visual extractor) + parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') + parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') + # Model settings (for Transformer) + parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') + parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') + parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') + parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') + parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') + parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') + parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') + parser.add_argument('--bos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--eos_idx', type=int, default=0, help='the index of .') + parser.add_argument('--pad_idx', type=int, default=0, help='the index of .') + parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') + parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') + # for Cross-modal context module + parser.add_argument('--topk', type=int, default=32, help='the number of k.') + parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.') + parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.') + # for Local-global hierachical visual encoder + parser.add_argument("--region_size", type=int, default=256, help="the size of the region for region transformer.") + parser.add_argument("--prototype_num", type=int, default=512, help="the number of visual prototypes for cross-modal interaction") + + # Sample related + parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') + parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') + parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') + parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') + parser.add_argument('--group_size', type=int, default=1, help='the group size.') + parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') + parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') + parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') + + # Trainer settings + parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') + parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') + parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') + parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments') + parser.add_argument('--save_period', type=int, default=1, help='the saving period.') + parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') + parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') + parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') + parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') + + # Optimization + parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') + parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') + parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') + parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') + parser.add_argument('--amsgrad', type=bool, default=True, help='.') + + # Learning Rate Scheduler + parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') + parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') + parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') + + # Others + parser.add_argument('--seed', type=int, default=9233, help='.') + parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') + + args = parser.parse_args() + return args + + +def main(): + # parse arguments + args = parse_agrs() + + # fix random seeds + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(args.seed) + + + tokenizer = Tokenizer(args) + + train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) + val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) + test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) + + # build model architecture + model = HistGenModel(args, tokenizer) + + # get function handles of loss and metrics + criterion = compute_loss + metrics = compute_scores + + # build optimizer, learning rate scheduler + optimizer = build_optimizer(args, model) + lr_scheduler = build_lr_scheduler(args, optimizer) + + # build trainer and start to train + trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader) + trainer.train() + + +if __name__ == '__main__': + main() diff --git a/models/__pycache__/histgen_model.cpython-310.pyc b/models/__pycache__/histgen_model.cpython-310.pyc new file mode 100644 index 0000000..ee63f79 Binary files /dev/null and b/models/__pycache__/histgen_model.cpython-310.pyc differ diff --git a/models/histgen_model.py b/models/histgen_model.py new file mode 100644 index 0000000..0a40750 --- /dev/null +++ b/models/histgen_model.py @@ -0,0 +1,34 @@ +import numpy as np +import torch +import torch.nn as nn +from modules.visual_extractor import VisualExtractor +from modules.histgen_module import BaseHistGen + +class HistGenModel(nn.Module): + def __init__(self, args, tokenizer): + super(HistGenModel, self).__init__() + self.args = args + self.tokenizer = tokenizer + self.encoder_decoder = BaseHistGen(args, tokenizer) + self.wsi_mapping = torch.nn.Linear(768, self.args.d_vf) if "ctranspath" in args.image_dir else torch.nn.Linear(1024, self.args.d_vf) + self.forward = self.forward_pathology + self.visual_extractor = VisualExtractor(args) + + def __str__(self): + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return super().__str__() + '\nTrainable parameters: {}'.format(params) + + def forward_pathology(self, images, targets=None, mode='train', update_opts={}): + + att_feats = self.wsi_mapping(images) + fc_feats = torch.mean(att_feats, dim=1) + + if mode == 'train': + output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward') + return output + elif mode == 'sample': + output, output_probs = self.encoder_decoder(fc_feats, att_feats, mode='sample') + return output + else: + raise ValueError diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/__pycache__/__init__.cpython-310.pyc b/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..593c300 Binary files /dev/null and b/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/__pycache__/att_model.cpython-310.pyc b/modules/__pycache__/att_model.cpython-310.pyc new file mode 100644 index 0000000..2389228 Binary files /dev/null and b/modules/__pycache__/att_model.cpython-310.pyc differ diff --git a/modules/__pycache__/caption_model.cpython-310.pyc b/modules/__pycache__/caption_model.cpython-310.pyc new file mode 100644 index 0000000..efc8631 Binary files /dev/null and b/modules/__pycache__/caption_model.cpython-310.pyc differ diff --git a/modules/__pycache__/dataloaders.cpython-310.pyc b/modules/__pycache__/dataloaders.cpython-310.pyc new file mode 100644 index 0000000..aa09c9f Binary files /dev/null and b/modules/__pycache__/dataloaders.cpython-310.pyc differ diff --git a/modules/__pycache__/datasets.cpython-310.pyc b/modules/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000..aaee46d Binary files /dev/null and b/modules/__pycache__/datasets.cpython-310.pyc differ diff --git a/modules/__pycache__/histgen_module.cpython-310.pyc b/modules/__pycache__/histgen_module.cpython-310.pyc new file mode 100644 index 0000000..58f4a29 Binary files /dev/null and b/modules/__pycache__/histgen_module.cpython-310.pyc differ diff --git a/modules/__pycache__/loss.cpython-310.pyc b/modules/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000..9b32166 Binary files /dev/null and b/modules/__pycache__/loss.cpython-310.pyc differ diff --git a/modules/__pycache__/metrics.cpython-310.pyc b/modules/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000..a82904c Binary files /dev/null and b/modules/__pycache__/metrics.cpython-310.pyc differ diff --git a/modules/__pycache__/optimizers.cpython-310.pyc b/modules/__pycache__/optimizers.cpython-310.pyc new file mode 100644 index 0000000..71cd390 Binary files /dev/null and b/modules/__pycache__/optimizers.cpython-310.pyc differ diff --git a/modules/__pycache__/tokenizers.cpython-310.pyc b/modules/__pycache__/tokenizers.cpython-310.pyc new file mode 100644 index 0000000..0045310 Binary files /dev/null and b/modules/__pycache__/tokenizers.cpython-310.pyc differ diff --git a/modules/__pycache__/trainer_AllinOne.cpython-310.pyc b/modules/__pycache__/trainer_AllinOne.cpython-310.pyc new file mode 100644 index 0000000..d70d70f Binary files /dev/null and b/modules/__pycache__/trainer_AllinOne.cpython-310.pyc differ diff --git a/modules/__pycache__/utils.cpython-310.pyc b/modules/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..7dca61d Binary files /dev/null and b/modules/__pycache__/utils.cpython-310.pyc differ diff --git a/modules/__pycache__/visual_extractor.cpython-310.pyc b/modules/__pycache__/visual_extractor.cpython-310.pyc new file mode 100644 index 0000000..7e20a43 Binary files /dev/null and b/modules/__pycache__/visual_extractor.cpython-310.pyc differ diff --git a/modules/att_model.py b/modules/att_model.py new file mode 100644 index 0000000..3c2dedc --- /dev/null +++ b/modules/att_model.py @@ -0,0 +1,321 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence + +import modules.utils as utils +from modules.caption_model import CaptionModel + + +def sort_pack_padded_sequence(input, lengths): + sorted_lengths, indices = torch.sort(lengths, descending=True) + tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) + inv_ix = indices.clone() + inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix) + return tmp, inv_ix + + +def pad_unsort_packed_sequence(input, inv_ix): + tmp, _ = pad_packed_sequence(input, batch_first=True) + tmp = tmp[inv_ix] + return tmp + + +def pack_wrapper(module, att_feats, att_masks): + if att_masks is not None: + packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) + return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) + else: + return module(att_feats) + + +class AttModel(CaptionModel): + def __init__(self, args, tokenizer): + super(AttModel, self).__init__() + self.args = args + self.tokenizer = tokenizer + self.vocab_size = len(tokenizer.idx2token) + self.input_encoding_size = args.d_model + self.rnn_size = args.d_ff + self.num_layers = args.num_layers + self.drop_prob_lm = args.drop_prob_lm + self.max_seq_length = args.max_seq_length + self.att_feat_size = args.d_vf + self.att_hid_size = args.d_model + + self.bos_idx = args.bos_idx + self.eos_idx = args.eos_idx + self.pad_idx = args.pad_idx + + self.use_bn = args.use_bn + + self.embed = lambda x: x + self.fc_embed = lambda x: x + self.att_embed = nn.Sequential(*( + ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) + + (nn.Linear(self.att_feat_size, self.input_encoding_size), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm)) + + ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ()))) + + def clip_att(self, att_feats, att_masks): + # Clip the length of att_masks and att_feats to the maximum length + if att_masks is not None: + max_len = att_masks.data.long().sum(1).max() + att_feats = att_feats[:, :max_len].contiguous() + att_masks = att_masks[:, :max_len].contiguous() + return att_feats, att_masks + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + # embed fc and att feats + fc_feats = self.fc_embed(fc_feats) + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + # Project the attention feats first to reduce memory and computation comsumptions. + p_att_feats = self.ctx2att(att_feats) + + return fc_feats, att_feats, p_att_feats, att_masks + + def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): + # 'it' contains a word index + xt = self.embed(it) + + output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) + if output_logsoftmax: + logprobs = F.log_softmax(self.logit(output), dim=1) + else: + logprobs = self.logit(output) + + return logprobs, state + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + sample_n = opt.get('sample_n', 10) + # when sample_n == beam_size then each beam is a sample. + assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' + batch_size = fc_feats.size(0) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + + state = self.init_hidden(batch_size) + + # first step, feed bos + it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) + logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size, + [p_fc_feats, p_att_feats, + pp_att_feats, p_att_masks] + ) + self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) + for k in range(batch_size): + if sample_n == beam_size: + for _n in range(sample_n): + seq_len = self.done_beams[k][_n]['seq'].shape[0] + seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq'] + seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps'] + else: + seq_len = self.done_beams[k][0]['seq'].shape[0] + seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq, seqLogprobs + + def _sample(self, fc_feats, att_feats, att_masks=None, update_opts={}): + opt = self.args.__dict__ + opt.update(**update_opts) + + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + sample_n = int(opt.get('sample_n', 1)) + group_size = opt.get('group_size', 1) + output_logsoftmax = opt.get('output_logsoftmax', 1) + decoding_constraint = opt.get('decoding_constraint', 0) + block_trigrams = opt.get('block_trigrams', 0) + if beam_size > 1 and sample_method in ['greedy', 'beam_search']: + return self._sample_beam(fc_feats, att_feats, att_masks, opt) + if group_size > 1: + return self._diverse_sample(fc_feats, att_feats, att_masks, opt) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size * sample_n) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + if sample_n > 1: + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n, + [p_fc_feats, p_att_feats, + pp_att_feats, p_att_masks] + ) + + trigrams = [] # will be a list of batch_size dictionaries + + seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1) + for t in range(self.max_seq_length + 1): + if t == 0: # input + it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long) + + logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, + output_logsoftmax=output_logsoftmax) + + if decoding_constraint and t > 0: + tmp = logprobs.new_zeros(logprobs.size()) + tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp + + # Mess with trigrams + # Copy from https://github.com/lukemelas/image-paragraph-captioning + if block_trigrams and t >= 3: + # Store trigram generated at last step + prev_two_batch = seq[:, t - 3:t - 1] + for i in range(batch_size): # = seq.size(0) + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + current = seq[i][t - 1] + if t == 3: # initialize + trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} + elif t > 3: + if prev_two in trigrams[i]: # add to list + trigrams[i][prev_two].append(current) + else: # create list + trigrams[i][prev_two] = [current] + # Block used trigrams at next step + prev_two_batch = seq[:, t - 2:t] + mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size + for i in range(batch_size): + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + if prev_two in trigrams[i]: + for j in trigrams[i][prev_two]: + mask[i, j] += 1 + # Apply mask to log probs + # logprobs = logprobs - (mask * 1e9) + alpha = 2.0 # = 4 + logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) + + # sample the next word + if t == self.max_seq_length: # skip if we achieve maximum length + break + it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) + + # stop when all finished + if t == 0: + unfinished = it != self.eos_idx + else: + it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 + logprobs = logprobs * unfinished.unsqueeze(1).float() + unfinished = unfinished * (it != self.eos_idx) + seq[:, t] = it + seqLogprobs[:, t] = logprobs + # quit loop if all sequences have finished + if unfinished.sum() == 0: + break + + return seq, seqLogprobs + + def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}): + + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + block_trigrams = opt.get('block_trigrams', 0) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries + + seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in + range(group_size)] + seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)] + state_table = [self.init_hidden(batch_size) for _ in range(group_size)] + + for tt in range(self.max_seq_length + group_size): + for divm in range(group_size): + t = tt - divm + seq = seq_table[divm] + seqLogprobs = seqLogprobs_table[divm] + trigrams = trigrams_table[divm] + if t >= 0 and t <= self.max_seq_length - 1: + if t == 0: # input + it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) + else: + it = seq[:, t - 1] # changed + + logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, + p_att_masks, state_table[divm]) # changed + logprobs = F.log_softmax(logprobs / temperature, dim=-1) + + # Add diversity + if divm > 0: + unaug_logprobs = logprobs.clone() + for prev_choice in range(divm): + prev_decisions = seq_table[prev_choice][:, t] + logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda + + if decoding_constraint and t > 0: + tmp = logprobs.new_zeros(logprobs.size()) + tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp + + # Mess with trigrams + if block_trigrams and t >= 3: + # Store trigram generated at last step + prev_two_batch = seq[:, t - 3:t - 1] + for i in range(batch_size): # = seq.size(0) + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + current = seq[i][t - 1] + if t == 3: # initialize + trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} + elif t > 3: + if prev_two in trigrams[i]: # add to list + trigrams[i][prev_two].append(current) + else: # create list + trigrams[i][prev_two] = [current] + # Block used trigrams at next step + prev_two_batch = seq[:, t - 2:t] + mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size + for i in range(batch_size): + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + if prev_two in trigrams[i]: + for j in trigrams[i][prev_two]: + mask[i, j] += 1 + # Apply mask to log probs + # logprobs = logprobs - (mask * 1e9) + alpha = 2.0 # = 4 + logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) + + it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1) + + # stop when all finished + if t == 0: + unfinished = it != self.eos_idx + else: + unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx + it[~unfinished] = self.pad_idx + unfinished = unfinished & (it != self.eos_idx) # changed + seq[:, t] = it + seqLogprobs[:, t] = sampleLogprobs.view(-1) + + return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, + 1).reshape( + batch_size * group_size, -1) diff --git a/modules/caption_model.py b/modules/caption_model.py new file mode 100644 index 0000000..b633ec2 --- /dev/null +++ b/modules/caption_model.py @@ -0,0 +1,401 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modules.utils as utils + + +class CaptionModel(nn.Module): + def __init__(self): + super(CaptionModel, self).__init__() + + # implements beam search + # calls beam_step and returns the final set of beams + # augments log-probabilities with diversity terms when number of groups > 1 + + def forward(self, *args, **kwargs): + mode = kwargs.get('mode', 'forward') + if 'mode' in kwargs: + del kwargs['mode'] + return getattr(self, '_' + mode)(*args, **kwargs) + + def beam_search(self, init_state, init_logprobs, *args, **kwargs): + + # function computes the similarity score to be augmented + def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprobs = logprobs.clone() + batch_size = beam_seq_table[0].shape[0] + + if divm > 0: + change = logprobs.new_zeros(batch_size, logprobs.shape[-1]) + for prev_choice in range(divm): + prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb + for prev_labels in range(bdash): + change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), + change.new_ones(batch_size, 1)) + + if local_time == 0: + logprobs = logprobs - change * diversity_lambda + else: + logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda + + return logprobs, unaug_logprobs + + # does one step of classical beam search + + def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): + # INPUTS: + # logprobs: probabilities augmented after diversity N*bxV + # beam_size: obvious + # t : time instant + # beam_seq : tensor contanining the beams + # beam_seq_logprobs: tensor contanining the beam logprobs + # beam_logprobs_sum: tensor contanining joint logprobs + # OUPUTS: + # beam_seq : tensor containing the word indices of the decoded captions Nxbxl + # beam_seq_logprobs : log-probability of each decision made, NxbxlxV + # beam_logprobs_sum : joint log-probability of each beam Nxb + + batch_size = beam_logprobs_sum.shape[0] + vocab_size = logprobs.shape[-1] + logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV + if t == 0: + assert logprobs.shape[1] == 1 + beam_logprobs_sum = beam_logprobs_sum[:, :1] + candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV + ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) + ys, ix = ys[:, :beam_size], ix[:, :beam_size] + beam_ix = ix // vocab_size # Nxb which beam + selected_ix = ix % vocab_size # Nxb # which world + state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape( + -1) # N*b which in Nxb beams + + if t > 0: + # gather according to beam_ix + assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == + beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() + beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) + + beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as( + beam_seq_logprobs)) + + beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl + beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ + logprobs.reshape(batch_size, -1).gather(1, ix) + assert (beam_logprobs_sum == ys).all() + _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size) + beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, + beam_ix.unsqueeze(-1).expand(-1, + -1, + vocab_size)) # NxbxV + assert (_tmp_beam_logprobs == beam_logprobs).all() + beam_seq_logprobs = torch.cat([ + beam_seq_logprobs, + beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2) + + new_state = [None for _ in state] + for _ix in range(len(new_state)): + # copy over state in previous beam q to new beam at vix + new_state[_ix] = state[_ix][:, state_ix] + state = new_state + return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state + + # Start diverse_beam_search + opt = kwargs['opt'] + temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + suppress_UNK = opt.get('suppress_UNK', 0) + length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) + bdash = beam_size // group_size # beam per group + + batch_size = init_logprobs.shape[0] + device = init_logprobs.device + # INITIALIZATIONS + beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)] + beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in + range(group_size)] + beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)] + + # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) + done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)] + state_table = [[_.clone() for _ in init_state] for _ in range(group_size)] + logprobs_table = [init_logprobs.clone() for _ in range(group_size)] + # END INIT + + # Chunk elements in the args + args = list(args) + args = utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x... + if self.__class__.__name__ == 'AttEnsemble': + args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in + range(group_size)] # group_name, arg_name, model_name + else: + args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] + + for t in range(self.max_seq_length + group_size - 1): + for divm in range(group_size): + if t >= divm and t <= self.max_seq_length + divm - 1: + # add diversity + logprobs = logprobs_table[divm] + # suppress previous word + if decoding_constraint and t - divm > 0: + logprobs.scatter_(1, beam_seq_table[divm][:, :, t - divm - 1].reshape(-1, 1).to(device), + float('-inf')) + # suppress UNK tokens in the decoding + if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1) - 1)] == 'UNK': + logprobs[:, logprobs.size(1) - 1] = logprobs[:, logprobs.size(1) - 1] - 1000 + # diversity is added here + # the function directly modifies the logprobs values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + logprobs, unaug_logprobs = add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash) + + # infer new beams + beam_seq_table[divm], \ + beam_seq_logprobs_table[divm], \ + beam_logprobs_sum_table[divm], \ + state_table[divm] = beam_step(logprobs, + unaug_logprobs, + bdash, + t - divm, + beam_seq_table[divm], + beam_seq_logprobs_table[divm], + beam_logprobs_sum_table[divm], + state_table[divm]) + + # if time's up... or if end token is reached then copy beams + for b in range(batch_size): + is_end = beam_seq_table[divm][b, :, t - divm] == self.eos_idx + assert beam_seq_table[divm].shape[-1] == t - divm + 1 + if t == self.max_seq_length + divm - 1: + is_end.fill_(1) + for vix in range(bdash): + if is_end[vix]: + final_beam = { + 'seq': beam_seq_table[divm][b, vix].clone(), + 'logps': beam_seq_logprobs_table[divm][b, vix].clone(), + 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(), + 'p': beam_logprobs_sum_table[divm][b, vix].item() + } + final_beam['p'] = length_penalty(t - divm + 1, final_beam['p']) + done_beams_table[b][divm].append(final_beam) + beam_logprobs_sum_table[divm][b, is_end] -= 1000 + + # move the current group one step forward in time + + it = beam_seq_table[divm][:, :, t - divm].reshape(-1) + logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *( + args[divm] + [state_table[divm]])) + logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) + + # all beams are sorted by their log-probabilities + done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] + for b in range(batch_size)] + done_beams = [sum(_, []) for _ in done_beams_table] + return done_beams + + def old_beam_search(self, init_state, init_logprobs, *args, **kwargs): + + # function computes the similarity score to be augmented + def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprobsf = logprobsf.clone() + for prev_choice in range(divm): + prev_decisions = beam_seq_table[prev_choice][local_time] + for sub_beam in range(bdash): + for prev_labels in range(bdash): + logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[ + prev_labels]] - diversity_lambda + return unaug_logprobsf + + # does one step of classical beam search + + def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): + # INPUTS: + # logprobsf: probabilities augmented after diversity + # beam_size: obvious + # t : time instant + # beam_seq : tensor contanining the beams + # beam_seq_logprobs: tensor contanining the beam logprobs + # beam_logprobs_sum: tensor contanining joint logprobs + # OUPUTS: + # beam_seq : tensor containing the word indices of the decoded captions + # beam_seq_logprobs : log-probability of each decision made, same size as beam_seq + # beam_logprobs_sum : joint log-probability of each beam + + ys, ix = torch.sort(logprobsf, 1, True) + candidates = [] + cols = min(beam_size, ys.size(1)) + rows = beam_size + if t == 0: + rows = 1 + for c in range(cols): # for each column (word, essentially) + for q in range(rows): # for each beam expansion + # compute logprob of expanding beam q with word in (sorted) position c + local_logprob = ys[q, c].item() + candidate_logprob = beam_logprobs_sum[q] + local_logprob + # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] + candidates.append({'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': unaug_logprobsf[q]}) + candidates = sorted(candidates, key=lambda x: -x['p']) + + new_state = [_.clone() for _ in state] + # beam_seq_prev, beam_seq_logprobs_prev + if t >= 1: + # we''ll need these as reference when we fork beams around + beam_seq_prev = beam_seq[:t].clone() + beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() + for vix in range(beam_size): + v = candidates[vix] + # fork beam index q into index vix + if t >= 1: + beam_seq[:t, vix] = beam_seq_prev[:, v['q']] + beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] + # rearrange recurrent states + for state_ix in range(len(new_state)): + # copy over state in previous beam q to new beam at vix + new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step + # append new end terminal at the end of this beam + beam_seq[t, vix] = v['c'] # c'th word is the continuation + beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here + beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam + state = new_state + return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates + + # Start diverse_beam_search + opt = kwargs['opt'] + temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + suppress_UNK = opt.get('suppress_UNK', 0) + length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) + bdash = beam_size // group_size # beam per group + + # INITIALIZATIONS + beam_seq_table = [torch.LongTensor(self.max_seq_length, bdash).zero_() for _ in range(group_size)] + beam_seq_logprobs_table = [torch.FloatTensor(self.max_seq_length, bdash, self.vocab_size + 1).zero_() for _ in + range(group_size)] + beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] + + # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) + done_beams_table = [[] for _ in range(group_size)] + # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] + state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) + logprobs_table = list(init_logprobs.chunk(group_size, 0)) + # END INIT + + # Chunk elements in the args + args = list(args) + if self.__class__.__name__ == 'AttEnsemble': + args = [[_.chunk(group_size) if _ is not None else [None] * group_size for _ in args_] for args_ in + args] # arg_name, model_name, group_name + args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in + range(group_size)] # group_name, arg_name, model_name + else: + args = [_.chunk(group_size) if _ is not None else [None] * group_size for _ in args] + args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] + + for t in range(self.max_seq_length + group_size - 1): + for divm in range(group_size): + if t >= divm and t <= self.max_seq_length + divm - 1: + # add diversity + logprobsf = logprobs_table[divm].float() + # suppress previous word + if decoding_constraint and t - divm > 0: + logprobsf.scatter_(1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), float('-inf')) + # suppress UNK tokens in the decoding + if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1) - 1)] == 'UNK': + logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 1000 + # diversity is added here + # the function directly modifies the logprobsf values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash) + + # infer new beams + beam_seq_table[divm], \ + beam_seq_logprobs_table[divm], \ + beam_logprobs_sum_table[divm], \ + state_table[divm], \ + candidates_divm = beam_step(logprobsf, + unaug_logprobsf, + bdash, + t - divm, + beam_seq_table[divm], + beam_seq_logprobs_table[divm], + beam_logprobs_sum_table[divm], + state_table[divm]) + + # if time's up... or if end token is reached then copy beams + for vix in range(bdash): + if beam_seq_table[divm][t - divm, vix] == self.eos_idx or t == self.max_seq_length + divm - 1: + final_beam = { + 'seq': beam_seq_table[divm][:, vix].clone(), + 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), + 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), + 'p': beam_logprobs_sum_table[divm][vix].item() + } + final_beam['p'] = length_penalty(t - divm + 1, final_beam['p']) + done_beams_table[divm].append(final_beam) + # don't continue beams from finished sequences + beam_logprobs_sum_table[divm][vix] = -1000 + + # move the current group one step forward in time + + it = beam_seq_table[divm][t - divm] + logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *( + args[divm] + [state_table[divm]])) + logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) + + # all beams are sorted by their log-probabilities + done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] + done_beams = sum(done_beams_table, []) + return done_beams + + def sample_next_word(self, logprobs, sample_method, temperature): + if sample_method == 'greedy': + sampleLogprobs, it = torch.max(logprobs.data, 1) + it = it.view(-1).long() + elif sample_method == 'gumbel': # gumbel softmax + def sample_gumbel(shape, eps=1e-20): + U = torch.rand(shape).cuda() + return -torch.log(-torch.log(U + eps) + eps) + + def gumbel_softmax_sample(logits, temperature): + y = logits + sample_gumbel(logits.size()) + return F.log_softmax(y / temperature, dim=-1) + + _logprobs = gumbel_softmax_sample(logprobs, temperature) + _, it = torch.max(_logprobs.data, 1) + sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions + else: + logprobs = logprobs / temperature + if sample_method.startswith('top'): # topk sampling + top_num = float(sample_method[3:]) + if 0 < top_num < 1: + # nucleus sampling from # The Curious Case of Neural Text Degeneration + probs = F.softmax(logprobs, dim=1) + sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) + _cumsum = sorted_probs.cumsum(1) + mask = _cumsum < top_num + mask = torch.cat([torch.ones_like(mask[:, :1]), mask[:, :-1]], 1) + sorted_probs = sorted_probs * mask.float() + sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) + logprobs.scatter_(1, sorted_indices, sorted_probs.log()) + else: + the_k = int(top_num) + tmp = torch.empty_like(logprobs).fill_(float('-inf')) + topk, indices = torch.topk(logprobs, the_k, dim=1) + tmp = tmp.scatter(1, indices, topk) + logprobs = tmp + it = torch.distributions.Categorical(logits=logprobs.detach()).sample() + sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions + return it, sampleLogprobs diff --git a/modules/dataloaders.py b/modules/dataloaders.py new file mode 100644 index 0000000..142433a --- /dev/null +++ b/modules/dataloaders.py @@ -0,0 +1,63 @@ +import torch +import numpy as np +from torchvision import transforms +from torch.utils.data import DataLoader +from .datasets import PathologySingleImageDataset + +class R2DataLoader(DataLoader): + def __init__(self, args, tokenizer, split, shuffle): + self.args = args + self.dataset_name = args.dataset_name + self.batch_size = args.batch_size + self.shuffle = shuffle + self.num_workers = args.num_workers + self.tokenizer = tokenizer + self.split = split + + if split == 'train': + self.transform = transforms.Compose([ + transforms.Resize(256), + transforms.RandomCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))]) + else: + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))]) + + if self.dataset_name == 'wsi_report': + self.dataset = PathologySingleImageDataset(self.args, self.tokenizer, self.split, transform=self.transform) + else: + raise ValueError + + self.init_kwargs = { + 'dataset': self.dataset, + 'batch_size': self.batch_size, + 'shuffle': self.shuffle, + 'collate_fn': self.collate_fn, + 'num_workers': self.num_workers + } + super().__init__(**self.init_kwargs) + + @staticmethod + def collate_fn(data): + #* image_ids & report_ids are the same thing + images_id, images, reports_ids, reports_masks, seq_lengths = zip(*data) #* data is a list of tuples + images = torch.stack(images, 0) + max_seq_length = max(seq_lengths) #* Calculate the max_seq_length of the batch + + targets = np.zeros((len(reports_ids), max_seq_length), dtype=int) #* len(reports_ids) is the batch size + targets_masks = np.zeros((len(reports_ids), max_seq_length), dtype=int) + + for i, report_ids in enumerate(reports_ids): + targets[i, :len(report_ids)] = report_ids #* Fill the targets with the report_ids + + for i, report_masks in enumerate(reports_masks): + targets_masks[i, :len(report_masks)] = report_masks #* Fill the targets_masks with the report_masks + + return images_id, images, torch.LongTensor(targets), torch.FloatTensor(targets_masks) #* Now we have the input and the label + diff --git a/modules/datasets.py b/modules/datasets.py new file mode 100644 index 0000000..62c233d --- /dev/null +++ b/modules/datasets.py @@ -0,0 +1,38 @@ +import os +import json +import torch +from PIL import Image +from torch.utils.data import Dataset + + +class BaseDataset(Dataset): + def __init__(self, args, tokenizer, split, transform=None): + self.image_dir = args.image_dir + self.ann_path = args.ann_path + self.max_seq_length = args.max_seq_length + self.split = split + self.tokenizer = tokenizer + self.transform = transform + self.ann = json.loads(open(self.ann_path, 'r').read()) + + self.examples = self.ann[self.split] + for i in range(len(self.examples)): + self.examples[i]['ids'] = tokenizer(self.examples[i]['report'])[:self.max_seq_length] + #* Below is the code to generate the mask for the report + #* such a mask is used to indicate the positions of actual tokens versus padding positions in a sequence. + self.examples[i]['mask'] = [1] * len(self.examples[i]['ids']) + + def __len__(self): + return len(self.examples) + +class PathologySingleImageDataset(BaseDataset): + def __getitem__(self, idx): + example = self.examples[idx] + image_id = example['id'] + image_path = os.path.join(self.image_dir, image_id + '.pt') + image = torch.load(image_path) + report_ids = example['ids'] + report_masks = example['mask'] + seq_length = len(report_ids) + sample = (image_id, image, report_ids, report_masks, seq_length) + return sample \ No newline at end of file diff --git a/modules/fc_model.py b/modules/fc_model.py new file mode 100644 index 0000000..08a1e9a --- /dev/null +++ b/modules/fc_model.py @@ -0,0 +1,204 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import * +from . import utils + +from .caption_model import CaptionModel + +class LSTMCore(nn.Module): + def __init__(self, opt): + super(LSTMCore, self).__init__() + self.input_encoding_size = opt.input_encoding_size + self.rnn_size = opt.rnn_size + self.drop_prob_lm = opt.drop_prob_lm + + # Build a LSTM + self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) + self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) + self.dropout = nn.Dropout(self.drop_prob_lm) + + def forward(self, xt, state): + + all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) + sigmoid_chunk = torch.sigmoid(sigmoid_chunk) + in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) + forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) + out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) + + in_transform = torch.max(\ + all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size), + all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size)) + next_c = forget_gate * state[1][-1] + in_gate * in_transform + next_h = out_gate * torch.tanh(next_c) + + output = self.dropout(next_h) + state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) + return output, state + +class FCModel(CaptionModel): + def __init__(self, opt): + super(FCModel, self).__init__() + self.vocab_size = opt.vocab_size + self.input_encoding_size = opt.input_encoding_size + self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.seq_length = opt.seq_length + self.fc_feat_size = opt.fc_feat_size + + self.ss_prob = 0.0 # Schedule sampling probability + + self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) + self.core = LSTMCore(opt) + self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) + self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) + + self.init_weights() + + def init_weights(self): + initrange = 0.1 + self.embed.weight.data.uniform_(-initrange, initrange) + self.logit.bias.data.fill_(0) + self.logit.weight.data.uniform_(-initrange, initrange) + + def init_hidden(self, bsz): + weight = self.logit.weight + if self.rnn_type == 'lstm': + return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + weight.new_zeros(self.num_layers, bsz, self.rnn_size)) + else: + return weight.new_zeros(self.num_layers, bsz, self.rnn_size) + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + batch_size = fc_feats.size(0) + seq_per_img = seq.shape[0] // batch_size + state = self.init_hidden(batch_size*seq_per_img) + outputs = [] + + if seq_per_img > 1: + fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) + + for i in range(seq.size(1) + 1): + if i == 0: + xt = self.img_embed(fc_feats) + else: + if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample + sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) + sample_mask = sample_prob < self.ss_prob + if sample_mask.sum() == 0: + it = seq[:, i-1].clone() + else: + sample_ind = sample_mask.nonzero().view(-1) + it = seq[:, i-1].data.clone() + #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) + #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) + prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) + it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) + else: + it = seq[:, i-1].clone() + # break if all the sequences end + if i >= 2 and seq[:, i-1].sum() == 0: + break + xt = self.embed(it) + + output, state = self.core(xt, state) + output = F.log_softmax(self.logit(output), dim=1) + outputs.append(output) + + return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() + + def get_logprobs_state(self, it, state): + # 'it' is contains a word index + xt = self.embed(it) + + output, state = self.core(xt, state) + logprobs = F.log_softmax(self.logit(output), dim=1) + + return logprobs, state + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + batch_size = fc_feats.size(0) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = torch.LongTensor(self.seq_length, batch_size).zero_() + seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + for k in range(batch_size): + state = self.init_hidden(beam_size) + for t in range(2): + if t == 0: + xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) + elif t == 1: # input + it = fc_feats.data.new(beam_size).long().zero_() + xt = self.embed(it) + + output, state = self.core(xt, state) + logprobs = F.log_softmax(self.logit(output), dim=1) + + self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) + seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[:, k] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) + + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + if beam_size > 1 and sample_method in ['greedy', 'beam_search']: + return self._sample_beam(fc_feats, att_feats, opt) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1) + for t in range(self.seq_length + 2): + if t == 0: + xt = self.img_embed(fc_feats) + else: + if t == 1: # input + it = fc_feats.data.new(batch_size).long().zero_() + xt = self.embed(it) + + output, state = self.core(xt, state) + logprobs = F.log_softmax(self.logit(output), dim=1) + + # sample the next_word + if t == self.seq_length + 1: # skip if we achieve maximum length + break + if sample_method == 'greedy': + sampleLogprobs, it = torch.max(logprobs.data, 1) + it = it.view(-1).long() + else: + if temperature == 1.0: + prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + else: + # scale logprobs by temperature + prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + it = torch.multinomial(prob_prev, 1).to(logprobs.device) + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions + it = it.view(-1).long() # and flatten indices for downstream processing + + if t >= 1: + # stop when all finished + if t == 1: + unfinished = it > 0 + else: + unfinished = unfinished & (it > 0) + it = it * unfinished.type_as(it) + seq[:,t-1] = it #seq[t] the input of t+2 time step + seqLogprobs[:,t-1] = sampleLogprobs.view(-1) + if unfinished.sum() == 0: + break + + return seq, seqLogprobs diff --git a/modules/histgen_module.py b/modules/histgen_module.py new file mode 100644 index 0000000..f531957 --- /dev/null +++ b/modules/histgen_module.py @@ -0,0 +1,639 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +# import sys +# sys.path.append('..') + +from .att_model import pack_wrapper, AttModel + +def transform_tokens2regions(hidden_states, num_regions, region_size): + + # transform sequence into regions + patch_hidden_states = torch.reshape(hidden_states, (hidden_states.size(0), num_regions, region_size, hidden_states.size(-1))) + # squash regions into sequence into a single axis (samples * regions, region_size, hidden_size) + hidden_states_reshape = patch_hidden_states.contiguous().view(hidden_states.size(0) * num_regions, + region_size, patch_hidden_states.size(-1)) + + return hidden_states_reshape + +def transform_masks2regions(mask, num_regions, region_size): + + # transform sequence mask into regions + patch_mask = torch.reshape(mask, (mask.size(0), num_regions, region_size)) + # squash regions into sequence into a single axis (samples * regions, region_size) + mask_reshape = patch_mask.contiguous().view(mask.size(0) * num_regions, 1, region_size) + + return mask_reshape + +def transform_sentences2tokens(seg_hidden_states, num_sentences, max_sentence_length): + # transform squashed sequence into segments + hidden_states = seg_hidden_states.contiguous().view(seg_hidden_states.size(0) // num_sentences, num_sentences, + max_sentence_length, seg_hidden_states.size(-1)) + # transform segments into sequence + hidden_states = hidden_states.contiguous().view(hidden_states.size(0), num_sentences * max_sentence_length, + hidden_states.size(-1)) + return hidden_states + +class TransformerLayer(nn.Module): + def __init__(self, size, self_attn, feed_forward, dropout): + super(TransformerLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + +class HATLayer(nn.Module): + def __init__(self, heads = 8, d_model = 512, d_ff = 512, region_size = 256, use_region_encoder = True, use_WSI_encoder = False, dropout = 0.1, max_patch = 100000, first_layer = False): + super().__init__() + self.region_size = region_size + self.max_patch = max_patch + self.max_region = int(np.ceil(self.max_patch / self.region_size)) + self.first_layer = first_layer # if is the first HAT layer, interpolate global token and add region level positional encodings + + self.d_model = d_model + self.heads = heads + self.dropout = dropout + self.d_ff = d_ff + + self.global_token = nn.Parameter(torch.randn(1, 1, self.d_model)) + self.region_position_embeddings = PositionalEncoding(self.d_model, self.dropout, max_len = 100000) + + self.use_region_encoder = use_region_encoder + self.use_WSI_encoder = use_WSI_encoder + + self.attn = MultiHeadedAttention(self.heads, self.d_model, self.dropout) + self.ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout) + c = copy.deepcopy + if self.use_region_encoder: + self.region_encoder = TransformerLayer(self.d_model, c(self.attn), c(self.ff), self.dropout) + if self.use_WSI_encoder: + self.WSI_encoder = TransformerLayer(self.d_model, c(self.attn), c(self.ff), self.dropout) + self.position_embeddings = nn.Embedding(self.max_region + 1, self.d_model, padding_idx = 0) #! Usage? + + def forward(self, x, mask, num_regions): + + assert self.use_region_encoder == True or self.use_WSI_encoder == True, "One of the encoders needs to be used" + + # num_regions = int(np.ceil(x.shape[1] / (self.region_size - 1))) + + if self.first_layer: # add global token + x, mask = self.interpolate_global_token(x, mask) + + if self.use_region_encoder: + region_inputs = transform_tokens2regions(x, num_regions, self.region_size) + region_masks = transform_masks2regions(mask, num_regions, self.region_size) + region_inputs = self.region_position_embeddings(region_inputs) + + outputs = self.region_encoder(region_inputs, region_masks) + else: + outputs = x + + if self.use_WSI_encoder: + + assert self.use_region_encoder == True, "Region encoder needs to be used before WSI encoder" + region_global_tokens = outputs[:, ::self.region_size].clone() + region_attention_mask = mask[:, :, ::self.region_size].clone() + #* Relative Positional Encoding + region_global_tokens = region_global_tokens.view(1, region_global_tokens.size(0), region_global_tokens.size(2)) + region_global_tokens = self.region_position_embeddings(region_global_tokens) + #* + WSI_outputs = self.WSI_encoder(region_global_tokens, region_attention_mask) + WSI_outputs = WSI_outputs.view(WSI_outputs.size(1), 1, WSI_outputs.size(2)) + #* replace region representation tokens + outputs[:, ::self.region_size] = WSI_outputs + #* Map the outputs to the original shape + outputs = outputs.view(x.size(0), num_regions * self.region_size, outputs.size(-1)) + else: + outputs = outputs.view(x.size(0), num_regions * self.region_size, outputs.size(-1)) + + outputs_mask = mask + + return outputs, outputs_mask + + def interpolate_global_token(self, hidden_states, mask): + batch_size, seq_len, hidden_dim = hidden_states.size() + num_regions = int(np.ceil(seq_len / (self.region_size - 1))) + + #* Calculate the total size after division and padding + total_size = num_regions * self.region_size + + #* Calculate the padding size + padding_size = total_size - seq_len - num_regions + + #* Pad the sequence and mask if needed + if padding_size > 0: + hidden_padding = torch.zeros(batch_size, padding_size, hidden_dim, device=hidden_states.device) + hidden_states = torch.cat([hidden_states, hidden_padding], dim=1) + + mask_padding = torch.zeros(batch_size, 1, padding_size, device=mask.device) + mask = torch.cat([mask, mask_padding], dim=2) + + #* Add the global token at the end of each region + global_token = self.global_token.repeat(batch_size, num_regions, 1) + + hidden_states_with_global = torch.cat([hidden_states.view(batch_size, num_regions, self.region_size - 1, hidden_dim), + global_token.unsqueeze(2)], dim=2) + hidden_states_with_global = hidden_states_with_global.view(batch_size, num_regions * self.region_size, hidden_dim) + + #* Update mask for global tokens (1 for global tokens) + global_token_mask = torch.ones(batch_size, 1, num_regions, 1, device=mask.device) + mask_with_global = torch.cat([mask.view(batch_size, 1, num_regions, self.region_size - 1), + global_token_mask], dim=3) + mask_with_global = mask_with_global.view(batch_size, 1, num_regions * self.region_size) + + return hidden_states_with_global, mask_with_global + +class HATEncoder(nn.Module): + def __init__(self, encoder_layout = None): + super().__init__() + self.encoder_layout = encoder_layout + self.layer = nn.ModuleList([HATLayer(heads = encoder_layout['num_heads'], d_model = encoder_layout['d_model'], d_ff = encoder_layout['d_ff'], + region_size = encoder_layout['region_size'], use_region_encoder = encoder_layout[str(idx)]['region_encoder'], + use_WSI_encoder = encoder_layout[str(idx)]['WSI_encoder'], dropout = encoder_layout['dropout'], + first_layer = encoder_layout[str(idx)]['first_layer']) for idx in range(int(encoder_layout['num_layers']))]) + self.norm = LayerNorm(encoder_layout['d_model']) + self.pooler = HATPooler(encoder_layout, pooling = encoder_layout['pooling']) + self.region_size = encoder_layout['region_size'] + + def forward(self, x, mask): + + num_regions = int(np.ceil(x.shape[1] / (self.region_size - 1))) + + for idx, layer in enumerate(self.layer): + x, mask = layer(x, mask, num_regions) + x = self.norm(x) + + #* Take out the global token of each region and send to the decoder + x = x.view(num_regions, self.region_size, x.size(-1)) + if self.encoder_layout['pooling'] == 'None': + output = x[:, ::self.region_size].view(1, num_regions, x.size(-1)) + return output + else: + output = self.pooler(x) + output = output.unsqueeze(0) + return output + +class AttentivePooling(nn.Module): + def __init__(self, encoder_layout): + super().__init__() + self.attn_dropout = encoder_layout['dropout'] + self.lin_proj = nn.Linear(encoder_layout['d_model'], encoder_layout['d_model']) + self.v = nn.Linear(encoder_layout['d_model'], 1, bias=False) + + def forward(self, inputs): + lin_out = self.lin_proj(inputs) + attention_weights = torch.tanh(self.v(lin_out)).squeeze(-1) + attention_weights_normalized = torch.softmax(attention_weights, -1) + return torch.sum(attention_weights_normalized.unsqueeze(-1) * inputs, 1) + +class HATPooler(nn.Module): + def __init__(self, encoder_layout, pooling = 'max'): + super().__init__() + self.dense = nn.Linear(encoder_layout['d_model'], encoder_layout['d_model']) + self.pooling = pooling + if self.pooling == 'attentive': + self.attentive_pooling = AttentivePooling(encoder_layout) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + if self.pooling == 'attentive': + pooled_output = self.attentive_pooling(hidden_states) + else: + pooled_output = torch.max(hidden_states, dim=1)[0] + pooled_output = self.dense(pooled_output) + pooled_output = self.activation(pooled_output) + return pooled_output + +def clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + + +def subsequent_mask(size): + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + return torch.from_numpy(subsequent_mask) == 0 + + +def attention(query, key, value, mask=None, dropout=None): + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, float('-inf')) + p_attn = F.softmax(scores, dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + + +def memory_querying_responding(query, key, value, mask=None, dropout=None, topk=32): + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, float('-inf')) + selected_scores, idx = scores.topk(topk) + dummy_value = value.unsqueeze(2).expand(idx.size(0), idx.size(1), idx.size(2), value.size(-2), value.size(-1)) + dummy_idx = idx.unsqueeze(-1).expand(idx.size(0), idx.size(1), idx.size(2), idx.size(3), value.size(-1)) + selected_value = torch.gather(dummy_value, 3, dummy_idx) + p_attn = F.softmax(selected_scores, dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn.unsqueeze(3), selected_value).squeeze(3), p_attn + + +class Transformer(nn.Module): + def __init__(self, encoder, decoder, src_embed, tgt_embed, cmn, d_model, num_heads): + super(Transformer, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.cmn = cmn + + self.d_model = d_model + self.num_heads = num_heads + + def forward(self, src, tgt, src_mask, tgt_mask, memory_matrix): + return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask, memory_matrix=memory_matrix) + + def encode(self, src, src_mask): + return self.encoder(self.src_embed(src), src_mask) + + def decode(self, memory, src_mask, tgt, tgt_mask, past=None, memory_matrix=None): + embeddings = self.tgt_embed(tgt) + + #* Memory querying and responding for textual features + dummy_memory_matrix = memory_matrix.unsqueeze(0).expand(embeddings.size(0), memory_matrix.size(0), memory_matrix.size(1)) + responses = self.cmn(embeddings, dummy_memory_matrix, dummy_memory_matrix) + embeddings = embeddings + responses + + return self.decoder(embeddings, memory, src_mask, tgt_mask, past=past) + + +class Encoder(nn.Module): + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask): + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + + +class LayerNorm(nn.Module): + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + + +class SublayerConnection(nn.Module): + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + _x = sublayer(self.norm(x)) + if type(_x) is tuple: + return x + self.dropout(_x[0]), _x[1] + return x + self.dropout(_x) + +class EncoderLayer(nn.Module): + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + + +class Decoder(nn.Module): + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, memory, src_mask, tgt_mask, past=None): + if past is not None: + present = [[], []] + x = x[:, -1:] + tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None + past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0))) + else: + past = [None] * len(self.layers) + for i, (layer, layer_past) in enumerate(zip(self.layers, past)): + x = layer(x, memory, src_mask, tgt_mask, + layer_past) + if layer_past is not None: + present[0].append(x[1][0]) + present[1].append(x[1][1]) + x = x[0] + if past[0] is None: + return self.norm(x) + else: + return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)] + + +class DecoderLayer(nn.Module): + def __init__(self, size, self_attn, src_attn, feed_forward, dropout): + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 3) + + def forward(self, x, memory, src_mask, tgt_mask, layer_past=None): + m = memory + src_mask = m.new_ones(m.shape[:2], dtype=torch.long) + src_mask = src_mask.unsqueeze(-2) + if layer_past is None: + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) + return self.sublayer[2](x, self.feed_forward) + else: + present = [None, None] + x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0])) + x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1])) + return self.sublayer[2](x, self.feed_forward), present + + +class MultiThreadMemory(nn.Module): + def __init__(self, h, d_model, dropout=0.1, topk=32): + super(MultiThreadMemory, self).__init__() + assert d_model % h == 0 + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + self.topk = topk + + def forward(self, query, key, value, mask=None, layer_past=None): + if mask is not None: + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: + query = self.linears[0](query) + key, value = layer_past[0], layer_past[1] + present = torch.stack([key, value]) + else: + query, key, value = \ + [l(x) for l, x in zip(self.linears, (query, key, value))] + if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): + past_key, past_value = layer_past[0], layer_past[1] + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + present = torch.stack([key, value]) + + query, key, value = \ + [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for x in [query, key, value]] + + x, self.attn = memory_querying_responding(query, key, value, mask=mask, dropout=self.dropout, topk=self.topk) + + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.h * self.d_k) + if layer_past is not None: + return self.linears[-1](x), present + else: + return self.linears[-1](x) + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None, layer_past=None): + if mask is not None: + mask = mask.unsqueeze(1) + nbatches = query.size(0) + if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: + query = self.linears[0](query) + key, value = layer_past[0], layer_past[1] + present = torch.stack([key, value]) + else: + query, key, value = \ + [l(x) for l, x in zip(self.linears, (query, key, value))] + + if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): + past_key, past_value = layer_past[0], layer_past[1] + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + present = torch.stack([key, value]) + + query, key, value = \ + [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for x in [query, key, value]] + + x, self.attn = attention(query, key, value, mask=mask, + dropout=self.dropout) + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.h * self.d_k) + if layer_past is not None: + return self.linears[-1](x), present + else: + return self.linears[-1](x) + + +class PositionwiseFeedForward(nn.Module): + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class Embeddings(nn.Module): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, x): + return self.lut(x) * math.sqrt(self.d_model) + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout, max_len=70000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp(torch.arange(0, d_model, 2).float() * + -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + + +class BaseHistGen(AttModel): + def make_model(self, tgt_vocab, cmn, encoder_layout = None): + c = copy.deepcopy + attn = MultiHeadedAttention(self.num_heads, self.d_model) + ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout) + position = PositionalEncoding(self.d_model, self.dropout) + model = Transformer( + HATEncoder(encoder_layout), + Decoder(DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout), self.num_layers), + lambda x: x, + nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position)), cmn, self.d_model, self.num_heads) + + for p in model.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + return model + + def __init__(self, args, tokenizer): + super(BaseHistGen, self).__init__(args, tokenizer) + self.args = args + self.num_layers = args.num_layers + self.d_model = args.d_model + self.d_ff = args.d_ff + self.num_heads = args.num_heads + self.dropout = args.dropout + self.topk = args.topk + self.K = args.prototype_num + + tgt_vocab = self.vocab_size + 1 + + self.cmn = MultiThreadMemory(args.num_heads, args.d_model, topk=args.topk) + + self.region_size = args.region_size + self.encoder_layout = { + 'num_heads': self.num_heads, + 'd_model': self.d_model, + 'd_ff': self.d_ff, + 'region_size': self.region_size, + 'dropout': self.dropout, + 'pooling': 'attentive', + 'num_layers': 2, + '0': { + 'region_encoder': True, + 'WSI_encoder': True, + 'first_layer': True + }, + '1': { + 'region_encoder': True, + 'WSI_encoder': False, + 'first_layer': False + }, + } + + self.model = self.make_model(tgt_vocab, self.cmn, self.encoder_layout) + self.logit = nn.Linear(args.d_model, tgt_vocab) + + self.memory_matrix = nn.Parameter(torch.FloatTensor(args.cmm_size, args.cmm_dim)) + nn.init.normal_(self.memory_matrix, 0, 1 / args.cmm_dim) + + self.attn_mem = MultiHeadedAttention(self.num_heads, self.d_model) + + def init_hidden(self, bsz): + return [] + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) + memory = self.model.encode(att_feats, att_masks) + + return fc_feats[..., :1], att_feats[..., :1], memory, att_masks + + def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): + + att_feats, att_masks = self.clip_att(att_feats, att_masks) + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + if att_masks is None: + att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) + + #* Memory querying and responding for visual features + dummy_memory_matrix = self.memory_matrix.unsqueeze(0).expand(att_feats.size(0), self.memory_matrix.size(0), self.memory_matrix.size(1)) + indices = torch.linspace(0, att_feats.shape[1] - 1, steps=self.K).long() + M = att_feats[:, indices, :] + responses = self.cmn(M, dummy_memory_matrix, dummy_memory_matrix) + response_mask = responses.new_ones(responses.shape[:2], dtype=torch.long) + response_mask = response_mask.unsqueeze(-2) + att_feats = att_feats + self.attn_mem(att_feats, responses, responses, response_mask) + #* + + att_masks = att_masks.unsqueeze(-2) + if seq is not None: + seq = seq[:, :-1] + seq_mask = (seq.data > 0) + seq_mask[:, 0] += True + + seq_mask = seq_mask.unsqueeze(-2) + seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) + else: + seq_mask = None + + return att_feats, seq, att_masks, seq_mask + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) + out = self.model(att_feats, seq, att_masks, seq_mask, memory_matrix=self.memory_matrix) + outputs = F.log_softmax(self.logit(out), dim=-1) + + return outputs + + def _save_attns(self, start=False): + if start: + self.attention_weights = [] + self.attention_weights.append([layer.src_attn.attn.cpu().numpy() for layer in self.model.decoder.layers]) + + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): + if len(state) == 0: + ys = it.unsqueeze(1) + past = [fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model), + fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model)] + else: + ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) + past = state[1:] + out, past = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device), past=past, + memory_matrix=self.memory_matrix) + + if not self.training: + self._save_attns(start=len(state) == 0) + return out[:, -1], [ys.unsqueeze(0)] + past + \ No newline at end of file diff --git a/modules/loss.py b/modules/loss.py new file mode 100644 index 0000000..f6dc8cc --- /dev/null +++ b/modules/loss.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn + + +class LanguageModelCriterion(nn.Module): + def __init__(self): + super(LanguageModelCriterion, self).__init__() + + def forward(self, input, target, mask): + # truncate to the same size + target = target[:, :input.size(1)] + mask = mask[:, :input.size(1)] + output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask + output = torch.sum(output) / torch.sum(mask) + return output + + +def compute_loss(output, reports_ids, reports_masks): + criterion = LanguageModelCriterion() + loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean() + return loss diff --git a/modules/metrics.py b/modules/metrics.py new file mode 100644 index 0000000..dc46386 --- /dev/null +++ b/modules/metrics.py @@ -0,0 +1,65 @@ +from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score + +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.meteor import Meteor +from pycocoevalcap.rouge import Rouge + + +def compute_scores(gts, res): + """ + Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap) + + :param gts: Dictionary with the image ids and their gold captions, + :param res: Dictionary with the image ids ant their generated captions + :print: Evaluation score (the mean of the scores of all the instances) for each measure + """ + + # Set up scorers + scorers = [ + (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]), + (Meteor(), "METEOR"), + (Rouge(), "ROUGE_L") + ] + eval_res = {} + # Compute score for each metric + for scorer, method in scorers: + try: + score, scores = scorer.compute_score(gts, res, verbose=0) + except TypeError: + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, m in zip(score, method): + eval_res[m] = sc + else: + eval_res[method] = score + return eval_res + + +def compute_mlc(gt, pred, label_set): + res_mlc = {} + avg_aucroc = 0 + for i, label in enumerate(label_set): + res_mlc['AUCROC_' + label] = roc_auc_score(gt[:, i], pred[:, i]) + avg_aucroc += res_mlc['AUCROC_' + label] + res_mlc['AVG_AUCROC'] = avg_aucroc / len(label_set) + + res_mlc['F1_MACRO'] = f1_score(gt, pred, average="macro") + res_mlc['F1_MICRO'] = f1_score(gt, pred, average="micro") + res_mlc['RECALL_MACRO'] = recall_score(gt, pred, average="macro") + res_mlc['RECALL_MICRO'] = recall_score(gt, pred, average="micro") + res_mlc['PRECISION_MACRO'] = precision_score(gt, pred, average="macro") + res_mlc['PRECISION_MICRO'] = precision_score(gt, pred, average="micro") + + return res_mlc + + +class MetricWrapper(object): + def __init__(self, label_set): + self.label_set = label_set + + def __call__(self, gts, res, gts_mlc, res_mlc): + eval_res = compute_scores(gts, res) + eval_res_mlc = compute_mlc(gts_mlc, res_mlc, self.label_set) + + eval_res.update(**eval_res_mlc) + return eval_res diff --git a/modules/optimizers.py b/modules/optimizers.py new file mode 100644 index 0000000..510566e --- /dev/null +++ b/modules/optimizers.py @@ -0,0 +1,18 @@ +import torch + + +def build_optimizer(args, model): + ve_params = list(map(id, model.visual_extractor.parameters())) + ed_params = filter(lambda x: id(x) not in ve_params, model.parameters()) + optimizer = getattr(torch.optim, args.optim)( + [{'params': model.visual_extractor.parameters(), 'lr': args.lr_ve}, + {'params': ed_params, 'lr': args.lr_ed}], + weight_decay=args.weight_decay, + amsgrad=args.amsgrad + ) + return optimizer + + +def build_lr_scheduler(args, optimizer): + lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma) + return lr_scheduler diff --git a/modules/tester_AllinOne.py b/modules/tester_AllinOne.py new file mode 100644 index 0000000..fe24bbd --- /dev/null +++ b/modules/tester_AllinOne.py @@ -0,0 +1,139 @@ +import logging +import os +from abc import abstractmethod + +import cv2 +import pandas as pd +import torch + +from modules.utils import generate_heatmap +from tqdm import tqdm +import logging + +class BaseTester(object): + def __init__(self, model, criterion, metric_ftns, args): + self.args = args + + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) + self.logger = logging.getLogger(__name__) + + # setup GPU device if available, move model into configured device + self.device, device_ids = self._prepare_device(args.n_gpu) + self.model = model.to(self.device) + if len(device_ids) > 1: + self.model = torch.nn.DataParallel(model, device_ids=device_ids) + + self.criterion = criterion + self.metric_ftns = metric_ftns + + self.epochs = self.args.epochs + self.save_dir = self.args.save_dir + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + + self._load_checkpoint(args.load) + + @abstractmethod + def test(self): + raise NotImplementedError + + @abstractmethod + def plot(self): + raise NotImplementedError + + def _prepare_device(self, n_gpu_use): + n_gpu = torch.cuda.device_count() + if n_gpu_use > 0 and n_gpu == 0: + self.logger.warning( + "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") + n_gpu_use = 0 + if n_gpu_use > n_gpu: + self.logger.warning( + "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format( + n_gpu_use, n_gpu)) + n_gpu_use = n_gpu + device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') + list_ids = list(range(n_gpu_use)) + return device, list_ids + + def _load_checkpoint(self, load_path): + load_path = str(load_path) + self.logger.info("Loading checkpoint: {} ...".format(load_path)) + checkpoint = torch.load(load_path) + self.model.load_state_dict(checkpoint['state_dict']) + + +class Tester(BaseTester): + def __init__(self, model, criterion, metric_ftns, args, test_dataloader): + super(Tester, self).__init__(model, criterion, metric_ftns, args) + self.test_dataloader = test_dataloader + + def test(self): + self.logger.info('Start to evaluate in the test set.') + log = dict() + self.model.eval() + with torch.no_grad(): + test_gts, test_res, test_ids = [], [], [] + for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)): + images_id, images, reports_ids, reports_masks = images_id[0], images.to(self.device), reports_ids.to( + self.device), reports_masks.to(self.device) + output = self.model(images, mode='sample') + reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) + ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) + test_res.extend(reports) + test_gts.extend(ground_truths) + test_ids.append(images_id) + + test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)}, + {i: [re] for i, re in enumerate(test_res)}) + log.update(**{'test_' + k: v for k, v in test_met.items()}) + print(log) + + # Convert to pandas DataFrame + test_res_df = pd.DataFrame(test_res, columns=['Generated Reports']) + test_gts_df = pd.DataFrame(test_gts, columns=['Ground Truths']) + + # Create DataFrame for IDs + test_ids_df = pd.DataFrame(test_ids, columns=['Case ID']) + + # Merge the DataFrames + merged_df = pd.concat([test_ids_df, test_res_df, test_gts_df], axis=1) + + # Save the merged DataFrame to a CSV file + merged_df.to_csv(os.path.join(self.save_dir, "gen_vs_gt.csv"), index=False) + test_res_df.to_csv(os.path.join(self.save_dir, "res.csv"), index=False) + test_gts_df.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False) + + return log + + + def plot(self): + assert self.args.batch_size == 1 and self.args.beam_size == 1 + self.logger.info('Start to plot attention weights in the test set.') + os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True) + mean = torch.tensor((0.485, 0.456, 0.406)) + std = torch.tensor((0.229, 0.224, 0.225)) + mean = mean[:, None, None] + std = std[:, None, None] + + self.model.eval() + with torch.no_grad(): + for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)): + images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( + self.device), reports_masks.to(self.device) + output = self.model(images, mode='sample') + image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy() + report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split() + attention_weights = [layer.src_attn.attn.cpu().numpy()[:, :, :-1].mean(0).mean(0) for layer in + self.model.encoder_decoder.model.decoder.layers] + for layer_idx, attns in enumerate(attention_weights): + assert len(attns) == len(report) + for word_idx, (attn, word) in enumerate(zip(attns, report)): + os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx), + "layer_{}".format(layer_idx)), exist_ok=True) + + heatmap = generate_heatmap(image, attn) + cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx), + "layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)), + heatmap) diff --git a/modules/tokenizers.py b/modules/tokenizers.py new file mode 100644 index 0000000..df4ffac --- /dev/null +++ b/modules/tokenizers.py @@ -0,0 +1,112 @@ +import json +import re +from collections import Counter + + +class Tokenizer(object): + def __init__(self, args): + self.ann_path = args.ann_path + self.threshold = args.threshold + self.dataset_name = args.dataset_name + if self.dataset_name == 'iu_xray': + self.clean_report = self.clean_report_iu_xray + elif self.dataset_name == 'wsi_report': + self.clean_report = self.clean_report_pathology + else: + self.clean_report = self.clean_report_mimic_cxr + self.ann = json.loads(open(self.ann_path, 'r').read()) + self.token2idx, self.idx2token = self.create_vocabulary() + + def create_vocabulary(self): + total_tokens = [] + + for example in self.ann['train']: + tokens = self.clean_report(example['report']).split() + for token in tokens: + total_tokens.append(token) + + counter = Counter(total_tokens) + vocab = [k for k, v in counter.items() if v >= self.threshold] + [''] + vocab.sort() + token2idx, idx2token = {}, {} + for idx, token in enumerate(vocab): + token2idx[token] = idx + 1 + idx2token[idx + 1] = token + return token2idx, idx2token + + def clean_report_iu_xray(self, report): + report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \ + .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \ + .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ + .strip().lower().split('. ') + sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', ''). + replace('\\', '').replace("'", '').strip().lower()) + tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] + report = ' . '.join(tokens) + ' .' + return report + + def clean_report_mimic_cxr(self, report): + report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ + .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ + .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ + .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ + .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ + .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ + .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ + .strip().lower().split('. ') + sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') + .replace('\\', '').replace("'", '').strip().lower()) + tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] + report = ' . '.join(tokens) + ' .' + return report + + def clean_report_pathology(self, report): + report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ + .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ + .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ + .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ + .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ + .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ + .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ + .strip().lower().split('. ') + sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') + .replace('\\', '').replace("'", '').strip().lower()) + tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] + report = ' . '.join(tokens) + ' .' + return report + + def get_token_by_id(self, id): + return self.idx2token[id] + + def get_id_by_token(self, token): + if token not in self.token2idx: + return self.token2idx[''] + return self.token2idx[token] + + def get_vocab_size(self): + return len(self.token2idx) + + def __call__(self, report): + tokens = self.clean_report(report).split() + ids = [] + for token in tokens: + ids.append(self.get_id_by_token(token)) + ids = [0] + ids + [0] + return ids + + def decode(self, ids): + txt = '' + for i, idx in enumerate(ids): + if idx > 0: + if i >= 1: + txt += ' ' + txt += self.idx2token[idx] + else: + break + return txt + + def decode_batch(self, ids_batch): + out = [] + for ids in ids_batch: + out.append(self.decode(ids)) + return out \ No newline at end of file diff --git a/modules/trainer_AllinOne.py b/modules/trainer_AllinOne.py new file mode 100644 index 0000000..e86120d --- /dev/null +++ b/modules/trainer_AllinOne.py @@ -0,0 +1,256 @@ +import os +from abc import abstractmethod + +import time +import torch +import pandas as pd +from numpy import inf + +import logging + + +class BaseTrainer(object): + def __init__(self, model, criterion, metric_ftns, optimizer, args): + self.args = args + + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) + self.logger = logging.getLogger(__name__) + + # setup GPU device if available, move model into configured device + self.device, device_ids = self._prepare_device(args.n_gpu) + self.model = model.to(self.device) + if len(device_ids) > 1: + self.model = torch.nn.DataParallel(model, device_ids=device_ids) + + self.criterion = criterion + self.metric_ftns = metric_ftns + self.optimizer = optimizer + + self.epochs = self.args.epochs + self.save_period = self.args.save_period + + self.mnt_mode = args.monitor_mode + self.mnt_metric = 'val_' + args.monitor_metric + self.mnt_metric_test = 'test_' + args.monitor_metric + assert self.mnt_mode in ['min', 'max'] + + self.mnt_best = inf if self.mnt_mode == 'min' else -inf + self.early_stop = getattr(self.args, 'early_stop', inf) + + self.start_epoch = 1 + self.checkpoint_dir = args.save_dir + + if not os.path.exists(self.checkpoint_dir): + os.makedirs(self.checkpoint_dir) + + if args.resume is not None: + self._resume_checkpoint(args.resume) + + self.best_recorder = {'val': {self.mnt_metric: self.mnt_best}, + 'test': {self.mnt_metric_test: self.mnt_best}} + + @abstractmethod + def _train_epoch(self, epoch, model_name): + raise NotImplementedError + + def train(self): + not_improved_count = 0 + for epoch in range(self.start_epoch, self.epochs + 1): + result = self._train_epoch(epoch, self.args.model_name) + + # save logged informations into log dict + log = {'epoch': epoch} + log.update(result) + self._record_best(log) + + # print logged informations to the screen + for key, value in log.items(): + self.logger.info('\t{:15s}: {}'.format(str(key), value)) + + # evaluate model performance according to configured metric, save best checkpoint as model_best + best = False + if self.mnt_mode != 'off': + try: + # check whether model performance improved or not, according to specified metric(mnt_metric) + improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ + (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) + except KeyError: + self.logger.warning( + "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( + self.mnt_metric)) + self.mnt_mode = 'off' + improved = False + + if improved: + self.mnt_best = log[self.mnt_metric] + not_improved_count = 0 + best = True + else: + not_improved_count += 1 + + if not_improved_count > self.early_stop: + self.logger.info("Validation performance didn\'t improve for {} epochs. " "Training stops.".format( + self.early_stop)) + break + + if epoch % self.save_period == 0: + self._save_checkpoint(epoch, save_best=best) + self._print_best() + self._print_best_to_file() + + def _print_best_to_file(self): + crt_time = time.asctime(time.localtime(time.time())) + self.best_recorder['val']['time'] = crt_time + self.best_recorder['test']['time'] = crt_time + self.best_recorder['val']['seed'] = self.args.seed + self.best_recorder['test']['seed'] = self.args.seed + self.best_recorder['val']['best_model_from'] = 'val' + self.best_recorder['test']['best_model_from'] = 'test' + + if not os.path.exists(self.args.record_dir): + os.makedirs(self.args.record_dir) + record_path = os.path.join(self.args.record_dir, self.args.dataset_name+'.csv') + if not os.path.exists(record_path): + record_table = pd.DataFrame() + else: + record_table = pd.read_csv(record_path) + #* Code update since pandas is new version + record_table = pd.concat([record_table, pd.DataFrame([self.best_recorder['val']])], ignore_index=True) + record_table = pd.concat([record_table, pd.DataFrame([self.best_recorder['test']])], ignore_index=True) + record_table.to_csv(record_path, index=False) + + def _prepare_device(self, n_gpu_use): + n_gpu = torch.cuda.device_count() + if n_gpu_use > 0 and n_gpu == 0: + self.logger.warning( + "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") + n_gpu_use = 0 + if n_gpu_use > n_gpu: + self.logger.warning( + "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format( + n_gpu_use, n_gpu)) + n_gpu_use = n_gpu + device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') + list_ids = list(range(n_gpu_use)) + return device, list_ids + + def _save_checkpoint(self, epoch, save_best=False): + state = { + 'epoch': epoch, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'monitor_best': self.mnt_best + } + filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth') + torch.save(state, filename) + self.logger.info("Saving checkpoint: {} ...".format(filename)) + if save_best: + best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') + torch.save(state, best_path) + self.logger.info("Saving current best: model_best.pth ...") + + def _resume_checkpoint(self, resume_path): + resume_path = str(resume_path) + self.logger.info("Loading checkpoint: {} ...".format(resume_path)) + checkpoint = torch.load(resume_path) + self.start_epoch = checkpoint['epoch'] + 1 + self.mnt_best = checkpoint['monitor_best'] + self.model.load_state_dict(checkpoint['state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer']) + + self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) + + def _record_best(self, log): + improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][ + self.mnt_metric]) or \ + (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric]) + if improved_val: + self.best_recorder['val'].update(log) + + improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][ + self.mnt_metric_test]) or \ + (self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][ + self.mnt_metric_test]) + if improved_test: + self.best_recorder['test'].update(log) + + def _print_best(self): + self.logger.info('Best results (w.r.t {}) in validation set:'.format(self.args.monitor_metric)) + for key, value in self.best_recorder['val'].items(): + self.logger.info('\t{:15s}: {}'.format(str(key), value)) + + self.logger.info('Best results (w.r.t {}) in test set:'.format(self.args.monitor_metric)) + for key, value in self.best_recorder['test'].items(): + self.logger.info('\t{:15s}: {}'.format(str(key), value)) + + +class Trainer(BaseTrainer): + def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, + test_dataloader): + super(Trainer, self).__init__(model, criterion, metric_ftns, optimizer, args) + self.lr_scheduler = lr_scheduler + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.test_dataloader = test_dataloader + self.args = args + + def _train_epoch(self, epoch, model_name): + + self.logger.info('[{}/{}] Start to train in the training set.'.format(epoch, self.epochs)) + + train_loss = 0 + self.model.train() + for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader): + images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), reports_masks.to( + self.device) + output = self.model(images, reports_ids, mode = 'train') + loss = self.criterion(output, reports_ids, reports_masks) + train_loss += loss.item() + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1) + self.optimizer.step() + + if batch_idx % self.args.log_period == 0: + self.logger.info('[{}/{}] Step: {}/{}, Training Loss: {:.5f}.' + .format(epoch, self.epochs, batch_idx, len(self.train_dataloader), + train_loss / (batch_idx + 1))) + + log = {'train_loss': train_loss / len(self.train_dataloader)} + + self.logger.info('[{}/{}] Start to evaluate in the validation set.'.format(epoch, self.epochs)) + self.model.eval() + with torch.no_grad(): + val_gts, val_res = [], [] + for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader): + images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( + self.device), reports_masks.to(self.device) + output = self.model(images, mode='sample') + reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) + ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) + val_res.extend(reports) + val_gts.extend(ground_truths) + val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)}, + {i: [re] for i, re in enumerate(val_res)}) + log.update(**{'val_' + k: v for k, v in val_met.items()}) + + self.logger.info('[{}/{}] Start to evaluate in the test set.'.format(epoch, self.epochs)) + self.model.eval() + with torch.no_grad(): + test_gts, test_res = [], [] + for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader): + images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( + self.device), reports_masks.to(self.device) + output = self.model(images, mode='sample') + reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) + ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) + test_res.extend(reports) + test_gts.extend(ground_truths) + test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)}, + {i: [re] for i, re in enumerate(test_res)}) + log.update(**{'test_' + k: v for k, v in test_met.items()}) + + self.lr_scheduler.step() + + return log diff --git a/modules/utils.py b/modules/utils.py new file mode 100644 index 0000000..0add86e --- /dev/null +++ b/modules/utils.py @@ -0,0 +1,70 @@ +import numpy as np +import cv2 +import torch + + +def penalty_builder(penalty_config): + if penalty_config == '': + return lambda x, y: y + pen_type, alpha = penalty_config.split('_') + alpha = float(alpha) + if pen_type == 'wu': + return lambda x, y: length_wu(x, y, alpha) + if pen_type == 'avg': + return lambda x, y: length_average(x, y, alpha) + + +def length_wu(length, logprobs, alpha=0.): + """ + NMT length re-ranking score from + "Google's Neural Machine Translation System" :cite:`wu2016google`. + """ + + modifier = (((5 + length) ** alpha) / + ((5 + 1) ** alpha)) + return logprobs / modifier + + +def length_average(length, logprobs, alpha=0.): + """ + Returns the average probability of tokens in a sequence. + """ + return logprobs / length + + +def split_tensors(n, x): + if torch.is_tensor(x): + assert x.shape[0] % n == 0 + x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) + elif type(x) is list or type(x) is tuple: + x = [split_tensors(n, _) for _ in x] + elif x is None: + x = [None] * n + return x + + +def repeat_tensors(n, x): + """ + For a tensor of size Bx..., we repeat it n times, and make it Bnx... + For collections, do nested repeat + """ + if torch.is_tensor(x): + x = x.unsqueeze(1) # Bx1x... + x = x.expand(-1, n, *([-1] * len(x.shape[2:]))) # Bxnx... + x = x.reshape(x.shape[0] * n, *x.shape[2:]) # Bnx... + elif type(x) is list or type(x) is tuple: + x = [repeat_tensors(n, _) for _ in x] + return x + + +def generate_heatmap(image, weights): + image = image.transpose(1, 2, 0) + height, width, _ = image.shape + weights = weights.reshape(int(weights.shape[0] ** 0.5), int(weights.shape[0] ** 0.5)) + weights = weights - np.min(weights) + weights = weights / np.max(weights) + weights = cv2.resize(weights, (width, height)) + weights = np.uint8(255 * weights) + heatmap = cv2.applyColorMap(weights, cv2.COLORMAP_JET) + result = heatmap * 0.5 + image * 0.5 + return result \ No newline at end of file diff --git a/modules/visual_extractor.py b/modules/visual_extractor.py new file mode 100644 index 0000000..281829d --- /dev/null +++ b/modules/visual_extractor.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn +import torchvision.models as models + + +class VisualExtractor(nn.Module): + def __init__(self, args): + super(VisualExtractor, self).__init__() + self.visual_extractor = args.visual_extractor + self.pretrained = args.visual_extractor_pretrained + model = getattr(models, self.visual_extractor)(pretrained=self.pretrained) + modules = list(model.children())[:-2] + self.model = nn.Sequential(*modules) + self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) + + def forward(self, images): + patch_feats = self.model(images) + avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1)) + batch_size, feat_size, _, _ = patch_feats.shape + patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1) + return patch_feats, avg_feats diff --git a/pycocoevalcap/README.md b/pycocoevalcap/README.md new file mode 100644 index 0000000..942de18 --- /dev/null +++ b/pycocoevalcap/README.md @@ -0,0 +1,23 @@ +Microsoft COCO Caption Evaluation Tools
+--- + +Modified the code to work with Python 3.
+ +### Requirements +* Python 3.x +* Java 1.8 +* pycocotools + +--- + +### Tested on +* Windows 10, Python 3.5. + +--- +### To fix Windows JVM memory error:
+Add the following in System Variables
+    Variable name : _JAVA_OPTIONS
+    Variable value : -Xmx1024M
+ +--- +Original code : https://github.com/tylin/coco-caption
diff --git a/pycocoevalcap/__init__.py b/pycocoevalcap/__init__.py new file mode 100644 index 0000000..680063e --- /dev/null +++ b/pycocoevalcap/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' \ No newline at end of file diff --git a/pycocoevalcap/__pycache__/__init__.cpython-310.pyc b/pycocoevalcap/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..efd59a5 Binary files /dev/null and b/pycocoevalcap/__pycache__/__init__.cpython-310.pyc differ diff --git a/pycocoevalcap/__pycache__/__init__.cpython-38.pyc b/pycocoevalcap/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..c62562c Binary files /dev/null and b/pycocoevalcap/__pycache__/__init__.cpython-38.pyc differ diff --git a/pycocoevalcap/bleu/LICENSE b/pycocoevalcap/bleu/LICENSE new file mode 100644 index 0000000..9ccf677 --- /dev/null +++ b/pycocoevalcap/bleu/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pycocoevalcap/bleu/__init__.py b/pycocoevalcap/bleu/__init__.py new file mode 100644 index 0000000..680063e --- /dev/null +++ b/pycocoevalcap/bleu/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' \ No newline at end of file diff --git a/pycocoevalcap/bleu/__pycache__/__init__.cpython-310.pyc b/pycocoevalcap/bleu/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..011295b Binary files /dev/null and b/pycocoevalcap/bleu/__pycache__/__init__.cpython-310.pyc differ diff --git a/pycocoevalcap/bleu/__pycache__/__init__.cpython-38.pyc b/pycocoevalcap/bleu/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..103100d Binary files /dev/null and b/pycocoevalcap/bleu/__pycache__/__init__.cpython-38.pyc differ diff --git a/pycocoevalcap/bleu/__pycache__/bleu.cpython-310.pyc b/pycocoevalcap/bleu/__pycache__/bleu.cpython-310.pyc new file mode 100644 index 0000000..12ea7b9 Binary files /dev/null and b/pycocoevalcap/bleu/__pycache__/bleu.cpython-310.pyc differ diff --git a/pycocoevalcap/bleu/__pycache__/bleu.cpython-38.pyc b/pycocoevalcap/bleu/__pycache__/bleu.cpython-38.pyc new file mode 100644 index 0000000..54dc0f6 Binary files /dev/null and b/pycocoevalcap/bleu/__pycache__/bleu.cpython-38.pyc differ diff --git a/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-310.pyc b/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-310.pyc new file mode 100644 index 0000000..7738428 Binary files /dev/null and b/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-310.pyc differ diff --git a/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-38.pyc b/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-38.pyc new file mode 100644 index 0000000..ecc929f Binary files /dev/null and b/pycocoevalcap/bleu/__pycache__/bleu_scorer.cpython-38.pyc differ diff --git a/pycocoevalcap/bleu/bleu.py b/pycocoevalcap/bleu/bleu.py new file mode 100644 index 0000000..60e723e --- /dev/null +++ b/pycocoevalcap/bleu/bleu.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# +# File Name : bleu.py +# +# Description : Wrapper for BLEU scorer. +# +# Creation Date : 06-01-2015 +# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT +# Authors : Hao Fang and Tsung-Yi Lin + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +from .bleu_scorer import BleuScorer + + +class Bleu: + def __init__(self, n=4): + # default compute Blue score up to 4 + self._n = n + self._hypo_for_image = {} + self.ref_for_image = {} + + def compute_score(self, gts, res, score_option = 'closest', verbose = 1): + ''' + Inputs: + gts - ground truths + res - predictions + score_option - {shortest, closest, average} + verbose - 1 or 0 + Outputs: + Blue scores + ''' + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + bleu_scorer = BleuScorer(n=self._n) + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + #assert(len(ref) >= 1) + + bleu_scorer += (hypo[0], ref) + + score, scores = bleu_scorer.compute_score(option = score_option, verbose =verbose) + + # return (bleu, bleu_info) + return score, scores + + def method(self): + return "Bleu" diff --git a/pycocoevalcap/bleu/bleu_scorer.py b/pycocoevalcap/bleu/bleu_scorer.py new file mode 100644 index 0000000..d5646aa --- /dev/null +++ b/pycocoevalcap/bleu/bleu_scorer.py @@ -0,0 +1,268 @@ +# bleu_scorer.py +# David Chiang + +# Copyright (c) 2004-2006 University of Maryland. All rights +# reserved. Do not redistribute without permission from the +# author. Not for commercial use. + +# Modified by: +# Hao Fang +# Tsung-Yi Lin + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +'''Provides: +cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). +cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). +''' + +import copy +import sys, math, re +from collections import defaultdict + +def precook(s, n=4, out=False): + """Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well.""" + words = s.split() + counts = defaultdict(int) + for k in range(1,n+1): + for i in range(len(words)-k+1): + ngram = tuple(words[i:i+k]) + counts[ngram] += 1 + return (len(words), counts) + +def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them.''' + + reflen = [] + maxcounts = {} + for ref in refs: + rl, counts = precook(ref, n) + reflen.append(rl) + for (ngram,count) in counts.items(): + maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + + # Calculate effective reference sentence length. + if eff == "shortest": + reflen = min(reflen) + elif eff == "average": + reflen = float(sum(reflen))/len(reflen) + + ## lhuang: N.B.: leave reflen computaiton to the very end!! + + ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) + + return (reflen, maxcounts) + +def cook_test(test, refs , eff=None, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it.''' + + reflen = refs[0] + refmaxcounts = refs[1] + + testlen, counts = precook(test, n, True) + + result = {} + + # Calculate effective reference sentence length. + + if eff == "closest": + result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] + else: ## i.e., "average" or "shortest" or None + result["reflen"] = reflen + + result["testlen"] = testlen + + result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] + + result['correct'] = [0]*n + for (ngram, count) in counts.items(): + result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) + + return result + +class BleuScorer(object): + """Bleu scorer. + """ + + __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" + # special_reflen is used in oracle (proportional effective ref len for a node). + + def copy(self): + ''' copy the refs.''' + new = BleuScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + new._score = None + return new + + def __init__(self, test=None, refs=None, n=4, special_reflen=None): + ''' singular instance ''' + + self.n = n + self.crefs = [] + self.ctest = [] + self.cook_append(test, refs) + self.special_reflen = special_reflen + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + cooked_test = cook_test(test, self.crefs[-1]) + self.ctest.append(cooked_test) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + self._score = None ## need to recompute + + def ratio(self, option=None): + self.compute_score(option=option) + return self._ratio + + def score_ratio(self, option=None): + '''return (bleu, len_ratio) pair''' + return (self.fscore(option=option), self.ratio(option=option)) + + def score_ratio_str(self, option=None): + return "%.4f (%.2f)" % self.score_ratio(option) + + def reflen(self, option=None): + self.compute_score(option=option) + return self._reflen + + def testlen(self, option=None): + self.compute_score(option=option) + return self._testlen + + def retest(self, new_test): + if type(new_test) is str: + new_test = [new_test] + assert len(new_test) == len(self.crefs), new_test + self.ctest = [] + for t, rs in zip(new_test, self.crefs): + self.ctest.append(cook_test(t, rs)) + self._score = None + + return self + + def rescore(self, new_test): + ''' replace test(s) with new test(s), and returns the new score.''' + + return self.retest(new_test).compute_score() + + def size(self): + assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + ## avoid creating new BleuScorer instances + self.cook_append(other[0], other[1]) + else: + assert self.compatible(other), "incompatible BLEUs." + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + self._score = None ## need to recompute + + return self + + def compatible(self, other): + return isinstance(other, BleuScorer) and self.n == other.n + + def single_reflen(self, option="average"): + return self._single_reflen(self.crefs[0][0], option) + + def _single_reflen(self, reflens, option=None, testlen=None): + + if option == "shortest": + reflen = min(reflens) + elif option == "average": + reflen = float(sum(reflens))/len(reflens) + elif option == "closest": + reflen = min((abs(l-testlen), l) for l in reflens)[1] + else: + assert False, "unsupported reflen option %s" % option + + return reflen + + def recompute_score(self, option=None, verbose=0): + self._score = None + return self.compute_score(option, verbose) + + def compute_score(self, option=None, verbose=0): + n = self.n + small = 1e-9 + tiny = 1e-15 ## so that if guess is 0 still return 0 + bleu_list = [[] for _ in range(n)] + + if self._score is not None: + return self._score + + if option is None: + option = "average" if len(self.crefs) == 1 else "closest" + + self._testlen = 0 + self._reflen = 0 + totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} + + # for each sentence + for comps in self.ctest: + testlen = comps['testlen'] + self._testlen += testlen + + if self.special_reflen is None: ## need computation + reflen = self._single_reflen(comps['reflen'], option, testlen) + else: + reflen = self.special_reflen + + self._reflen += reflen + + for key in ['guess','correct']: + for k in range(n): + totalcomps[key][k] += comps[key][k] + + # append per image bleu score + bleu = 1. + for k in range(n): + bleu *= (float(comps['correct'][k]) + tiny) \ + /(float(comps['guess'][k]) + small) + bleu_list[k].append(bleu ** (1./(k+1))) + ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleu_list[k][-1] *= math.exp(1 - 1/ratio) + + if verbose > 1: + print(comps, reflen) + + totalcomps['reflen'] = self._reflen + totalcomps['testlen'] = self._testlen + + bleus = [] + bleu = 1. + for k in range(n): + bleu *= float(totalcomps['correct'][k] + tiny) \ + / (totalcomps['guess'][k] + small) + bleus.append(bleu ** (1./(k+1))) + ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleus[k] *= math.exp(1 - 1/ratio) + + if verbose > 0: + print(totalcomps) + print("ratio:", ratio) + + self._score = bleus + return self._score, bleu_list diff --git a/pycocoevalcap/cider/__init__.py b/pycocoevalcap/cider/__init__.py new file mode 100644 index 0000000..3f7d85b --- /dev/null +++ b/pycocoevalcap/cider/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' diff --git a/pycocoevalcap/cider/cider.py b/pycocoevalcap/cider/cider.py new file mode 100644 index 0000000..7aadb9a --- /dev/null +++ b/pycocoevalcap/cider/cider.py @@ -0,0 +1,55 @@ +# Filename: cider.py +# +# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric +# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) +# +# Creation Date: Sun Feb 8 14:16:54 2015 +# +# Authors: Ramakrishna Vedantam and Tsung-Yi Lin + + +from .cider_scorer import CiderScorer +import pdb + +class Cider: + """ + Main Class to compute the CIDEr metric + + """ + def __init__(self, test=None, refs=None, n=4, sigma=6.0): + # set cider to sum over 1 to 4-grams + self._n = n + # set the standard deviation parameter for gaussian penalty + self._sigma = sigma + + def compute_score(self, gts, res): + """ + Main function to compute CIDEr score + :param hypo_for_image (dict) : dictionary with key and value + ref_for_image (dict) : dictionary with key and value + :return: cider (float) : computed CIDEr score for the corpus + """ + + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) + + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + assert(len(ref) > 0) + + cider_scorer += (hypo[0], ref) + + (score, scores) = cider_scorer.compute_score() + + return score, scores + + def method(self): + return "CIDEr" \ No newline at end of file diff --git a/pycocoevalcap/cider/cider_scorer.py b/pycocoevalcap/cider/cider_scorer.py new file mode 100644 index 0000000..94752e8 --- /dev/null +++ b/pycocoevalcap/cider/cider_scorer.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# Tsung-Yi Lin +# Ramakrishna Vedantam + + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +import copy +from collections import defaultdict +import numpy as np +import pdb +import math + +def precook(s, n=4, out=False): + """ + Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well. + :param s: string : sentence to be converted into ngrams + :param n: int : number of ngrams for which representation is calculated + :return: term frequency vector for occuring ngrams + """ + words = s.split() + counts = defaultdict(int) + for k in range(1,n+1): + for i in range(len(words)-k+1): + ngram = tuple(words[i:i+k]) + counts[ngram] += 1 + return counts + +def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them. + :param refs: list of string : reference sentences for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (list of dict) + ''' + return [precook(ref, n) for ref in refs] + +def cook_test(test, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it. + :param test: list of string : hypothesis sentence for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (dict) + ''' + return precook(test, n, True) + +class CiderScorer(object): + """CIDEr scorer. + """ + + def copy(self): + ''' copy the refs.''' + new = CiderScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + return new + + def __init__(self, test=None, refs=None, n=4, sigma=6.0): + ''' singular instance ''' + self.n = n + self.sigma = sigma + self.crefs = [] + self.ctest = [] + self.document_frequency = defaultdict(float) + self.cook_append(test, refs) + self.ref_len = None + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + self.ctest.append(cook_test(test)) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + def size(self): + assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + ## avoid creating new CiderScorer instances + self.cook_append(other[0], other[1]) + else: + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + + return self + def compute_doc_freq(self): + ''' + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + ''' + for refs in self.crefs: + # refs, k ref captions of one image + for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): + self.document_frequency[ngram] += 1 + # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + + def compute_cider(self): + def counts2vec(cnts): + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(self.n)] + length = 0 + norm = [0.0 for _ in range(self.n)] + for (ngram,term_freq) in cnts.items(): + # give word count 1 if it doesn't appear in reference corpus + df = np.log(max(1.0, self.document_frequency[ngram])) + # ngram index + n = len(ngram)-1 + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[n][ngram] = float(term_freq)*(self.ref_len - df) + # compute norm for the vector. the norm will be used for computing similarity + norm[n] += pow(vec[n][ngram], 2) + + if n == 1: + length += term_freq + norm = [np.sqrt(n) for n in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + ''' + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + ''' + delta = float(length_hyp - length_ref) + # measure consine similarity + val = np.array([0.0 for _ in range(self.n)]) + for n in range(self.n): + # ngram + for (ngram,count) in vec_hyp[n].items(): + # vrama91 : added clipping + val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] + + if (norm_hyp[n] != 0) and (norm_ref[n] != 0): + val[n] /= (norm_hyp[n]*norm_ref[n]) + + assert(not math.isnan(val[n])) + # vrama91: added a length based gaussian penalty + val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) + return val + + # compute log reference length + self.ref_len = np.log(float(len(self.crefs))) + + scores = [] + for test, refs in zip(self.ctest, self.crefs): + # compute vector for test captions + vec, norm, length = counts2vec(test) + # compute vector for ref captions + score = np.array([0.0 for _ in range(self.n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec(ref) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + # change by vrama91 - mean of ngram scores, instead of sum + score_avg = np.mean(score) + # divide by number of references + score_avg /= len(refs) + # multiply score by 10 + score_avg *= 10.0 + # append score of an image to the score list + scores.append(score_avg) + return scores + + def compute_score(self, option=None, verbose=0): + # compute idf + self.compute_doc_freq() + # assert to check document frequency + assert(len(self.ctest) >= max(self.document_frequency.values())) + # compute cider score + score = self.compute_cider() + # debug + # print score + return np.mean(np.array(score)), np.array(score) \ No newline at end of file diff --git a/pycocoevalcap/eval.py b/pycocoevalcap/eval.py new file mode 100644 index 0000000..21f53dc --- /dev/null +++ b/pycocoevalcap/eval.py @@ -0,0 +1,74 @@ +__author__ = 'tylin' +from .tokenizer.ptbtokenizer import PTBTokenizer +from .bleu.bleu import Bleu +from .meteor.meteor import Meteor +from .rouge.rouge import Rouge +from .cider.cider import Cider + +class COCOEvalCap: + def __init__(self, coco, cocoRes): + self.evalImgs = [] + self.eval = {} + self.imgToEval = {} + self.coco = coco + self.cocoRes = cocoRes + self.params = {'image_id': cocoRes.getImgIds()} + + def evaluate(self): + imgIds = self.params['image_id'] + # imgIds = self.coco.getImgIds() + gts = {} + res = {} + for imgId in imgIds: + gts[imgId] = self.coco.imgToAnns[imgId] + res[imgId] = self.cocoRes.imgToAnns[imgId] + + # ================================================= + # Set up scorers + # ================================================= + print('tokenization...') + tokenizer = PTBTokenizer() + gts = tokenizer.tokenize(gts) + res = tokenizer.tokenize(res) + + # ================================================= + # Set up scorers + # ================================================= + print('setting up scorers...') + scorers = [ + (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), + (Meteor(),"METEOR"), + (Rouge(), "ROUGE_L"), + (Cider(), "CIDEr") + ] + + # ================================================= + # Compute scores + # ================================================= + eval = {} + for scorer, method in scorers: + print('computing %s score...'%(scorer.method())) + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, scs, m in zip(score, scores, method): + self.setEval(sc, m) + self.setImgToEvalImgs(scs, imgIds, m) + print("%s: %0.3f"%(m, sc)) + else: + self.setEval(score, method) + self.setImgToEvalImgs(scores, imgIds, method) + print("%s: %0.3f"%(method, score)) + self.setEvalImgs() + + def setEval(self, score, method): + self.eval[method] = score + + def setImgToEvalImgs(self, scores, imgIds, method): + for imgId, score in zip(imgIds, scores): + if not imgId in self.imgToEval: + self.imgToEval[imgId] = {} + self.imgToEval[imgId]["image_id"] = imgId + self.imgToEval[imgId][method] = score + + def setEvalImgs(self): + self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] diff --git a/pycocoevalcap/license.txt b/pycocoevalcap/license.txt new file mode 100644 index 0000000..3ada56f --- /dev/null +++ b/pycocoevalcap/license.txt @@ -0,0 +1,26 @@ +Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +The views and conclusions contained in the software and documentation are those +of the authors and should not be interpreted as representing official policies, +either expressed or implied, of the FreeBSD Project. \ No newline at end of file diff --git a/pycocoevalcap/meteor/__init__.py b/pycocoevalcap/meteor/__init__.py new file mode 100644 index 0000000..349338d --- /dev/null +++ b/pycocoevalcap/meteor/__init__.py @@ -0,0 +1 @@ +from .meteor import * \ No newline at end of file diff --git a/pycocoevalcap/meteor/__pycache__/__init__.cpython-310.pyc b/pycocoevalcap/meteor/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..c8eba98 Binary files /dev/null and b/pycocoevalcap/meteor/__pycache__/__init__.cpython-310.pyc differ diff --git a/pycocoevalcap/meteor/__pycache__/__init__.cpython-38.pyc b/pycocoevalcap/meteor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..5956821 Binary files /dev/null and b/pycocoevalcap/meteor/__pycache__/__init__.cpython-38.pyc differ diff --git a/pycocoevalcap/meteor/__pycache__/meteor.cpython-310.pyc b/pycocoevalcap/meteor/__pycache__/meteor.cpython-310.pyc new file mode 100644 index 0000000..9eadd2b Binary files /dev/null and b/pycocoevalcap/meteor/__pycache__/meteor.cpython-310.pyc differ diff --git a/pycocoevalcap/meteor/__pycache__/meteor.cpython-38.pyc b/pycocoevalcap/meteor/__pycache__/meteor.cpython-38.pyc new file mode 100644 index 0000000..1a18556 Binary files /dev/null and b/pycocoevalcap/meteor/__pycache__/meteor.cpython-38.pyc differ diff --git a/pycocoevalcap/meteor/data/paraphrase-en.gz b/pycocoevalcap/meteor/data/paraphrase-en.gz new file mode 100644 index 0000000..88033c8 Binary files /dev/null and b/pycocoevalcap/meteor/data/paraphrase-en.gz differ diff --git a/pycocoevalcap/meteor/meteor-1.5.jar b/pycocoevalcap/meteor/meteor-1.5.jar new file mode 100644 index 0000000..a833bc0 Binary files /dev/null and b/pycocoevalcap/meteor/meteor-1.5.jar differ diff --git a/pycocoevalcap/meteor/meteor.py b/pycocoevalcap/meteor/meteor.py new file mode 100644 index 0000000..114b42a --- /dev/null +++ b/pycocoevalcap/meteor/meteor.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python + +# Python wrapper for METEOR implementation, by Xinlei Chen +# Acknowledge Michael Denkowski for the generous discussion and help + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +import os +import sys +import subprocess +import threading + +# Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. +METEOR_JAR = 'meteor-1.5.jar' +# print METEOR_JAR + +class Meteor: + + def __init__(self): + self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ + '-', '-', '-stdio', '-l', 'en', '-norm'] + self.meteor_p = subprocess.Popen(self.meteor_cmd, \ + cwd=os.path.dirname(os.path.abspath(__file__)), \ + stdin=subprocess.PIPE, \ + stdout=subprocess.PIPE, \ + stderr=subprocess.PIPE, + universal_newlines = True, + bufsize = 1) + # Used to guarantee thread safety + self.lock = threading.Lock() + + def compute_score(self, gts, res): + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + scores = [] + + eval_line = 'EVAL' + self.lock.acquire() + for i in imgIds: + assert(len(res[i]) == 1) + stat = self._stat(res[i][0], gts[i]) + eval_line += ' ||| {}'.format(stat) + + self.meteor_p.stdin.write('{}\n'.format(eval_line)) + for i in range(0,len(imgIds)): + scores.append(float(self.meteor_p.stdout.readline().strip())) + score = float(self.meteor_p.stdout.readline().strip()) + self.lock.release() + + return score, scores + + def method(self): + return "METEOR" + + def _stat(self, hypothesis_str, reference_list): + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + self.meteor_p.stdin.write('{}\n'.format(score_line)) + return self.meteor_p.stdout.readline().strip() + + def _score(self, hypothesis_str, reference_list): + self.lock.acquire() + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + self.meteor_p.stdin.write('{}\n'.format(score_line)) + stats = self.meteor_p.stdout.readline().strip() + eval_line = 'EVAL ||| {}'.format(stats) + # EVAL ||| stats + self.meteor_p.stdin.write('{}\n'.format(eval_line)) + score = float(self.meteor_p.stdout.readline().strip()) + # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice + # thanks for Andrej for pointing this out + score = float(self.meteor_p.stdout.readline().strip()) + self.lock.release() + return score + + def __del__(self): + self.lock.acquire() + self.meteor_p.stdin.close() + self.meteor_p.kill() + self.meteor_p.wait() + self.lock.release() diff --git a/pycocoevalcap/rouge/__init__.py b/pycocoevalcap/rouge/__init__.py new file mode 100644 index 0000000..e3c0469 --- /dev/null +++ b/pycocoevalcap/rouge/__init__.py @@ -0,0 +1 @@ +from .rouge import * \ No newline at end of file diff --git a/pycocoevalcap/rouge/__pycache__/__init__.cpython-310.pyc b/pycocoevalcap/rouge/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..23c6182 Binary files /dev/null and b/pycocoevalcap/rouge/__pycache__/__init__.cpython-310.pyc differ diff --git a/pycocoevalcap/rouge/__pycache__/__init__.cpython-38.pyc b/pycocoevalcap/rouge/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..e11f6ad Binary files /dev/null and b/pycocoevalcap/rouge/__pycache__/__init__.cpython-38.pyc differ diff --git a/pycocoevalcap/rouge/__pycache__/rouge.cpython-310.pyc b/pycocoevalcap/rouge/__pycache__/rouge.cpython-310.pyc new file mode 100644 index 0000000..03e5757 Binary files /dev/null and b/pycocoevalcap/rouge/__pycache__/rouge.cpython-310.pyc differ diff --git a/pycocoevalcap/rouge/__pycache__/rouge.cpython-38.pyc b/pycocoevalcap/rouge/__pycache__/rouge.cpython-38.pyc new file mode 100644 index 0000000..c65dcd5 Binary files /dev/null and b/pycocoevalcap/rouge/__pycache__/rouge.cpython-38.pyc differ diff --git a/pycocoevalcap/rouge/rouge.py b/pycocoevalcap/rouge/rouge.py new file mode 100644 index 0000000..3a10f5a --- /dev/null +++ b/pycocoevalcap/rouge/rouge.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# +# File Name : rouge.py +# +# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) +# +# Creation Date : 2015-01-07 06:03 +# Author : Ramakrishna Vedantam + +import numpy as np +import pdb + +def my_lcs(string, sub): + """ + Calculates longest common subsequence for a pair of tokenized strings + :param string : list of str : tokens from a string split using whitespace + :param sub : list of str : shorter string, also split using whitespace + :returns: length (list of int): length of the longest common subsequence between the two strings + + Note: my_lcs only gives length of the longest common subsequence, not the actual LCS + """ + if(len(string)< len(sub)): + sub, string = string, sub + + lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] + + for j in range(1,len(sub)+1): + for i in range(1,len(string)+1): + if(string[i-1] == sub[j-1]): + lengths[i][j] = lengths[i-1][j-1] + 1 + else: + lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) + + return lengths[len(string)][len(sub)] + +class Rouge(): + ''' + Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set + + ''' + def __init__(self): + # vrama91: updated the value below based on discussion with Hovey + self.beta = 1.2 + + def calc_score(self, candidate, refs): + """ + Compute ROUGE-L score given one candidate and references for an image + :param candidate: str : candidate sentence to be evaluated + :param refs: list of str : COCO reference sentences for the particular image to be evaluated + :returns score: int (ROUGE-L score for the candidate evaluated against references) + """ + assert(len(candidate)==1) + assert(len(refs)>0) + prec = [] + rec = [] + + # split into tokens + token_c = candidate[0].split(" ") + + for reference in refs: + # split into tokens + token_r = reference.split(" ") + # compute the longest common subsequence + lcs = my_lcs(token_r, token_c) + prec.append(lcs/float(len(token_c))) + rec.append(lcs/float(len(token_r))) + + prec_max = max(prec) + rec_max = max(rec) + + if(prec_max!=0 and rec_max !=0): + score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) + else: + score = 0.0 + return score + + def compute_score(self, gts, res): + """ + Computes Rouge-L score given a set of reference and candidate sentences for the dataset + Invoked by evaluate_captions.py + :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values + :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values + :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) + """ + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + score = [] + for id in imgIds: + hypo = res[id] + ref = gts[id] + + score.append(self.calc_score(hypo, ref)) + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + assert(len(ref) > 0) + + average_score = np.mean(np.array(score)) + return average_score, np.array(score) + + def method(self): + return "Rouge" diff --git a/pycocoevalcap/tokenizer/__init__.py b/pycocoevalcap/tokenizer/__init__.py new file mode 100644 index 0000000..71357a4 --- /dev/null +++ b/pycocoevalcap/tokenizer/__init__.py @@ -0,0 +1 @@ +__author__ = 'hfang' diff --git a/pycocoevalcap/tokenizer/ptbtokenizer.py b/pycocoevalcap/tokenizer/ptbtokenizer.py new file mode 100644 index 0000000..b7d06e1 --- /dev/null +++ b/pycocoevalcap/tokenizer/ptbtokenizer.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# +# File Name : ptbtokenizer.py +# +# Description : Do the PTB Tokenization and remove punctuations. +# +# Creation Date : 29-12-2014 +# Last Modified : Thu Mar 19 09:53:35 2015 +# Authors : Hao Fang and Tsung-Yi Lin + +import os +import sys +import subprocess +import tempfile +import itertools + + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +# path to the stanford corenlp jar +STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' + +# punctuations to be removed from the sentences +PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ + ".", "?", "!", ",", ":", "-", "--", "...", ";"] + +class PTBTokenizer: + """Python wrapper of Stanford PTBTokenizer""" + + def tokenize(self, captions_for_image): + cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ + 'edu.stanford.nlp.process.PTBTokenizer', \ + '-preserveLines', '-lowerCase'] + + # ====================================================== + # prepare data for PTB Tokenizer + # ====================================================== + final_tokenized_captions_for_image = {} + image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] + sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) + + # ====================================================== + # save sentences to temporary file + # ====================================================== + path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) + tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) + tmp_file.write(sentences.encode('utf-8')) + tmp_file.close() + + # ====================================================== + # tokenize sentence + # ====================================================== + cmd.append(os.path.basename(tmp_file.name)) + p_tokenizer = subprocess.Popen(cmd, + cwd=path_to_jar_dirname, + stdout=subprocess.PIPE, + universal_newlines = True, + bufsize = 1) + token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] + lines = token_lines.split('\n') + # remove temp file + os.remove(tmp_file.name) + + # ====================================================== + # create dictionary for tokenized captions + # ====================================================== + for k, line in zip(image_id, lines): + if not k in final_tokenized_captions_for_image: + final_tokenized_captions_for_image[k] = [] + tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ + if w not in PUNCTUATIONS]) + final_tokenized_captions_for_image[k].append(tokenized_caption) + + return final_tokenized_captions_for_image diff --git a/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar b/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar new file mode 100644 index 0000000..3cfa0a0 Binary files /dev/null and b/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar differ diff --git a/records/wsi_report.csv b/records/wsi_report.csv new file mode 100644 index 0000000..c6f6d7e --- /dev/null +++ b/records/wsi_report.csv @@ -0,0 +1,195 @@ +val_BLEU_4,epoch,train_loss,val_BLEU_1,val_BLEU_2,val_BLEU_3,val_METEOR,val_ROUGE_L,test_BLEU_1,test_BLEU_2,test_BLEU_3,test_BLEU_4,test_METEOR,test_ROUGE_L,time,seed,best_model_from +0.0565972308607382,7,1.7531182136246148,0.234481201475219,0.1247045621842245,0.0798780169504361,0.0884258241027818,0.1852337988558269,0.2337987391444647,0.1199131800264235,0.0735864641578995,0.0496197646508178,0.0866359047435237,0.181475992799469,Mon Dec 11 09:29:19 2023,456789,val +0.0542733229083079,5,2.083609720374735,0.2747025419145447,0.1330704013893907,0.0799617243227591,0.1025656516993313,0.1883960663529787,0.2716523196259372,0.1319292803828068,0.0775681835507563,0.0512411693625575,0.1013688048114297,0.1870354825231361,Mon Dec 11 09:29:19 2023,456789,test +0.197166156301937,27,1.1079945886109184,0.4244391307699703,0.3090431639451085,0.2423957546204503,0.1894009203856905,0.3521960407473721,0.4009060380314914,0.2841284173800583,0.2190111372838425,0.1757862328608941,0.1764497575104328,0.3325274321503063,Tue Dec 12 02:29:17 2023,456789,val +0.1956206203837293,29,1.1047769019196128,0.4223476916255783,0.3068395870920382,0.240370792105686,0.188037515941433,0.3495742349462193,0.4022363551225975,0.2857704741917389,0.2207094570009163,0.1774982291927353,0.1768635030762715,0.3329119984088415,Tue Dec 12 02:29:17 2023,456789,test +0.197166156301937,27,1.1079945886109184,0.4244391307699703,0.3090431639451085,0.2423957546204503,0.1894009203856905,0.3521960407473721,0.4009060380314914,0.2841284173800583,0.2190111372838425,0.1757862328608941,0.1764497575104328,0.3325274321503063,Thu Dec 14 08:40:50 2023,456789,val +0.1956206203837293,29,1.1047769019196128,0.4223476916255783,0.3068395870920382,0.240370792105686,0.188037515941433,0.3495742349462193,0.4022363551225975,0.2857704741917389,0.2207094570009163,0.1774982291927353,0.1768635030762715,0.3329119984088415,Thu Dec 14 08:40:50 2023,456789,test +0.1967466234586174,18,1.295177956860481,0.4145136304254905,0.3040821823525579,0.2401606215777973,0.1848409133338149,0.3441992368421986,0.3954849584217644,0.2799873648163856,0.2155637175739691,0.1731946524426595,0.1732096183849676,0.3288924264799064,Thu Dec 14 08:44:11 2023,456789,val +0.1950044311610086,27,1.2332070106265025,0.4155292539255785,0.3035581604882437,0.2388234074072361,0.1845993109721334,0.343366983685113,0.400999707347253,0.2850393299412866,0.2202829241513563,0.1775110891707733,0.1749057616353648,0.3330149473643515,Thu Dec 14 08:44:11 2023,456789,test +0.060040199869837,16,1.5168610070874462,0.2506848723391778,0.1326437936760262,0.0848510073589915,0.096700257599535,0.1951658387127617,0.2404661725771694,0.1210259750038084,0.073255807196248,0.04899427349722,0.0909664810849335,0.1854744453757367,Thu Dec 14 09:03:42 2023,456789,val +0.056886542219987,13,1.6080467405647854,0.2433788827806373,0.1295074926676683,0.0815129348399083,0.0956655985940219,0.1884060460316854,0.2396699212123639,0.1241656921029956,0.0754315095535209,0.0512134924882707,0.0927961232350113,0.1847334676303497,Thu Dec 14 09:03:42 2023,456789,test +0.1476491234722433,30,1.2803659960803162,0.3681614494825755,0.2496676191708725,0.1868677121211032,0.1583365719142697,0.2899778046148705,0.3491020761409826,0.2320230933372004,0.1714702317191226,0.1345703975246285,0.14931546169131,0.2807924346447732,Fri Dec 15 02:36:07 2023,456789,val +0.1453315538687899,27,1.2836386874028674,0.3663475801900839,0.247157720721968,0.1843732040235789,0.1574388604933622,0.2890172892669032,0.3533814806162692,0.2346746685805182,0.1734270270582129,0.1360372033000209,0.1502464354880587,0.2804982408731449,Fri Dec 15 02:36:07 2023,456789,test +0.065148954086487,25,1.3035310646058853,0.3320354385187206,0.1688318623437607,0.098021656230627,0.1287958160055633,0.195517931401169,0.3240679094540593,0.1638329336775331,0.0950280581611942,0.0630777463290065,0.1223377970925459,0.1902164505904283,Fri Dec 15 02:37:50 2023,456789,val +0.065148954086487,25,1.3035310646058853,0.3320354385187206,0.1688318623437607,0.098021656230627,0.1287958160055633,0.195517931401169,0.3240679094540593,0.1638329336775331,0.0950280581611942,0.0630777463290065,0.1223377970925459,0.1902164505904283,Fri Dec 15 02:37:50 2023,456789,test +0.0994188428417495,24,1.5952179212830429,0.2693634187713682,0.1747817061164394,0.1274691093669434,0.1195304072729161,0.2354976886077685,0.2496445435232596,0.1565046348793855,0.11155382450581,0.0855086052673248,0.1106527019952549,0.2260206870019921,Fri Dec 15 10:05:58 2023,456789,val +0.0994188428417495,24,1.5952179212830429,0.2693634187713682,0.1747817061164394,0.1274691093669434,0.1195304072729161,0.2354976886077685,0.2496445435232596,0.1565046348793855,0.11155382450581,0.0855086052673248,0.1106527019952549,0.2260206870019921,Fri Dec 15 10:05:58 2023,456789,test +0.0137906922621098,1,5.318818585092132,0.1290351105331592,0.0566005556029909,0.0297090494768309,0.045456274948152,0.1202042537223721,0.1307204161248367,0.0572498984131191,0.0302019497624291,0.0140803268101268,0.0463318030545665,0.1205845489451919,Fri Dec 15 10:36:37 2023,456789,val +0.0137906922621098,1,5.318818585092132,0.1290351105331592,0.0566005556029909,0.0297090494768309,0.045456274948152,0.1202042537223721,0.1307204161248367,0.0572498984131191,0.0302019497624291,0.0140803268101268,0.0463318030545665,0.1205845489451919,Fri Dec 15 10:36:37 2023,456789,test +0.0333749451687922,29,3.318666747317048,0.2269830949284755,0.0989654665588618,0.054491416344425,0.0841764305351416,0.1700945180674377,0.2230949284785406,0.0961379652693095,0.0522183589022928,0.031548937308867,0.0830327424019916,0.1661774533117477,Sat Dec 16 20:21:46 2023,456789,val +0.0333749451687922,29,3.318666747317048,0.2269830949284755,0.0989654665588618,0.054491416344425,0.0841764305351416,0.1700945180674377,0.2230949284785406,0.0961379652693095,0.0522183589022928,0.031548937308867,0.0830327424019916,0.1661774533117477,Sat Dec 16 20:21:46 2023,456789,test +0.1869331458411932,26,1.1759753394980947,0.4133859242036598,0.2968345731726112,0.2308426678398804,0.1828817315556557,0.3373444038667384,0.3970819103032409,0.2794755842071384,0.213651319747385,0.1704495617909113,0.174470411242392,0.3272929180097892,Sun Dec 17 17:10:15 2023,456789,val +0.1869331458411932,26,1.1759753394980947,0.4133859242036598,0.2968345731726112,0.2308426678398804,0.1828817315556557,0.3373444038667384,0.3970819103032409,0.2794755842071384,0.213651319747385,0.1704495617909113,0.174470411242392,0.3272929180097892,Sun Dec 17 17:10:15 2023,456789,test +0.0534177614838337,9,1.6598566367214616,0.2468943924366418,0.1337487466309148,0.0811284039995229,0.0925864700681617,0.1980719257358018,0.2430533159261127,0.1269420004154547,0.0749808848481584,0.0490232973391119,0.0894671713512215,0.1928147820830542,Sun Dec 17 17:13:00 2023,456789,val +0.0534177614838337,9,1.6598566367214616,0.2468943924366418,0.1337487466309148,0.0811284039995229,0.0925864700681617,0.1980719257358018,0.2430533159261127,0.1269420004154547,0.0749808848481584,0.0490232973391119,0.0894671713512215,0.1928147820830542,Sun Dec 17 17:13:00 2023,456789,test +0.038797408510287,22,2.601390434998258,0.1924961160431919,0.0958521161124184,0.0579195195606253,0.0727162152489481,0.1655309365977341,0.1869013787286473,0.0926413936491233,0.0554389388066611,0.0370952031903531,0.0697611075760868,0.1635801812547266,Mon Dec 18 09:23:27 2023,456789,val +0.0386400127583674,27,2.592122323617281,0.1903568388982425,0.0946688959433261,0.0572715668903259,0.0718232461637202,0.1641423364759565,0.188450080983613,0.0941106803852767,0.057261522395924,0.0387299183929068,0.0703872040884986,0.1649139254080215,Mon Dec 18 09:23:27 2023,456789,test +0.0487428716499928,14,2.672423162864538,0.2349252128505975,0.1164349068787612,0.0701079870604823,0.0975767810989772,0.1707991445080603,0.2310556831500515,0.1125942783095425,0.0666929326475395,0.0456039285270212,0.0945433931466266,0.1644909432144186,Mon Dec 18 09:49:59 2023,456789,val +0.0487428716499928,14,2.672423162864538,0.2349252128505975,0.1164349068787612,0.0701079870604823,0.0975767810989772,0.1707991445080603,0.2310556831500515,0.1125942783095425,0.0666929326475395,0.0456039285270212,0.0945433931466266,0.1644909432144186,Mon Dec 18 09:49:59 2023,456789,test +0.147245825194109,29,1.2886341865818165,0.3664413423525019,0.2476671744621921,0.1854810937795084,0.156918422656776,0.2904213209834204,0.3625362821563567,0.2449606677125462,0.182967074735263,0.1447226985347869,0.1541386542122942,0.2930470037508452,Mon Dec 18 13:24:23 2023,456789,val +0.147245825194109,29,1.2886341865818165,0.3664413423525019,0.2476671744621921,0.1854810937795084,0.156918422656776,0.2904213209834204,0.3625362821563567,0.2449606677125462,0.182967074735263,0.1447226985347869,0.1541386542122942,0.2930470037508452,Mon Dec 18 13:24:23 2023,456789,test +0.0615186500875259,23,1.294445733286644,0.3156192629876821,0.1581493183267628,0.091344673809964,0.1171025577968654,0.1741685243751697,0.3062758429472092,0.1489671565303766,0.0838780252543883,0.0551790561651596,0.1144456828733844,0.1724201618452019,Mon Dec 18 13:28:16 2023,456789,val +0.060738138305101,25,1.2877339267986978,0.3142246708196781,0.1566952161851935,0.0903017774175832,0.1178449693035493,0.1742484418125262,0.3068338084361721,0.1497733026555823,0.084575071558877,0.0558243076224226,0.116799368386232,0.1745058088783778,Mon Dec 18 13:28:16 2023,456789,test +0.1241615870904293,19,2.523397505045984,0.3303487407729403,0.2193899373212639,0.1610206903785966,0.1351926080399716,0.2775466673102026,0.31985243617074,0.2058945233085987,0.1474432373099739,0.1111979143105764,0.1291720509646541,0.2685014684684035,Mon Dec 18 14:53:02 2023,456789,val +0.1239761698899667,28,2.4943487692541977,0.3315191448389393,0.2201183884256235,0.1612568632594893,0.1356218514488508,0.2776455232446769,0.3231543173103867,0.2083115273453578,0.1489786195249158,0.1121314037972429,0.1309479525478316,0.2706443935618,Mon Dec 18 14:53:02 2023,456789,test +6.105889543283395e-12,30,5.893813529079695,0.0991417425227555,0.0027952659347896,4.697837058130164e-09,0.0177913697038573,0.1000995439899167,0.0983615084525344,0.0040982976699354,6.062962204567044e-09,7.393313587239485e-12,0.017681687771083,0.0993720228019383,Mon Dec 18 14:58:46 2023,456789,val +6.105889543283395e-12,30,5.893813529079695,0.0991417425227555,0.0027952659347896,4.697837058130164e-09,0.0177913697038573,0.1000995439899167,0.0983615084525344,0.0040982976699354,6.062962204567044e-09,7.393313587239485e-12,0.017681687771083,0.0993720228019383,Mon Dec 18 14:58:46 2023,456789,test +0.1442735054061815,21,1.2083694930400295,0.3725944155630978,0.2491894746961097,0.184833369372863,0.1585621522597866,0.2895413777917824,0.358644522954339,0.2358989309173154,0.1721304169919719,0.1323345157935913,0.1494998138974566,0.281844932556489,Mon Dec 18 15:05:51 2023,456789,val +0.1406899700186021,25,1.1882719588036057,0.3691554972962443,0.2456340853589753,0.181177807410118,0.1562368695921505,0.287362015286867,0.3652178054104498,0.2419561977944214,0.1772363941545284,0.1366299762600823,0.1531253004049468,0.2875710374051048,Mon Dec 18 15:05:51 2023,456789,test +0.1500167631401781,23,1.1230120820009062,0.3738159103293388,0.2515431831218599,0.1890596024785009,0.1562017811232584,0.2907074693683364,0.3632111765263292,0.2363665054858522,0.1735572516371867,0.1347456420639498,0.1487827203695474,0.2799198954432507,Mon Dec 18 16:05:34 2023,456789,val +0.149395377336946,22,1.1298144175192664,0.3759919379911395,0.2522465190681104,0.1889821193601999,0.1573948400714074,0.2911234872244414,0.3631286495886628,0.2385197962191785,0.1756331981157187,0.1367266976648178,0.1495219293638301,0.2820133945846119,Mon Dec 18 16:05:34 2023,456789,test +0.0666946445699319,10,2.84710609075772,0.2991490745341977,0.1573386410864431,0.0965104517936288,0.1129960234517074,0.1980318952094478,0.2974780261549808,0.1525076803039296,0.091617226121232,0.062164057161642,0.1128167371633308,0.1909899218961606,Mon Dec 18 22:51:20 2023,456789,val +0.0666946445699319,10,2.84710609075772,0.2991490745341977,0.1573386410864431,0.0965104517936288,0.1129960234517074,0.1980318952094478,0.2974780261549808,0.1525076803039296,0.091617226121232,0.062164057161642,0.1128167371633308,0.1909899218961606,Mon Dec 18 22:51:20 2023,456789,test +0.0194500967810023,4,3.95976626991218,0.2057321196358897,0.0792058838315584,0.0375849058118577,0.0717824276550382,0.1636637049656784,0.2010299089726907,0.0766626701538134,0.0365390225836461,0.0194042401343388,0.0700679908415549,0.1602310569849931,Mon Dec 18 23:43:41 2023,456789,val +0.0194500967810023,4,3.95976626991218,0.2057321196358897,0.0792058838315584,0.0375849058118577,0.0717824276550382,0.1636637049656784,0.2010299089726907,0.0766626701538134,0.0365390225836461,0.0194042401343388,0.0700679908415549,0.1602310569849931,Mon Dec 18 23:43:41 2023,456789,test +0.054492455959975,12,2.671354571790937,0.2399441725980955,0.1224107128613137,0.0767669733275702,0.0918801707556527,0.1867308500074335,0.2323716642067093,0.1166527450158466,0.070072059658216,0.0470884153618849,0.0878180834022126,0.1808275319224386,Tue Dec 19 08:05:15 2023,456789,val +0.0504423576787822,9,2.768772794929259,0.2245348625197892,0.1163946075499199,0.0721891253790657,0.0864483035464336,0.1843504527219997,0.2244442545229495,0.114822639245958,0.069690569122429,0.0477776560663813,0.0858071610227128,0.1845377732965727,Tue Dec 19 08:05:15 2023,456789,test +0.1845701685576257,22,1.081126516044833,0.4099150246956111,0.2932334738856169,0.2278311465115302,0.1790032340323166,0.3371726576629122,0.3903882004684814,0.2737182713196575,0.2089350542037304,0.1667365538927272,0.168965441170099,0.3241529612083373,Tue Dec 19 10:49:00 2023,456789,val +0.1843071101944085,28,1.0599167319638376,0.4116908601903723,0.2939475160519089,0.2280633611485444,0.1794687838556872,0.3368179089139218,0.3962400193623782,0.2788021420110711,0.2131557911619683,0.170223265315175,0.1725775508944826,0.3277573630025199,Tue Dec 19 10:49:00 2023,456789,test +0.1370747267165643,21,1.41579957686665,0.3393638597670423,0.2271635402581954,0.1713527466459861,0.1429837350901612,0.2745247140340499,0.33306012373499,0.2173238658071225,0.1597379878065883,0.124708231154997,0.1383325077409394,0.270609677101576,Tue Dec 19 14:06:33 2023,456789,val +0.1369860181996828,29,1.3869648441606928,0.3418006015297591,0.2281760546123414,0.1716509124432287,0.1447831121082461,0.277542275687928,0.3348873788141977,0.2195891886999078,0.1613694748869686,0.125850953957806,0.1392874367898691,0.2739043327219881,Tue Dec 19 14:06:33 2023,456789,test +0.0803127213476108,7,1.851152973106647,0.2649202463128318,0.1541170233525209,0.1064976957404945,0.0982921911370242,0.2137127181771639,0.2619895505330779,0.1479490071071303,0.0993792006378829,0.0728566674642063,0.0971940861503126,0.2099033119121345,Tue Dec 19 14:08:54 2023,456789,val +0.0803127213476108,7,1.851152973106647,0.2649202463128318,0.1541170233525209,0.1064976957404945,0.0982921911370242,0.2137127181771639,0.2619895505330779,0.1479490071071303,0.0993792006378829,0.0728566674642063,0.0971940861503126,0.2099033119121345,Tue Dec 19 14:08:54 2023,456789,test +0.176333080040024,21,1.1572505920593747,0.4020109696246347,0.2849444564180032,0.2196344092803012,0.1774890704587969,0.329423241903132,0.3878984906694172,0.2718907804899657,0.2073810823950162,0.1652796882904949,0.1692751232342376,0.3201011405482821,Tue Dec 19 15:20:18 2023,456789,val +0.1742971981518699,28,1.129352570481729,0.3967697728512523,0.2808093868520377,0.216679622062981,0.1754032163196337,0.3246644507067778,0.3909350727974597,0.2755436587744013,0.2109247932282425,0.1684501732532893,0.1713020840459195,0.3231417740221218,Tue Dec 19 15:20:18 2023,456789,test +0.0543546611516549,16,2.6406956591409654,0.2581293908222438,0.1327022298833063,0.0804945103431506,0.0995383656066214,0.195929372602577,0.2519848159879962,0.1271420848003283,0.0751559489839935,0.0494442798694877,0.097676221456085,0.1911417065736758,Tue Dec 19 16:21:32 2023,456789,val +0.0506584330290082,19,2.6201409005374687,0.2532443856261191,0.1280933581194718,0.0762212485363994,0.0968762459202943,0.1950699027336145,0.2474100252971625,0.1263238118795636,0.0762556855335533,0.0510920353274819,0.0943196555797608,0.1919676452172629,Tue Dec 19 16:21:32 2023,456789,test +0.1031446478235171,30,2.6583033559604106,0.310675035017774,0.1994436500851593,0.1404657652120984,0.1244073290163715,0.2599531699527131,0.3085327289260418,0.1951408716331072,0.1357945440114564,0.0982964764507269,0.1226635830970376,0.259224563614075,Tue Dec 19 20:13:13 2023,456789,val +0.1021870129957478,21,2.6801170847234554,0.3095654984780283,0.1979757871481349,0.1393156253389928,0.1238488149790084,0.2604867723866166,0.3106690728503736,0.1974916762029476,0.13781712079363,0.1002434433690301,0.1236741936044779,0.2621979648533448,Tue Dec 19 20:13:13 2023,456789,test +0.0388811258996502,30,2.6809009456733346,0.1655991379401052,0.0940540319900079,0.0591153241467493,0.0647871793049599,0.1494532802534661,0.1678937891569226,0.0937163337189377,0.058208090747178,0.0379371374918191,0.0652128005855594,0.1529994516093929,Tue Dec 19 20:37:02 2023,456789,val +0.0388811258996502,30,2.6809009456733346,0.1655991379401052,0.0940540319900079,0.0591153241467493,0.0647871793049599,0.1494532802534661,0.1678937891569226,0.0937163337189377,0.058208090747178,0.0379371374918191,0.0652128005855594,0.1529994516093929,Tue Dec 19 20:37:02 2023,456789,test +1.8560997546198534e-12,30,5.968021112858707,0.0823250975292583,0.0006556971311169,1.3112218635290302e-09,0.0167955397821563,0.0875353201678896,0.0836618985695704,0.0009347942069866,1.660930020002577e-09,2.216193505018257e-12,0.0171421804227067,0.0885615854218911,Tue Dec 19 21:24:17 2023,456789,val +1.8560997546198534e-12,30,5.968021112858707,0.0823250975292583,0.0006556971311169,1.3112218635290302e-09,0.0167955397821563,0.0875353201678896,0.0836618985695704,0.0009347942069866,1.660930020002577e-09,2.216193505018257e-12,0.0171421804227067,0.0885615854218911,Tue Dec 19 21:24:17 2023,456789,test +0.1255337032612625,27,2.606364480928265,0.3411032782400032,0.2249941360371165,0.1642023381165037,0.1386842081437333,0.2852625698552443,0.3200574952791257,0.2068109292343047,0.1477786154073463,0.1106429809984212,0.1281852022552312,0.267084889272527,Wed Dec 20 07:27:32 2023,456789,val +0.1225245785284492,22,2.6214547171190827,0.3372218290756179,0.2210860935693442,0.1606657257043728,0.1368433445047689,0.2828941562254672,0.3199799615138195,0.2079283486531462,0.1494105608130835,0.1126537522640321,0.1282679879482544,0.2688251437831211,Wed Dec 20 07:27:32 2023,456789,test +6.105889543283395e-12,30,5.893813529079695,0.0991417425227555,0.0027952659347896,4.697837058130164e-09,0.0177913697038573,0.1000995439899167,0.0983615084525344,0.0040982976699354,6.062962204567044e-09,7.393313587239485e-12,0.017681687771083,0.0993720228019383,Thu Dec 21 08:23:27 2023,456789,val +6.105889543283395e-12,30,5.893813529079695,0.0991417425227555,0.0027952659347896,4.697837058130164e-09,0.0177913697038573,0.1000995439899167,0.0983615084525344,0.0040982976699354,6.062962204567044e-09,7.393313587239485e-12,0.017681687771083,0.0993720228019383,Thu Dec 21 08:23:27 2023,456789,test +0.0272586860918983,28,2.576580900122425,0.2241352405721687,0.0997840454704674,0.0517011846658588,0.07583640558284,0.170120033108495,0.2218725617685276,0.0970825196819267,0.0484252461927446,0.0236931031869374,0.0748523586890084,0.1675136561888593,Mon Jan 1 14:42:57 2024,456789,val +0.0257310676610758,21,2.587573056151637,0.2536540962288653,0.1026410546066249,0.0484784956710873,0.0868763024186674,0.1681240838240294,0.2488426527958355,0.0992716641533301,0.046902039788468,0.0245120371630192,0.0858221843342276,0.1653984626162888,Mon Jan 1 14:42:57 2024,456789,test +0.0454148572336609,8,3.0237782491838514,0.261755526657994,0.1264290678460636,0.0712033353117583,0.0910804489978357,0.1811150522758468,0.2530559167750292,0.1192192345582948,0.0657628351466826,0.0409982604022008,0.0878069938668167,0.1743744640318531,Mon Jan 1 15:19:33 2024,456789,val +0.0447681623445012,7,3.0973853413669588,0.2629518855656663,0.1255814031058836,0.0703108950295383,0.090845217984122,0.1836120033008002,0.2554486345903737,0.1194027670188386,0.0659029868791564,0.04120161972758,0.0879492108380011,0.1769290388906068,Mon Jan 1 15:19:33 2024,456789,test +0.0307235741018944,16,1.3030683907894576,0.2616254876462904,0.1173900924176187,0.0583239242515302,0.104358868718527,0.1830287101172581,0.2584655396618951,0.1139259623825978,0.054846994628915,0.0278760201395715,0.1035274095768083,0.1813236246504508,Mon Jan 1 19:32:33 2024,456789,val +0.0307235741018944,16,1.3030683907894576,0.2616254876462904,0.1173900924176187,0.0583239242515302,0.104358868718527,0.1830287101172581,0.2584655396618951,0.1139259623825978,0.054846994628915,0.0278760201395715,0.1035274095768083,0.1813236246504508,Mon Jan 1 19:32:33 2024,456789,test +0.045785773014968,25,1.224863422729356,0.253680104031206,0.1177925427279998,0.0680553564174579,0.0964385203616208,0.1803966907051743,0.2484135240572139,0.1137296424136021,0.0642036573831182,0.041915545351601,0.0945013062082387,0.1758362255546988,Mon Jan 1 20:39:28 2024,456789,val +0.045785773014968,25,1.224863422729356,0.253680104031206,0.1177925427279998,0.0680553564174579,0.0964385203616208,0.1803966907051743,0.2484135240572139,0.1137296424136021,0.0642036573831182,0.041915545351601,0.0945013062082387,0.1758362255546988,Mon Jan 1 20:39:28 2024,456789,test +0.0457685120413343,30,1.2046893341818,0.2556566970090994,0.1183215273794969,0.0681524652159038,0.0971051734566472,0.184105811454325,0.2501430429128706,0.1141680405379804,0.0640469054801252,0.0414486222691845,0.0949513148271337,0.1794768272504384,Mon Jan 1 20:57:10 2024,456789,val +0.0457685120413343,30,1.2046893341818,0.2556566970090994,0.1183215273794969,0.0681524652159038,0.0971051734566472,0.184105811454325,0.2501430429128706,0.1141680405379804,0.0640469054801252,0.0414486222691845,0.0949513148271337,0.1794768272504384,Mon Jan 1 20:57:10 2024,456789,test +0.0451582596008639,6,3.0867278004944403,0.2769310793237935,0.1244683412841645,0.0691290234318203,0.1135059433121272,0.1928023009304045,0.2717295188556531,0.1204382654920773,0.0651091321846289,0.0409371265587452,0.1115423535611913,0.1909474242680005,Mon Jan 1 21:25:18 2024,456789,val +0.0451582596008639,6,3.0867278004944403,0.2769310793237935,0.1244683412841645,0.0691290234318203,0.1135059433121272,0.1928023009304045,0.2717295188556531,0.1204382654920773,0.0651091321846289,0.0409371265587452,0.1115423535611913,0.1909474242680005,Mon Jan 1 21:25:18 2024,456789,test +0.0437907642875378,7,2.87096480844041,0.2786215864759391,0.1250965871216463,0.0683245176469906,0.1104825793916817,0.1908776674033969,0.2740702210663163,0.1211194707343802,0.0645858234589619,0.0400168564644639,0.1085192677118951,0.1892162007335911,Mon Jan 1 21:49:45 2024,456789,val +0.0437907642875378,7,2.87096480844041,0.2786215864759391,0.1250965871216463,0.0683245176469906,0.1104825793916817,0.1908776674033969,0.2740702210663163,0.1211194707343802,0.0645858234589619,0.0400168564644639,0.1085192677118951,0.1892162007335911,Mon Jan 1 21:49:45 2024,456789,test +0.0460281374614173,10,2.132844601298824,0.2557737321196325,0.1187312234145275,0.0685361093522126,0.0965447734927661,0.184503625510579,0.250455136540959,0.1146847307170909,0.0646462364568886,0.0420348723344981,0.0946919122607818,0.1797463209702345,Tue Jan 2 00:38:45 2024,456789,val +0.0460281374614173,10,2.132844601298824,0.2557737321196325,0.1187312234145275,0.0685361093522126,0.0965447734927661,0.184503625510579,0.250455136540959,0.1146847307170909,0.0646462364568886,0.0420348723344981,0.0946919122607818,0.1797463209702345,Tue Jan 2 00:38:45 2024,456789,test +0.0459162370197961,18,1.6160881505971496,0.2564369310793205,0.1186013856962273,0.0683798522320268,0.0977493113244218,0.181269147890911,0.2506631989596846,0.1144306374420402,0.0644393246097349,0.041946083336683,0.0956913888870511,0.1765698828473361,Tue Jan 2 00:41:23 2024,456789,val +0.0458222878870521,23,1.5766171952745889,0.2538361508452502,0.1178429138120607,0.0681277329868823,0.0964179745061633,0.1804750991898051,0.2486085825747692,0.1138890345612116,0.0643468741555868,0.0420217982499197,0.0945253174245212,0.1759147121613667,Tue Jan 2 00:41:23 2024,456789,test +0.0458870225539365,24,1.2388587589333286,0.254122236670998,0.1181073031386425,0.0682560916056603,0.0966068459559823,0.1806739762772854,0.2489986996098797,0.1142220079352694,0.064569266505156,0.0421427361736232,0.0947963005497044,0.1761408862519936,Tue Jan 2 02:04:13 2024,456789,val +0.0458870225539365,24,1.2388587589333286,0.254122236670998,0.1181073031386425,0.0682560916056603,0.0966068459559823,0.1806739762772854,0.2489986996098797,0.1142220079352694,0.064569266505156,0.0421427361736232,0.0947963005497044,0.1761408862519936,Tue Jan 2 02:04:13 2024,456789,test +0.0457696763444301,8,1.7390844722306271,0.2558777633289953,0.1185570869979222,0.0682428891477197,0.0970101411824032,0.1842854948556642,0.2503771131339369,0.1143796954984135,0.0641822769165541,0.0415760968143352,0.0950801483461337,0.1795814760593291,Tue Jan 2 02:19:42 2024,456789,val +0.0457696763444301,8,1.7390844722306271,0.2558777633289953,0.1185570869979222,0.0682428891477197,0.0970101411824032,0.1842854948556642,0.2503771131339369,0.1143796954984135,0.0641822769165541,0.0415760968143352,0.0950801483461337,0.1795814760593291,Tue Jan 2 02:19:42 2024,456789,test +0.0458796791606591,26,1.594034348366242,0.2562808842652762,0.1185510977043592,0.0683072711034656,0.097767786713749,0.1811907394062802,0.2504681404421294,0.1142710000028752,0.0642956611759628,0.0418394863609267,0.0956542834364961,0.1764783151395567,Tue Jan 2 02:34:05 2024,456789,val +0.0458796791606591,26,1.594034348366242,0.2562808842652762,0.1185510977043592,0.0683072711034656,0.097767786713749,0.1811907394062802,0.2504681404421294,0.1142710000028752,0.0642956611759628,0.0418394863609267,0.0956542834364961,0.1764783151395567,Tue Jan 2 02:34:05 2024,456789,test +0.0452949714312369,14,1.699042799508366,0.253680104031206,0.1166126685338633,0.0673226009988978,0.0963125166333418,0.1821015757141304,0.2484135240572139,0.1125036831271927,0.0635479394810006,0.0415221680743996,0.0942128811802272,0.1775351456393822,Tue Jan 2 03:24:37 2024,456789,val +0.0452949714312369,14,1.699042799508366,0.253680104031206,0.1166126685338633,0.0673226009988978,0.0963125166333418,0.1821015757141304,0.2484135240572139,0.1125036831271927,0.0635479394810006,0.0415221680743996,0.0942128811802272,0.1775351456393822,Tue Jan 2 03:24:37 2024,456789,test +0.0464698496354176,5,3.0746031335042727,0.2553055916774999,0.1185659630441967,0.0688024595090549,0.0975702876968352,0.1845030860787665,0.2501040312093595,0.1148906125451345,0.0653310912708242,0.0427682199321449,0.0959215824043626,0.1796193968602361,Tue Jan 2 06:42:21 2024,456789,val +0.0464698496354176,5,3.0746031335042727,0.2553055916774999,0.1185659630441967,0.0688024595090549,0.0975702876968352,0.1845030860787665,0.2501040312093595,0.1148906125451345,0.0653310912708242,0.0427682199321449,0.0959215824043626,0.1796193968602361,Tue Jan 2 06:42:21 2024,456789,test +0.0073796178531873,15,1.3147350513656055,0.1964889466840026,0.0469480145400297,0.0152808422375059,0.0811564760030229,0.1335911969529094,0.1927438231469415,0.0462254013593456,0.0157957865111361,0.0079340829901037,0.0815531611963464,0.1338228302212229,Tue Jan 2 06:43:20 2024,456789,val +0.0073731806913014,30,1.199275309418574,0.1960338101430403,0.0468661457582884,0.0152630724010138,0.0812237686336682,0.1333579839596437,0.1924317295188531,0.0463518462447877,0.015937615108751,0.0079874526780009,0.0817469309730068,0.1335873704012188,Tue Jan 2 06:43:20 2024,456789,test +0.003625782245084,4,4.809684177213899,0.1283745123537044,0.0348921101668894,0.012257511290956,0.0310409436097302,0.1213293174895824,0.1294928478543546,0.0365407075293186,0.0127141254631526,0.0034259644498899,0.0318343211208828,0.1217734135019891,Tue Jan 2 07:31:08 2024,456789,val +0.003625782245084,4,4.809684177213899,0.1283745123537044,0.0348921101668894,0.012257511290956,0.0310409436097302,0.1213293174895824,0.1294928478543546,0.0365407075293186,0.0127141254631526,0.0034259644498899,0.0318343211208828,0.1217734135019891,Tue Jan 2 07:31:08 2024,456789,test +0.0225510062011262,5,2.2113424255446046,0.2268010403120906,0.0953216768479352,0.047412668474279,0.0931199664488868,0.1514624461826974,0.2252665799739892,0.094702312449906,0.0470463138747092,0.022357999417305,0.0936797657300946,0.1512823129068211,Tue Jan 2 07:58:30 2024,456789,val +0.0225510062011262,5,2.2113424255446046,0.2268010403120906,0.0953216768479352,0.047412668474279,0.0931199664488868,0.1514624461826974,0.2252665799739892,0.094702312449906,0.0470463138747092,0.022357999417305,0.0936797657300946,0.1512823129068211,Tue Jan 2 07:58:30 2024,456789,test +0.0459130247719115,9,2.7782988241826447,0.2539921976592945,0.1180488221883235,0.0682203168426336,0.0965560614324393,0.1805762959844763,0.248751625487643,0.1140507866923812,0.0644632143860018,0.0421388887368817,0.0946353651555276,0.1759801176669234,Tue Jan 2 13:15:19 2024,456789,val +0.0459130247719115,9,2.7782988241826447,0.2539921976592945,0.1180488221883235,0.0682203168426336,0.0965560614324393,0.1805762959844763,0.248751625487643,0.1140507866923812,0.0644632143860018,0.0421388887368817,0.0946353651555276,0.1759801176669234,Tue Jan 2 13:15:19 2024,456789,test +0.0464747875133674,8,2.8427601690042,0.2478933680103998,0.1170409781957937,0.068535251443348,0.0944026829712295,0.1851763460947434,0.2430039011703479,0.113810380954045,0.0654200403470458,0.0430725834538907,0.0927744371964068,0.1804441522801655,Tue Jan 2 14:13:04 2024,456789,val +0.0464747875133674,8,2.8427601690042,0.2478933680103998,0.1170409781957937,0.068535251443348,0.0944026829712295,0.1851763460947434,0.2430039011703479,0.113810380954045,0.0654200403470458,0.0430725834538907,0.0927744371964068,0.1804441522801655,Tue Jan 2 14:13:04 2024,456789,test +0.0994188428417495,24,1.5952179212830429,0.2693634187713682,0.1747817061164394,0.1274691093669434,0.1195304072729161,0.2354976886077685,0.2496445435232596,0.1565046348793855,0.11155382450581,0.0855086052673248,0.1106527019952549,0.2260206870019921,Thu Jan 4 11:38:53 2024,456789,val +0.0994188428417495,24,1.5952179212830429,0.2693634187713682,0.1747817061164394,0.1274691093669434,0.1195304072729161,0.2354976886077685,0.2496445435232596,0.1565046348793855,0.11155382450581,0.0855086052673248,0.1106527019952549,0.2260206870019921,Thu Jan 4 11:38:53 2024,456789,test +0.0394560281090867,30,3.1862685663531134,0.2383355006501919,0.1027594501091738,0.0592303218982046,0.0892394097132374,0.1677472343784154,0.2335890767230138,0.099628281766973,0.0562457101832107,0.0361636555429078,0.0871248519690922,0.1640184698846089,Thu Jan 4 21:49:28 2024,456789,val +0.0392554199566281,20,3.213299367484795,0.2444343302990865,0.108150845555549,0.0607002483457466,0.0910971409689517,0.1807038884219759,0.2398179453836119,0.1050311553104008,0.0578559334152436,0.0363030354214102,0.0890799776537196,0.1774314753491287,Thu Jan 4 21:49:28 2024,456789,test +0.041197262742757,21,3.201987003299598,0.2417945383615053,0.1124304663944054,0.0633536445211857,0.091876927006545,0.180656508100245,0.2362548764629358,0.1085786081679759,0.0595364165947774,0.0373922855859654,0.0899118296150398,0.1764098096095462,Thu Jan 4 22:23:00 2024,456789,val +0.041197262742757,21,3.201987003299598,0.2417945383615053,0.1124304663944054,0.0633536445211857,0.091876927006545,0.180656508100245,0.2362548764629358,0.1085786081679759,0.0595364165947774,0.0373922855859654,0.0899118296150398,0.1764098096095462,Thu Jan 4 22:23:00 2024,456789,test +0.012659688806496,7,1.8585701727626849,0.2026137841352379,0.0627178184629656,0.0256150444910104,0.0706455594227085,0.1441109686943186,0.197061118335498,0.0592883807529717,0.0245956978216092,0.0121430132114287,0.0691191610602494,0.1422491890923745,Sat Jan 6 07:56:57 2024,456789,val +0.010765412775618,6,2.0053069898870777,0.1891157347204136,0.0602019710976954,0.0229393395161712,0.0785310589368792,0.1487037066658719,0.1918465539661873,0.0631192993050476,0.0270645657634828,0.0141302255378558,0.0809419562284611,0.1496591054203035,Sat Jan 6 07:56:57 2024,456789,test +0.0677776063087146,14,1.376197649456648,0.2764601807113496,0.150557840463341,0.0959804951387144,0.105847658747751,0.2003946460116721,0.2702658248826311,0.141918404740677,0.0868055945395642,0.0592187094823942,0.1019359007969743,0.1946108638748818,Sun Jan 7 09:32:14 2024,456789,val +0.0654118488121307,18,1.2882999341416577,0.2509553662845866,0.1376437280361042,0.0900167599986719,0.1016047490503871,0.1923943432895639,0.2374597114642773,0.127443428829887,0.0824568558213179,0.0593731797825525,0.0954547646795058,0.1852724517805865,Sun Jan 7 09:32:14 2024,456789,test +0.1895679552830471,22,1.1546318171617318,0.4257709275026868,0.3054789709778319,0.2361757275131287,0.1859505357603936,0.3495296055751685,0.4008101050917379,0.2833463154719041,0.2168345823321041,0.1728020333532612,0.1735841971403758,0.3315691065263819,Mon Jan 8 10:22:27 2024,456789,val +0.1895679552830471,22,1.1546318171617318,0.4257709275026868,0.3054789709778319,0.2361757275131287,0.1859505357603936,0.3495296055751685,0.4008101050917379,0.2833463154719041,0.2168345823321041,0.1728020333532612,0.1735841971403758,0.3315691065263819,Mon Jan 8 10:22:27 2024,456789,test +0.1854480330519426,30,1.5837523580416164,0.4099394065646508,0.2951411867132874,0.2296434112759729,0.1797438150398729,0.3393186780512545,0.3988103564692152,0.2801139013700867,0.2138749297756441,0.1699310028911202,0.1727359390182834,0.331209495313559,Tue Jan 9 08:53:09 2024,456789,val +0.1833169884492602,23,1.5934872530759234,0.4100057803466762,0.2940795610113794,0.2278132640163521,0.1794204816371101,0.3389191656763772,0.4020767441604236,0.2828888407069284,0.2163787780008037,0.1723391922042541,0.1743595288561575,0.3339618846587842,Tue Jan 9 08:53:09 2024,456789,test +0.0390662774751945,15,3.085446947582444,0.2408062418725586,0.1074924443860586,0.0598605071925622,0.0894621728215985,0.1833677334196221,0.2372951885565638,0.1055310373360786,0.0578486326523472,0.0372391583711055,0.0879208562585449,0.1797857322794967,Tue Jan 9 17:51:28 2024,456789,val +0.0390662774751945,15,3.085446947582444,0.2408062418725586,0.1074924443860586,0.0598605071925622,0.0894621728215985,0.1833677334196221,0.2372951885565638,0.1055310373360786,0.0578486326523472,0.0372391583711055,0.0879208562585449,0.1797857322794967,Tue Jan 9 17:51:28 2024,456789,test +0.0402343679992864,28,3.1931566445372503,0.2359037711313363,0.1069155814526652,0.0610754263267073,0.0881760819310576,0.1764652428781531,0.2300910273081894,0.1029965393031631,0.0572355277306617,0.0363947387093531,0.0857819486940624,0.1725695124885203,Tue Jan 9 17:54:25 2024,456789,val +0.0402343679992864,28,3.1931566445372503,0.2359037711313363,0.1069155814526652,0.0610754263267073,0.0881760819310576,0.1764652428781531,0.2300910273081894,0.1029965393031631,0.0572355277306617,0.0363947387093531,0.0857819486940624,0.1725695124885203,Tue Jan 9 17:54:25 2024,456789,test +0.1722675067672823,15,1.7208031031796018,0.3873105417988918,0.2759421772711643,0.2136578813958943,0.1690235776813292,0.3229407030966867,0.3767674761522255,0.2644542115830829,0.2018532767023788,0.1604677026976572,0.1635955043621824,0.3161911028145679,Wed Jan 10 20:48:37 2024,456789,val +0.1722189861350389,18,1.6922274332430567,0.3867724951936896,0.2758300225783304,0.213719783403078,0.1685867967190334,0.3232264554213595,0.3803540896153764,0.2698231587043696,0.2075360318757931,0.1661250251301058,0.166369508234862,0.3204588241925328,Wed Jan 10 20:48:37 2024,456789,test +0.1800161732142454,33,1.64794299621571,0.3987173050601714,0.2876560154729609,0.223611788679478,0.1745002845840889,0.3341980159154957,0.3817978921190557,0.2689784412025534,0.2053744457386753,0.1629989679378021,0.1678849120235985,0.3205949663562036,Thu Jan 11 10:24:01 2024,456789,val +0.1769522691717872,25,1.652649973700809,0.3951423154350136,0.284112167493926,0.2202658225183404,0.1728798164406487,0.3323380070088462,0.3835394655333926,0.2708486051876414,0.2072294044616425,0.1648375033532527,0.1684754504062519,0.3221560586016022,Thu Jan 11 10:24:01 2024,456789,test +0.1893825766807425,16,1.6454334900458918,0.4156443702147558,0.3010799187320908,0.2346284711247342,0.1838936327403839,0.3465848919266567,0.4031373292222754,0.2865242277253098,0.2199567513120452,0.1756983945153955,0.1773623725248826,0.33519095249776,Fri Jan 12 04:31:29 2024,456789,val +0.1848292154487148,15,1.6566942404999223,0.4108856372214939,0.2958909352126092,0.2296349510772571,0.1809492489743522,0.3412014353399775,0.4047412354344898,0.289521205593821,0.2233442535801843,0.1790767970046786,0.1776956728764784,0.336296152467113,Fri Jan 12 04:31:29 2024,456789,test +0.1965498475985183,24,1.5991850097873217,0.4253093522356456,0.3098910672800771,0.2423586595722438,0.1884117341343142,0.3545299543835447,0.4065312867876934,0.2900864650937491,0.2235888800796212,0.1794154181616473,0.179109302449887,0.3385557815905521,Fri Jan 12 04:54:36 2024,456789,val +0.1952219600142736,30,1.5920350937141283,0.4234492602561952,0.3083467371632561,0.2409330040768186,0.1872767099981874,0.3526866983183989,0.4111751253501944,0.2938390076183674,0.2267840642620414,0.18228290576713,0.1806248139916089,0.3406203948083564,Fri Jan 12 04:54:36 2024,456789,test +0.1132774498124881,29,1.5885623058507652,0.3295474567312209,0.2090227658153882,0.1496157952737883,0.1351177118363027,0.2575726909862487,0.3203225365586945,0.1967548704322341,0.1366253521997074,0.1011357839521532,0.128936101223533,0.2534423769630654,Fri Jan 12 05:51:51 2024,456789,val +0.1132774498124881,29,1.5885623058507652,0.3295474567312209,0.2090227658153882,0.1496157952737883,0.1351177118363027,0.2575726909862487,0.3203225365586945,0.1967548704322341,0.1366253521997074,0.1011357839521532,0.128936101223533,0.2534423769630654,Fri Jan 12 05:51:51 2024,456789,test +0.2016708441481354,26,1.6173178177176744,0.4294463999427622,0.3153025164430465,0.2478184467317396,0.1904349546305809,0.3592865214702794,0.4062237022781922,0.2892203602500015,0.2222930366602876,0.1776200693756742,0.17805188056887,0.3388520089019849,Sat Jan 13 09:24:06 2024,456789,val +0.2006259456224443,35,1.612828048950083,0.4278176135059729,0.3137156132500835,0.2465127448317615,0.189605236105966,0.3575739554522817,0.4070677631650607,0.2896557382103424,0.2228312698001959,0.1783507144214693,0.1783611030261803,0.3394224712997829,Sat Jan 13 09:24:06 2024,456789,test +0.1971642016119933,27,1.6348801513731732,0.4221967654945679,0.308699397369679,0.2423874186108029,0.1873813761746,0.3524722235034286,0.4045118307679009,0.2886512257284772,0.2226198237335624,0.1785806132733241,0.1775582622929678,0.3373219065123344,Sat Jan 13 11:56:05 2024,456789,val +0.1959511548420797,30,1.6335313076708071,0.4215631851345677,0.307510611222716,0.2411046142000582,0.1867370407080089,0.3513371298581166,0.4054211987771746,0.2896325837638149,0.2235859900597737,0.1794220065429669,0.1782112182190665,0.3376299025814531,Sat Jan 13 11:56:05 2024,456789,test +0.1766944482294604,28,1.155449237152498,0.4025247729462778,0.2859596956950559,0.2204176876196948,0.1760012958284986,0.328051260354705,0.3834583331488563,0.2645355838889126,0.200146257017177,0.1588804163936584,0.1635976508997293,0.3110297819859123,Sun Jan 14 01:29:54 2024,456789,val +0.1728704371052166,17,1.2315315243149825,0.3984746473203263,0.2823237895441968,0.21661716214299,0.1743753578015599,0.3259103994517834,0.3841555382179568,0.2671675680719294,0.2034948911207503,0.1626073529468017,0.1656832493823925,0.3158930020457863,Sun Jan 14 01:29:54 2024,456789,test +0.1920613605331065,30,1.613751195388106,0.4225698079226418,0.3053798248808434,0.2378510387636711,0.1852960154497894,0.3487678606011491,0.4022958161972436,0.2859483892763756,0.2195768008930747,0.1751446763715023,0.1750871715696569,0.3347306315036027,Sun Jan 14 07:13:59 2024,456789,val +0.1899799883899898,26,1.6169384957747763,0.4204944737980816,0.30310914745992,0.2355929113194784,0.1844333852407765,0.3476991854379127,0.4036915686285349,0.2867655824594863,0.2199791680627419,0.1752875976608934,0.1755769140226974,0.3358481413729531,Sun Jan 14 07:13:59 2024,456789,test +0.1932049658049393,19,1.6459132102687488,0.418334176417811,0.3042512757549553,0.2380493113172827,0.1843457808596927,0.3470144230286481,0.404716893360156,0.2867856836712124,0.2197837094321253,0.1754541822284109,0.1766979195896708,0.3373786977377307,Mon Jan 15 07:39:11 2024,456789,val +0.1924758886646524,33,1.6195457535365934,0.4193874562560722,0.3046875249552083,0.2377189291414584,0.1845670047721518,0.347641577903319,0.4067792064438138,0.289022121495015,0.2222381412825655,0.1777443016990785,0.1774619315387756,0.3390965727942517,Mon Jan 15 07:39:11 2024,456789,test +0.1993215300456555,37,1.612402267844909,0.4304082129154168,0.313756225354385,0.2455829675967209,0.1898134216469935,0.355761580001066,0.4097025525622525,0.2935591664321547,0.2263521755943996,0.1812501318114711,0.1784561147416624,0.3405993997703646,Mon Jan 15 08:41:16 2024,456789,val +0.1975071802864783,35,1.610485027214786,0.4291014021520639,0.3121047690666265,0.2438224052814563,0.1890325459103117,0.3546629788403369,0.4094862065965325,0.2935363132130585,0.2265350106901457,0.1816001367350277,0.1784691227348917,0.3407077608325425,Mon Jan 15 08:41:16 2024,456789,test +0.1933468329254713,22,1.5955444779997916,0.4199953457433103,0.3050806959440995,0.2385299943185583,0.1859248375716904,0.3505272952997275,0.4048145485792417,0.289057378075242,0.2226896399217993,0.1782933040151778,0.1774009816897228,0.3367420413169875,Mon Jan 15 09:41:59 2024,456789,val +0.1928006874676236,24,1.5908372792329335,0.418535098650698,0.3039872415357573,0.2377084273430874,0.1852060772622194,0.3494468319554391,0.4052896019693263,0.2901705325787992,0.2239269374702638,0.1795452451397541,0.1783100159535628,0.3381746523361855,Mon Jan 15 09:41:59 2024,456789,test +0.1163500078134047,31,2.9176030974152107,0.3226408661612834,0.2177985727025335,0.1571556388176403,0.1324585927421973,0.2674164279822154,0.3191274705994266,0.2119422035623724,0.1510954718664248,0.1108865224501665,0.1308016667351617,0.2646066255431548,Tue Jan 16 04:29:49 2024,456789,val +0.1151486193104934,26,2.920389707889336,0.3211046035066163,0.2162995052319697,0.1558819982112352,0.1318466375988958,0.2669849514779717,0.3208775169424473,0.2134703524477858,0.1523443523678372,0.1118243797163041,0.1314167181392924,0.2660032331500122,Tue Jan 16 04:29:49 2024,456789,test +0.1656666849245744,24,2.153675915670953,0.3941670933726937,0.2773770526110028,0.2102478310077934,0.1695285839946426,0.3248838262979353,0.3831828707068923,0.2663212513708795,0.2003011519742146,0.1565591576231608,0.1623894477277345,0.3147435189018528,Tue Jan 16 12:39:40 2024,456789,val +0.1637244697037911,19,2.1689922089709066,0.3924548208091566,0.274875577411643,0.2081396614648745,0.1680879668658668,0.3228520097603173,0.3818347544501385,0.265618135216264,0.2003761273523002,0.1570525718897495,0.162341199560276,0.3156213598934125,Tue Jan 16 12:39:40 2024,456789,test +0.049448415155385,26,2.573289893717736,0.2748241874609645,0.1373034603305925,0.0790677885059086,0.098884932324965,0.1932497669615804,0.2615592351127927,0.1262494722911582,0.0711197521576496,0.0434255719105283,0.0936461738906148,0.1837463522316675,Sun Jan 21 07:10:48 2024,456789,val +0.049448415155385,26,2.573289893717736,0.2748241874609645,0.1373034603305925,0.0790677885059086,0.098884932324965,0.1932497669615804,0.2615592351127927,0.1262494722911582,0.0711197521576496,0.0434255719105283,0.0936461738906148,0.1837463522316675,Sun Jan 21 07:10:48 2024,456789,test +0.0365292545382168,5,2.2622990843649804,0.2124887124790346,0.1003086454922369,0.0574319432432276,0.0811924291539347,0.1597740344108888,0.2107197093732849,0.0976063862028806,0.0541818838092851,0.0330744807819391,0.0789532535101956,0.1582153816161014,Sun Jan 21 13:31:20 2024,456789,val +0.0365292545382168,5,2.2622990843649804,0.2124887124790346,0.1003086454922369,0.0574319432432276,0.0811924291539347,0.1597740344108888,0.2107197093732849,0.0976063862028806,0.0541818838092851,0.0330744807819391,0.0789532535101956,0.1582153816161014,Sun Jan 21 13:31:20 2024,456789,test +0.0832986608128533,30,3.041216642831989,0.2662295346514202,0.1665165322491746,0.1153358728322493,0.106625145195597,0.221803963805843,0.2587407450848887,0.1597345427919344,0.1082270187662969,0.0757847800443463,0.102976892738773,0.2184364002872991,Sun Jan 21 17:21:14 2024,456789,val +0.0832986608128533,30,3.041216642831989,0.2662295346514202,0.1665165322491746,0.1153358728322493,0.106625145195597,0.221803963805843,0.2587407450848887,0.1597345427919344,0.1082270187662969,0.0757847800443463,0.102976892738773,0.2184364002872991,Sun Jan 21 17:21:14 2024,456789,test +0.067945253388372,28,2.473010689460991,0.2385598221459548,0.1382050546216192,0.0932818912846819,0.0994230722751356,0.2008120595297716,0.2389516330992057,0.1385907538842819,0.0903025045853273,0.0627137262861368,0.0992752123480886,0.2004688437702076,Sun Jan 21 19:20:19 2024,456789,val +0.0665674593401427,30,2.4726671584250868,0.2382327528892861,0.1367828343565307,0.0917330579441558,0.0985477958011899,0.199813857386265,0.2403268938184588,0.1388652725937323,0.0902745934219205,0.0627698280896946,0.0997706599540146,0.2008522953400196,Sun Jan 21 19:20:19 2024,456789,test +0.0931410739450857,26,2.2569886844300013,0.2828318068530036,0.175554091566582,0.1236366209929575,0.1184029810123881,0.2337275714302712,0.270733117496503,0.1636942766230018,0.1117072176594325,0.0817006074475287,0.1121230088996185,0.2263089573296754,Mon Jan 22 02:32:48 2024,456789,val +0.0923146596075177,27,2.256110493205149,0.2806743364407732,0.1738874177782678,0.1224518136235766,0.1177862055839632,0.2329801690967512,0.2714412916045762,0.1645031508529706,0.1123317637920865,0.0821901778498594,0.112765884887815,0.2266610910641991,Mon Jan 22 02:32:48 2024,456789,test +0.0751240209280265,27,3.8135595367676416,0.2798572598278214,0.1714142351629332,0.1124715852787066,0.105283135806052,0.2382627981047684,0.273472649999544,0.1646813346435603,0.1061376458459884,0.0697404012411755,0.1013180513758229,0.2333421165623164,Mon Jan 22 07:30:42 2024,456789,val +0.0728485920407549,22,3.820860662680453,0.2781310010344503,0.1691504839833641,0.1100131850303654,0.1043161173368861,0.235559792366571,0.2742968580953015,0.1656612743865062,0.1073404182638982,0.0710012128895343,0.101607956617687,0.2336377845190934,Mon Jan 22 07:30:42 2024,456789,test +0.0608791540550838,24,3.944519735523954,0.2469358590069414,0.1451890233149313,0.0935839111645882,0.0914212715356625,0.2098826004305627,0.2364798672419267,0.1346979104679793,0.0840678286813978,0.05308310990537,0.0859187849076245,0.2045203193987662,Tue Jan 30 16:19:43 2024,456789,val +0.0601081421648775,30,3.9399126780901215,0.2453145177385802,0.1442325854870148,0.0927377266832365,0.0905090369136635,0.2089630796155349,0.2370148639476066,0.1351321411830296,0.0847531016988984,0.05386845193874,0.085959513854081,0.2046992268595535,Tue Jan 30 16:19:43 2024,456789,test +0.0608791540550838,24,3.944519735523954,0.2469358590069414,0.1451890233149313,0.0935839111645882,0.0914212715356625,0.2098826004305627,0.2364798672419267,0.1346979104679793,0.0840678286813978,0.05308310990537,0.0859187849076245,0.2045203193987662,Sat Feb 3 09:11:58 2024,456789,val +0.0601081421648775,30,3.9399126780901215,0.2453145177385802,0.1442325854870148,0.0927377266832365,0.0905090369136635,0.2089630796155349,0.2370148639476066,0.1351321411830296,0.0847531016988984,0.05386845193874,0.085959513854081,0.2046992268595535,Sat Feb 3 09:11:58 2024,456789,test +0.0675822207930474,12,1.8285817304536487,0.2861091762160233,0.152523919635802,0.0962538109108903,0.1082595001081577,0.2041086695190062,0.2876461346118997,0.1498280731308161,0.0911963337584602,0.0614927787867049,0.107805273911507,0.2011804801077635,Wed Mar 6 19:17:41 2024,456789,val +0.0675822207930474,12,1.8285817304536487,0.2861091762160233,0.152523919635802,0.0962538109108903,0.1082595001081577,0.2041086695190062,0.2876461346118997,0.1498280731308161,0.0911963337584602,0.0614927787867049,0.107805273911507,0.2011804801077635,Wed Mar 6 19:17:41 2024,456789,test +0.0461074576989291,8,2.004684211913937,0.2544993498049381,0.118477398284799,0.0685838977417521,0.0953404050244289,0.1843535820163173,0.2493368010403088,0.1144856268940758,0.0646547113828999,0.0421483088007805,0.0935811756540631,0.1797647377159858,Wed Mar 6 19:24:02 2024,456789,val +0.0461074576989291,8,2.004684211913937,0.2544993498049381,0.118477398284799,0.0685838977417521,0.0953404050244289,0.1843535820163173,0.2493368010403088,0.1144856268940758,0.0646547113828999,0.0421483088007805,0.0935811756540631,0.1797647377159858,Wed Mar 6 19:24:02 2024,456789,test +0.0620349296047563,8,1.991828693248942,0.2656680479451986,0.1403796326183569,0.0883284272507563,0.102552133861856,0.1975920149225402,0.2600158818045261,0.1355392339413597,0.084557915379441,0.0587252571071732,0.1007135098489949,0.1955264068043994,Thu Mar 7 07:36:27 2024,456789,val +0.055377627592833,7,2.0594369889913984,0.2421486122649633,0.124263127783201,0.0781103649954232,0.0977373616373919,0.1857679282091528,0.2509115517960482,0.1317251413304328,0.0837268319657806,0.0592827173364581,0.1006907053694778,0.193848566348427,Thu Mar 7 07:36:27 2024,456789,test +0.0453964562904908,12,1.401348980109756,0.254122236670998,0.1169285438651287,0.067523794038083,0.0964877224186474,0.1823396179829076,0.2489986996098797,0.112998499564942,0.0639143719347054,0.0417497911195076,0.0945075650405822,0.1778267252355657,Thu Mar 7 07:38:47 2024,456789,val +0.0453964562904908,12,1.401348980109756,0.254122236670998,0.1169285438651287,0.067523794038083,0.0964877224186474,0.1823396179829076,0.2489986996098797,0.112998499564942,0.0639143719347054,0.0417497911195076,0.0945075650405822,0.1778267252355657,Thu Mar 7 07:38:47 2024,456789,test +0.0454224535953271,17,1.7880993422329117,0.2539921976592945,0.1168700796085364,0.0674880226872587,0.0964150599759419,0.182228856588987,0.248751625487643,0.1128266440594742,0.0638080913007787,0.0417456683001681,0.09435818406512,0.1776659566504955,Thu Mar 7 07:57:50 2024,456789,val +0.0454224535953271,17,1.7880993422329117,0.2539921976592945,0.1168700796085364,0.0674880226872587,0.0964150599759419,0.182228856588987,0.248751625487643,0.1128266440594742,0.0638080913007787,0.0417456683001681,0.09435818406512,0.1776659566504955,Thu Mar 7 07:57:50 2024,456789,test +0.0459766119750063,10,1.6020072621810495,0.2538101430429095,0.1179782526504932,0.0682592559967794,0.0964285159717658,0.1805208679138608,0.2484785435630656,0.1137875586114811,0.0642532076096663,0.0419398047694579,0.0944554259584617,0.1759016310602554,Thu Mar 7 08:22:43 2024,456789,val +0.0459766119750063,10,1.6020072621810495,0.2538101430429095,0.1179782526504932,0.0682592559967794,0.0964285159717658,0.1805208679138608,0.2484785435630656,0.1137875586114811,0.0642532076096663,0.0419398047694579,0.0944554259584617,0.1759016310602554,Thu Mar 7 08:22:43 2024,456789,test +0.0605858104424772,8,1.986871244732208,0.2561737460001794,0.1360708947882275,0.0860212795320674,0.1003870430682985,0.1924168613915124,0.2546160340962535,0.1342293100252579,0.0838635958812386,0.0575827361415855,0.0993037737768311,0.1920884264804549,Thu Mar 7 19:46:02 2024,456789,val +0.0547977217694287,7,2.053388431646482,0.237763865867636,0.1243613132598344,0.0783327717571664,0.0943248820510119,0.1816947423719679,0.2451157770834566,0.129903677702636,0.0830041150556185,0.0585611140899908,0.0963458630639889,0.1892605326901023,Thu Mar 7 19:46:02 2024,456789,test +0.0674468248377443,11,1.867585077988388,0.2455101392741912,0.137850846928382,0.09181084699613,0.1002784579004408,0.19590376395772,0.2467945955780064,0.136522421484674,0.0878824806041819,0.0620521401524644,0.1004937828307917,0.1974823774498073,Thu Mar 7 20:39:53 2024,456789,val +0.0634075175280432,18,1.7624659056912124,0.2329467542875306,0.1302584845002121,0.0863667957476178,0.099460299162904,0.189552962019166,0.2348864721063459,0.1324108342060952,0.0879952028040995,0.0641506901549218,0.0996557457146562,0.1909396576511333,Thu Mar 7 20:39:53 2024,456789,test +0.0461334627430497,40,1.745108318618064,0.2543693107932346,0.118418914033361,0.0685481281797738,0.0952769003161693,0.1842559017235081,0.2490897269180721,0.1143144605698449,0.0645486615921167,0.042144696045626,0.0933983494253615,0.1796039691309157,Thu Mar 7 21:22:32 2024,456789,val +0.0461334627430497,40,1.745108318618064,0.2543693107932346,0.118418914033361,0.0685481281797738,0.0952769003161693,0.1842559017235081,0.2490897269180721,0.1143144605698449,0.0645486615921167,0.042144696045626,0.0933983494253615,0.1796039691309157,Thu Mar 7 21:22:32 2024,456789,test +0.0441502955260203,4,2.433862712306924,0.2436540962288654,0.1133994785435669,0.065818611534272,0.0926949030328476,0.1792890112572704,0.2388816644993467,0.1096823369841658,0.0620417714546582,0.0402769502005985,0.0909684619635823,0.1751807584105997,Thu Mar 7 21:51:29 2024,456789,val +0.0441502955260203,4,2.433862712306924,0.2436540962288654,0.1133994785435669,0.065818611534272,0.0926949030328476,0.1792890112572704,0.2388816644993467,0.1096823369841658,0.0620417714546582,0.0402769502005985,0.0909684619635823,0.1751807584105997,Thu Mar 7 21:51:29 2024,456789,test +0.0554972814798414,10,1.877601398513573,0.2261647324456764,0.1205955754707293,0.0777240695365257,0.0919368558377465,0.1761959930382465,0.2349783816759763,0.1228139993497922,0.0758425051178474,0.0520849452966258,0.094165844493486,0.1821366443780601,Fri Mar 8 09:35:08 2024,456789,val +0.0554972814798414,10,1.877601398513573,0.2261647324456764,0.1205955754707293,0.0777240695365257,0.0919368558377465,0.1761959930382465,0.2349783816759763,0.1228139993497922,0.0758425051178474,0.0520849452966258,0.094165844493486,0.1821366443780601,Fri Mar 8 09:35:08 2024,456789,test +0.0460062746327701,17,1.7720311121925498,0.2540572171651462,0.1181628178376815,0.0683832941942736,0.095166803981946,0.1841155397475401,0.248751625487643,0.1139934747356504,0.0642890096599489,0.0419211132949035,0.0932796720819013,0.1794731581198023,Fri Mar 8 10:02:16 2024,456789,val +0.0458870225539365,39,1.7393758887248247,0.254122236670998,0.1181073031386425,0.0682560916056603,0.0966068459559823,0.1806739762772854,0.2489986996098797,0.1142220079352694,0.064569266505156,0.0421427361736232,0.0947963005497044,0.1761408862519936,Fri Mar 8 10:02:16 2024,456789,test +0.0620849195193176,11,1.8557853316537,0.2408127621992549,0.1314613376514274,0.0859577768519815,0.0968954394635108,0.191754732147623,0.2398696711962179,0.1312532969481677,0.0849635228465307,0.0599916346617659,0.0960353271804474,0.1927818798496836,Fri Mar 8 10:57:51 2024,456789,val +0.0620849195193176,11,1.8557853316537,0.2408127621992549,0.1314613376514274,0.0859577768519815,0.0968954394635108,0.191754732147623,0.2398696711962179,0.1312532969481677,0.0849635228465307,0.0599916346617659,0.0960353271804474,0.1927818798496836,Fri Mar 8 10:57:51 2024,456789,test +0.04600064988706096,12,1.8497342001888974,0.2546814044213231,0.11840681008760644,0.06843770038886718,0.09673874625233088,0.18119591182630765,0.24912873862158322,0.11423750837693362,0.06457510793350847,0.04214559555630855,0.09484797768627365,0.1763340912362437,Fri Mar 8 11:15:57 2024,456789,val +0.04588702255393656,40,1.7490093688927446,0.254122236670998,0.11810730313864252,0.0682560916056603,0.09660684595598233,0.18067397627728543,0.2492180168908319,0.11441594356190243,0.06469158628564463,0.042226961587848326,0.09473393504014992,0.1759818794219504,Fri Mar 8 11:15:57 2024,456789,test diff --git a/test_wsi_report.sh b/test_wsi_report.sh new file mode 100644 index 0000000..2efc9de --- /dev/null +++ b/test_wsi_report.sh @@ -0,0 +1,28 @@ +model='histgen' +max_length=100 +epochs=40 +region_size=96 +prototype_num=512 + +python main_test_AllinOne.py \ + --image_dir /path/to/feature \ + --ann_path /path/to/json \ + --dataset_name wsi_report \ + --model_name $model \ + --max_seq_length $max_length \ + --threshold 10 \ + --batch_size 1 \ + --epochs $epochs \ + --step_size 1 \ + --topk 512 \ + --cmm_size 2048 \ + --cmm_dim 512 \ + --region_size $region_size \ + --prototype_num $prototype_num \ + --save_dir /path/to/storage \ + --step_size 1 \ + --gamma 0.8 \ + --seed 42 \ + --log_period 1000 \ + --load /path/to/checkpoint \ + --beam_size 3 diff --git a/train_wsi_report.sh b/train_wsi_report.sh new file mode 100644 index 0000000..192a63b --- /dev/null +++ b/train_wsi_report.sh @@ -0,0 +1,30 @@ +model='histgen' +max_length=100 +epochs=40 +region_size=96 +prototype_num=512 + +python main_train_AllinOne.py \ + --image_dir /path/to/feature \ + --ann_path /path/to/json \ + --dataset_name wsi_report \ + --model_name $model \ + --max_seq_length $max_length \ + --num_layers 3 \ + --threshold 10 \ + --batch_size 1 \ + --epochs $epochs \ + --lr_ve 1e-4 \ + --lr_ed 1e-4 \ + --step_size 3 \ + --topk 512 \ + --cmm_size 2048 \ + --cmm_dim 512 \ + --region_size $region_size \ + --prototype_num $prototype_num \ + --save_dir /path/to/storage \ + --step_size 1 \ + --gamma 0.8 \ + --seed 456789 \ + --log_period 1000 \ + --beam_size 3