Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Guo Zhengrui committed Dec 18, 2024
1 parent 0410485 commit f19def7
Show file tree
Hide file tree
Showing 81 changed files with 6,571 additions and 6 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ Overview of the proposed HistGen framework: (a) local-global hierarchical encode
- [Issues](#issues)
- [License and Usage](#license-and-usage)

## News
- **2024-12-18**: Tokenizer for HistGen is uploaded, better decoding capability is unlocked. Check modules.tokenizers for details.
- **2024-12-18**: Ground Truth Reports are further cleaned and uploaded. Check the HuggingFace Datasets for more details.
- **2024-12-18**: Baselines models are uploaded.
- **2024-11-12**: HistGen WSI-report dataset is available on HuggingFace Datasets! (Also the annotation files!)
- **2024-08-10**: Codes for feature extraction (CLAM) is uploaded.
- **2024-06-17**: Our paper is accepted by MICCAI2024! 🎉

## TO-DO
- [x] Release the source code for training and testing HistGen
- [x] Release the diagnostic report data
Expand All @@ -57,7 +65,8 @@ conda env create -f requirements.yml
zip -FF DINOv2_Features.zip --out WSI_Feature_DINOv2.zip
unzip WSI_Feature_DINOv2.zip
```
Also, the paired diagnostic reports can be found from the above link with name *annotations.json*.
~~Also, the paired diagnostic reports can be found from the above link with name *annotations.json*.~~
🌟 **Update**: The ground truth reports are further cleaned and uploaded. You could find the cleaned reports with name *annotations712_update.json*, which provides smoother and more precise descripted reports. Meanwhile, the train:val:test sets are split into 7:1:2.
<!-- Our curated dataset could be downloaded from [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/zguobc_connect_ust_hk/EhmtBBT0n2lKtiCQt97eqcEBvO9WwNM3TL9x-7-kg_liuA). -->

The structure of this fold is shown as follows:
Expand Down Expand Up @@ -150,6 +159,13 @@ sh train_wsi_report.sh
```
Before you run the script, please set the path and other hyperparameters in `train_wsi_report.sh`. Note that **--image_dir** should be the path to the **dinov2_vitl** directory, and **--ann_path** should be the path to the **annotation.json** file.

🌟 **Update:** we have included baseline models in our paper for training, including R2Gen, R2GenCMN, Show&Tell, Transformer, M2Transformer, and UpDownAtt modeles. Note that they are not originally designed to process WSIs, thus enormous number of patches in a WSI may lead to unaffordable computational overhead. Thus, we implement a simple image token selection mechanism before processing these patch tokens by these models, "uniform sampling", "cross attention", and "kmeans clustering" are provided. And selected token number can be chosen in the script "train_wsi_report_baselines.sh". To train one of these baseline models, please run the following commands:
```
cd HistGen
conda activate histgen
sh train_wsi_report_baselines.sh
```

### Inference
To generate reports for WSIs in test set, you can run the following commands:
```
Expand Down
33 changes: 29 additions & 4 deletions main_train_AllinOne.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer
from modules.tokenizers import Tokenizer, ModernTokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.optimizers import build_optimizer, build_lr_scheduler
from modules.trainer_AllinOne import Trainer
from modules.loss import compute_loss
from models.histgen_model import HistGenModel
#* Baselines
from models.r2gen import R2GenModel
from models.r2gen_cmn import BaseCMNModel
from models.M2Transformer import M2Transformer
from models.PlainTransformer import PlainTransformer
from models.ShowTellModel import ShowTell
from models.UpDownAttn import UpDownAttn
#*

def parse_agrs():
parser = argparse.ArgumentParser()
Expand All @@ -23,7 +31,7 @@ def parse_agrs():
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')

parser.add_argument('--model_name', type=str, default='histgen', choices=['histgen'], help='model used for experiment')
parser.add_argument('--model_name', type=str, default='histgen', choices=['histgen', 'r2gen', 'r2gen_cmn', 'm2transformer', 'transformer', 'showtell', 'updown'], help='model used for experiment')

# Model settings (for visual extractor)
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
Expand Down Expand Up @@ -101,14 +109,31 @@ def main():
np.random.seed(args.seed)


tokenizer = Tokenizer(args)
# tokenizer = Tokenizer(args)
tokenizer = ModernTokenizer(args)

train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)


# build model architecture
model = HistGenModel(args, tokenizer)
if args.model_name == 'r2gen':
model = R2GenModel(args, tokenizer)
elif args.model_name == 'r2gen_cmn':
model = BaseCMNModel(args, tokenizer)
elif args.model_name == 'm2transformer':
model = M2Transformer(args, tokenizer)
elif args.model_name == 'transformer':
model = PlainTransformer(args, tokenizer)
elif args.model_name == 'showtell':
model = ShowTell(args, tokenizer)
elif args.model_name == 'updown':
model = UpDownAttn(args, tokenizer)
elif args.model_name == 'histgen':
model = HistGenModel(args, tokenizer)
else:
raise ValueError('Invalid model name')

# get function handles of loss and metrics
criterion = compute_loss
Expand Down
2 changes: 2 additions & 0 deletions models/M2T_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .transformer import Transformer
from .captioning_model import CaptioningModel
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit f19def7

Please sign in to comment.