Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Guo Zhengrui committed Mar 9, 2024
0 parents commit 9e390ed
Show file tree
Hide file tree
Showing 71 changed files with 3,890 additions and 0 deletions.
116 changes: 116 additions & 0 deletions main_test_AllinOne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.tester_AllinOne import Tester
from modules.loss import compute_loss
from models.histgen_model import HistGenModel

def parse_agrs():
parser = argparse.ArgumentParser()

# Data input settings
parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.')
parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.')

# Data loader settings
parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr', 'wsi_report'], help='the dataset to be used.')
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')

parser.add_argument('--model_name', type=str, default='histgen', choices=['histgen'], help='model used for experiment')

# Model settings (for visual extractor)
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')
# Model settings (for Transformer)
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')
# for Cross-modal context module
parser.add_argument('--topk', type=int, default=32, help='the number of k.')
parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.')
parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.')
# for Local-global hierachical visual encoder
parser.add_argument("--region_size", type=int, default=256, help="the size of the region for region transformer.")

# Sample related
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
parser.add_argument('--group_size', type=int, default=1, help='the group size.')
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')

# Trainer settings
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.')
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments')
parser.add_argument('--save_period', type=int, default=1, help='the saving period.')
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).')

# Optimization
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
parser.add_argument('--amsgrad', type=bool, default=True, help='.')

# Learning Rate Scheduler
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')

# Others
parser.add_argument('--seed', type=int, default=9233, help='.')
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
parser.add_argument('--load', type=str, help='whether to load a pre-trained model.')

args = parser.parse_args()
return args


def main():
# parse arguments
args = parse_agrs()

# fix random seeds
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)

tokenizer = Tokenizer(args)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)
model = HistGenModel(args, tokenizer)

# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores

# build trainer and start to train
tester = Tester(model, criterion, metrics, args, test_dataloader)
tester.test()


if __name__ == '__main__':
main()
127 changes: 127 additions & 0 deletions main_train_AllinOne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.optimizers import build_optimizer, build_lr_scheduler
from modules.trainer_AllinOne import Trainer
from modules.loss import compute_loss
from models.histgen_model import HistGenModel

def parse_agrs():
parser = argparse.ArgumentParser()

# Data input settings
parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.')
parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.')

# Data loader settings
parser.add_argument('--dataset_name', type=str, default='wsi_report', choices=['iu_xray', 'mimic_cxr', 'wsi_report'], help='the dataset to be used.')
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')

parser.add_argument('--model_name', type=str, default='histgen', choices=['histgen'], help='model used for experiment')

# Model settings (for visual extractor)
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')
# Model settings (for Transformer)
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')
# for Cross-modal context module
parser.add_argument('--topk', type=int, default=32, help='the number of k.')
parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.')
parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.')
# for Local-global hierachical visual encoder
parser.add_argument("--region_size", type=int, default=256, help="the size of the region for region transformer.")
parser.add_argument("--prototype_num", type=int, default=512, help="the number of visual prototypes for cross-modal interaction")

# Sample related
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
parser.add_argument('--group_size', type=int, default=1, help='the group size.')
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')

# Trainer settings
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.')
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments')
parser.add_argument('--save_period', type=int, default=1, help='the saving period.')
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).')

# Optimization
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
parser.add_argument('--amsgrad', type=bool, default=True, help='.')

# Learning Rate Scheduler
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')

# Others
parser.add_argument('--seed', type=int, default=9233, help='.')
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')

args = parser.parse_args()
return args


def main():
# parse arguments
args = parse_agrs()

# fix random seeds
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)


tokenizer = Tokenizer(args)

train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)

# build model architecture
model = HistGenModel(args, tokenizer)

# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores

# build optimizer, learning rate scheduler
optimizer = build_optimizer(args, model)
lr_scheduler = build_lr_scheduler(args, optimizer)

# build trainer and start to train
trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)
trainer.train()


if __name__ == '__main__':
main()
Binary file added models/__pycache__/histgen_model.cpython-310.pyc
Binary file not shown.
34 changes: 34 additions & 0 deletions models/histgen_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
import torch
import torch.nn as nn
from modules.visual_extractor import VisualExtractor
from modules.histgen_module import BaseHistGen

class HistGenModel(nn.Module):
def __init__(self, args, tokenizer):
super(HistGenModel, self).__init__()
self.args = args
self.tokenizer = tokenizer
self.encoder_decoder = BaseHistGen(args, tokenizer)
self.wsi_mapping = torch.nn.Linear(768, self.args.d_vf) if "ctranspath" in args.image_dir else torch.nn.Linear(1024, self.args.d_vf)
self.forward = self.forward_pathology
self.visual_extractor = VisualExtractor(args)

def __str__(self):
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super().__str__() + '\nTrainable parameters: {}'.format(params)

def forward_pathology(self, images, targets=None, mode='train', update_opts={}):

att_feats = self.wsi_mapping(images)
fc_feats = torch.mean(att_feats, dim=1)

if mode == 'train':
output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward')
return output
elif mode == 'sample':
output, output_probs = self.encoder_decoder(fc_feats, att_feats, mode='sample')
return output
else:
raise ValueError
Empty file added modules/__init__.py
Empty file.
Binary file added modules/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added modules/__pycache__/att_model.cpython-310.pyc
Binary file not shown.
Binary file added modules/__pycache__/caption_model.cpython-310.pyc
Binary file not shown.
Binary file added modules/__pycache__/dataloaders.cpython-310.pyc
Binary file not shown.
Binary file added modules/__pycache__/datasets.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added modules/__pycache__/loss.cpython-310.pyc
Binary file not shown.
Binary file added modules/__pycache__/metrics.cpython-310.pyc
Binary file not shown.
Binary file added modules/__pycache__/optimizers.cpython-310.pyc
Binary file not shown.
Binary file added modules/__pycache__/tokenizers.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added modules/__pycache__/utils.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 9e390ed

Please sign in to comment.