-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Guo Zhengrui
committed
Mar 9, 2024
0 parents
commit 9e390ed
Showing
71 changed files
with
3,890 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch | ||
import argparse | ||
import numpy as np | ||
from modules.tokenizers import Tokenizer | ||
from modules.dataloaders import R2DataLoader | ||
from modules.metrics import compute_scores | ||
from modules.tester_AllinOne import Tester | ||
from modules.loss import compute_loss | ||
from models.histgen_model import HistGenModel | ||
|
||
def parse_agrs(): | ||
parser = argparse.ArgumentParser() | ||
|
||
# Data input settings | ||
parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') | ||
parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') | ||
|
||
# Data loader settings | ||
parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr', 'wsi_report'], help='the dataset to be used.') | ||
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') | ||
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') | ||
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') | ||
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') | ||
|
||
parser.add_argument('--model_name', type=str, default='histgen', choices=['histgen'], help='model used for experiment') | ||
|
||
# Model settings (for visual extractor) | ||
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') | ||
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') | ||
# Model settings (for Transformer) | ||
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') | ||
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') | ||
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') | ||
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') | ||
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') | ||
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') | ||
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') | ||
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.') | ||
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.') | ||
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.') | ||
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') | ||
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') | ||
# for Cross-modal context module | ||
parser.add_argument('--topk', type=int, default=32, help='the number of k.') | ||
parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.') | ||
parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.') | ||
# for Local-global hierachical visual encoder | ||
parser.add_argument("--region_size", type=int, default=256, help="the size of the region for region transformer.") | ||
|
||
# Sample related | ||
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') | ||
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') | ||
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') | ||
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') | ||
parser.add_argument('--group_size', type=int, default=1, help='the group size.') | ||
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') | ||
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') | ||
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') | ||
|
||
# Trainer settings | ||
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') | ||
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') | ||
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') | ||
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments') | ||
parser.add_argument('--save_period', type=int, default=1, help='the saving period.') | ||
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') | ||
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') | ||
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') | ||
parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') | ||
|
||
# Optimization | ||
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') | ||
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') | ||
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') | ||
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') | ||
parser.add_argument('--amsgrad', type=bool, default=True, help='.') | ||
|
||
# Learning Rate Scheduler | ||
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') | ||
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') | ||
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') | ||
|
||
# Others | ||
parser.add_argument('--seed', type=int, default=9233, help='.') | ||
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') | ||
parser.add_argument('--load', type=str, help='whether to load a pre-trained model.') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
# parse arguments | ||
args = parse_agrs() | ||
|
||
# fix random seeds | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
np.random.seed(args.seed) | ||
|
||
tokenizer = Tokenizer(args) | ||
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) | ||
model = HistGenModel(args, tokenizer) | ||
|
||
# get function handles of loss and metrics | ||
criterion = compute_loss | ||
metrics = compute_scores | ||
|
||
# build trainer and start to train | ||
tester = Tester(model, criterion, metrics, args, test_dataloader) | ||
tester.test() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
import argparse | ||
import numpy as np | ||
from modules.tokenizers import Tokenizer | ||
from modules.dataloaders import R2DataLoader | ||
from modules.metrics import compute_scores | ||
from modules.optimizers import build_optimizer, build_lr_scheduler | ||
from modules.trainer_AllinOne import Trainer | ||
from modules.loss import compute_loss | ||
from models.histgen_model import HistGenModel | ||
|
||
def parse_agrs(): | ||
parser = argparse.ArgumentParser() | ||
|
||
# Data input settings | ||
parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.') | ||
parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.') | ||
|
||
# Data loader settings | ||
parser.add_argument('--dataset_name', type=str, default='wsi_report', choices=['iu_xray', 'mimic_cxr', 'wsi_report'], help='the dataset to be used.') | ||
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') | ||
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') | ||
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') | ||
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') | ||
|
||
parser.add_argument('--model_name', type=str, default='histgen', choices=['histgen'], help='model used for experiment') | ||
|
||
# Model settings (for visual extractor) | ||
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') | ||
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') | ||
# Model settings (for Transformer) | ||
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') | ||
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') | ||
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') | ||
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') | ||
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') | ||
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') | ||
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') | ||
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.') | ||
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.') | ||
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.') | ||
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') | ||
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') | ||
# for Cross-modal context module | ||
parser.add_argument('--topk', type=int, default=32, help='the number of k.') | ||
parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.') | ||
parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.') | ||
# for Local-global hierachical visual encoder | ||
parser.add_argument("--region_size", type=int, default=256, help="the size of the region for region transformer.") | ||
parser.add_argument("--prototype_num", type=int, default=512, help="the number of visual prototypes for cross-modal interaction") | ||
|
||
# Sample related | ||
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') | ||
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') | ||
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') | ||
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') | ||
parser.add_argument('--group_size', type=int, default=1, help='the group size.') | ||
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') | ||
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') | ||
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') | ||
|
||
# Trainer settings | ||
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') | ||
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') | ||
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') | ||
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments') | ||
parser.add_argument('--save_period', type=int, default=1, help='the saving period.') | ||
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') | ||
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') | ||
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') | ||
parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') | ||
|
||
# Optimization | ||
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') | ||
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') | ||
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') | ||
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') | ||
parser.add_argument('--amsgrad', type=bool, default=True, help='.') | ||
|
||
# Learning Rate Scheduler | ||
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') | ||
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') | ||
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') | ||
|
||
# Others | ||
parser.add_argument('--seed', type=int, default=9233, help='.') | ||
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
# parse arguments | ||
args = parse_agrs() | ||
|
||
# fix random seeds | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
np.random.seed(args.seed) | ||
|
||
|
||
tokenizer = Tokenizer(args) | ||
|
||
train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) | ||
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) | ||
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) | ||
|
||
# build model architecture | ||
model = HistGenModel(args, tokenizer) | ||
|
||
# get function handles of loss and metrics | ||
criterion = compute_loss | ||
metrics = compute_scores | ||
|
||
# build optimizer, learning rate scheduler | ||
optimizer = build_optimizer(args, model) | ||
lr_scheduler = build_lr_scheduler(args, optimizer) | ||
|
||
# build trainer and start to train | ||
trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader) | ||
trainer.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from modules.visual_extractor import VisualExtractor | ||
from modules.histgen_module import BaseHistGen | ||
|
||
class HistGenModel(nn.Module): | ||
def __init__(self, args, tokenizer): | ||
super(HistGenModel, self).__init__() | ||
self.args = args | ||
self.tokenizer = tokenizer | ||
self.encoder_decoder = BaseHistGen(args, tokenizer) | ||
self.wsi_mapping = torch.nn.Linear(768, self.args.d_vf) if "ctranspath" in args.image_dir else torch.nn.Linear(1024, self.args.d_vf) | ||
self.forward = self.forward_pathology | ||
self.visual_extractor = VisualExtractor(args) | ||
|
||
def __str__(self): | ||
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) | ||
params = sum([np.prod(p.size()) for p in model_parameters]) | ||
return super().__str__() + '\nTrainable parameters: {}'.format(params) | ||
|
||
def forward_pathology(self, images, targets=None, mode='train', update_opts={}): | ||
|
||
att_feats = self.wsi_mapping(images) | ||
fc_feats = torch.mean(att_feats, dim=1) | ||
|
||
if mode == 'train': | ||
output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward') | ||
return output | ||
elif mode == 'sample': | ||
output, output_probs = self.encoder_decoder(fc_feats, att_feats, mode='sample') | ||
return output | ||
else: | ||
raise ValueError |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.