Skip to content

Commit

Permalink
cleaned code for pub
Browse files Browse the repository at this point in the history
  • Loading branch information
cheungatm committed May 21, 2022
1 parent 73d7395 commit 7ada413
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 52 deletions.
2 changes: 1 addition & 1 deletion ich-fl/config/config_fed_client.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"path": "ich_trainer.ICHTrainer",
"args": {
"lr": 0.0003,
"epochs": 2
"epochs": 3
}
}
},
Expand Down
4 changes: 2 additions & 2 deletions ich-fl/config/config_fed_server.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
"id": "scatter_and_gather",
"name": "ScatterAndGather",
"args": {
"min_clients" : 2,
"num_rounds" : 2,
"min_clients" : 4,
"num_rounds" : 10,
"start_round": 0,
"wait_time_after_min_received": 10,
"aggregator_id": "aggregator",
Expand Down
19 changes: 16 additions & 3 deletions ich-fl/custom/fl_dataset_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import pydicom
import os

"""
This defines the Dataset class for training and validation. It includes pre-processing of
DICOM images by windowing and stacking the CT scans.
"""

class IntracranialDataset(Dataset):
def __init__(self, csv_file, path, train, test, img_size=(512,512,1)):
self.csv = csv_file
Expand All @@ -15,7 +20,7 @@ def __init__(self, csv_file, path, train, test, img_size=(512,512,1)):
self.img_size = img_size
self.all_image_names = self.csv[:]['Image']
self.all_labels = np.array(self.csv.drop(['Image', 'all_diagnoses'], axis=1))
self.ratio_of_data_to_train = 0.25
self.ratio_of_data_to_train = 0.85
self.train_ratio = int(self.ratio_of_data_to_train * len(self.csv))
self.valid_ratio = len(self.csv) - self.train_ratio

Expand Down Expand Up @@ -51,7 +56,7 @@ def __init__(self, csv_file, path, train, test, img_size=(512,512,1)):
transforms.ToTensor()
])

## Calculate the proportion of different labels to weight the loss
## Calculate the proportion of different labels to weight the loss for imbalanced dataset
total_n_samples = np.array(self.labels).shape[0]
print(f"\ntotal samples: {total_n_samples}")
n_per_subtype = np.sum(self.labels, axis = 0)
Expand Down Expand Up @@ -88,8 +93,13 @@ def correct_dcm(dcm):
dcm.RescaleIntercept = -1000
return(dcm)


def window_image(dcm, window_center, window_width):

"""
window_image() "windows" each dicom CT scan by varying the contrast, similar to the workflow of a radiologist.
Thanks to https://github.com/appian42/kaggle-rsna-intracranial-hemorrhage/blob/master/src/utils/misc.py and
https://www.kaggle.com/code/dcstang/see-like-a-radiologist-with-systematic-windowing for inspiration.
"""
if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):
correct_dcm(dcm)

Expand All @@ -101,6 +111,9 @@ def window_image(dcm, window_center, window_width):
return img

def bsb_window(dcm):
"""
This stacks the windowed scans into an RGB image.
"""
brain_img = window_image(dcm, 40, 80)
subdural_img = window_image(dcm, 80, 200)
soft_img = window_image(dcm, 40, 380)
Expand Down
59 changes: 25 additions & 34 deletions ich-fl/custom/ich_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This is the training script for federated learning. It uses the nvflare Executor and Shareable classes
to distribute the model and initiate training at each site.
"""

import enum
import os.path
from random import shuffle
from ssl import AlertDescription
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision import models as tvmodels
# From train.py --probably can trim this down
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from fl_dataset_class import IntracranialDataset
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from datetime import date
from sklearn import metrics
from tqdm import tqdm

#
from nvflare.apis.dxo import from_shareable, DXO, DataKind, MetaKey
from nvflare.apis.executor import Executor
Expand All @@ -52,12 +49,12 @@

class ICHTrainer(Executor):

def __init__(self, lr=0.0003, epochs=2, train_task_name=AppConstants.TASK_TRAIN,
def __init__(self, lr=0.0003, epochs=3, train_task_name=AppConstants.TASK_TRAIN,
submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL, exclude_vars=None):
"""
Args:
lr (float, optional): Learning rate. Defaults to 0.01
epochs (int, optional): Epochs. Defaults to 5
lr (float, optional): Learning rate.
epochs (int, optional): Epochs for local training.
train_task_name (str, optional): Task name for train task. Defaults to "train".
submit_model_task_name (str, optional): Task name for submit model. Defaults to "submit_model".
exclude_vars (list): List of variables to exclude during model loading.
Expand All @@ -71,17 +68,26 @@ def __init__(self, lr=0.0003, epochs=2, train_task_name=AppConstants.TASK_TRAIN,
self._submit_model_task_name = submit_model_task_name
self._exclude_vars = exclude_vars

# Training setup
# Define the model
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = tvmodels.resnext101_32x8d(pretrained=True, progress=True)
self.model.fc = nn.Linear(2048,6)

# Set parameters for torchvision model
requires_grad = False
if requires_grad == False:
for param in self.model.parameters():
param.requires_grad = False
# to train the hidden layers
elif self.requires_grad == True:
for param in self.model.parameters():
param.requires_grad = True

self.model.fc = nn.Linear(2048,6) # Final fully connected layer with 6 classes for bleed subtypes
self.model.to(self.device)
self.optimizer = Adam(self.model.parameters(), lr=lr)
batch_size = 16
#self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size = 3, gamma=0.1)


# Point to the relevent test label data and DICOM files
# Point to the relevent test label data and DICOM files. Relative path defaults to client folder.
train_csv = pd.read_csv('./input/labels.csv')
train_csv = train_csv.sample(frac=1, random_state=23)

Expand Down Expand Up @@ -115,7 +121,7 @@ def local_train(self, fl_ctx, weights, abort_signal):
# Basic training
self.model.train()

# Initialize variables to output
# Initialize variables to output. This is optional but allows clients to monitor the progress of training after each epoch.
running_train_loss = []
running_val_loss = []
running_train_acc = []
Expand All @@ -128,6 +134,8 @@ def local_train(self, fl_ctx, weights, abort_signal):
running_val_prc = []
#
local_output_dir = self.create_output_dir(fl_ctx)

#
for epoch in range(self._epochs):
print(f'Epoch {epoch+1} of {self._epochs}')
# running_loss = 0.0
Expand Down Expand Up @@ -157,12 +165,6 @@ def local_train(self, fl_ctx, weights, abort_signal):
train_epoch_labels += labels.tolist()
train_epoch_preds += sigmoid_preds.tolist()

# running_loss += (cost.cpu().detach().numpy()/images.size()[0])
# if i % 3000 == 0:
# self.log_info(fl_ctx, f"Epoch: {epoch}/{self._epochs}, Iteration: {i}, "
# f"Loss: {running_loss/3000}")
# running_loss = 0.0

# Divide total loss added by num_batches
train_epoch_loss = train_running_batch_loss / counter

Expand Down Expand Up @@ -191,7 +193,6 @@ def local_train(self, fl_ctx, weights, abort_signal):
val_fpr, val_tpr, _ = metrics.roc_curve(flat_val_label, flat_val_pred)
train_roc_auc = round(metrics.auc(train_fpr, train_tpr), 6)
val_roc_auc = round(metrics.auc(val_fpr, val_tpr), 6)

# Caclulate PRC AUC
train_precision, train_recall, train_thresholds = metrics.precision_recall_curve(flat_train_label, flat_train_pred)
train_prc_auc = round(metrics.auc(train_recall, train_precision), 6)
Expand Down Expand Up @@ -282,19 +283,9 @@ def plot_metrics(self, fl_ctx, output_dir, train_data, val_data, title_string):
axs.set_title(f"{title_string}")
axs.legend(loc='center left')
print(f"plotting {title_string}...")
#run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM))
#local_output_dir = os.path.join(run_dir, PTConstants.OutputMetricsDir)
fig.savefig(f'{output_dir}/{title_string}.png')
return

#def calculate_metrics():
# F1
# Acc
# ROC AUC
# PRC AUC
# return


def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
try:

Expand Down
17 changes: 12 additions & 5 deletions ich-fl/custom/ich_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
ich_validator.py uses the Executor and Shareable classes of nvflare to
run cross-site validation, which validates all local models and the global FL model
on data from each site. The results of cross-site validation will be saved as a .json
file in the FL server for comparison.
"""

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize
import os

import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from fl_dataset_class import IntracranialDataset
from torch.utils.data import DataLoader
from datetime import date
from sklearn import metrics
from tqdm import tqdm
#
from nvflare.apis.dxo import from_shareable, DataKind, DXO
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
Expand Down Expand Up @@ -77,7 +80,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
model_owner = shareable.get_header(AppConstants.MODEL_OWNER, "?")
weights = {k: torch.as_tensor(v, device=self.device) for k, v in dxo.data.items()}

# Get validation accuracy
# Get validation accuracy for each hemorrhage subtype
validation_results = self.do_validation(weights, abort_signal)
any_results = validation_results['any']
epidural_results = validation_results['epidural']
Expand All @@ -103,6 +106,10 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
return make_reply(ReturnCode.TASK_UNKNOWN)

def do_validation(self, weights, abort_signal):
"""
do_validation() is a customizable function that will return a dict() of any performance
metric.
"""
self.model.load_state_dict(weights)
self.model.eval()

Expand Down
3 changes: 3 additions & 0 deletions ich-fl/custom/pt_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines the naming schema of variables and paths used throughout the codebase.
"""

class PTConstants:
PTServerName = "server"
Expand Down
4 changes: 3 additions & 1 deletion ich-fl/custom/pt_model_locator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class PTModelLocator(ModelLocator):
def __init__(self, exclude_vars=None, model=None):
super(PTModelLocator, self).__init__()


# Define the model used for federated learning. Most users should only need to customize the following
# two lines of code.
self.model = tvmodels.resnext101_32x8d(pretrained=True, progress=True)
self.model.fc = nn.Linear(2048,6)
#

self.exclude_vars = exclude_vars

Expand Down
6 changes: 3 additions & 3 deletions src/dataset_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def __init__(self, path, train, test, img_size=(512,512,1)):
self.master_labels = []
self.master_image_names = []

# Loop over each site within all_sites directory
# Loop over each site within directory at path defined in train or inference
site_dirs = os.listdir(f'{self.path}')
for site in site_dirs:

# Define path for each site's data
site_path = os.path.join(self.path, site)

if os.path.isfile(site_path):
if os.path.isfile(site_path): #skip anything that is a file and not directory
continue

print(f'\nAccessing data from: {site_path}')
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(self, path, train, test, img_size=(512,512,1)):
elif self.test == True and self.train == False:
print(f"Number of test images: {len(site_image_names)}")
image_paths = list(site_image_paths)
labels = list(self.all_labels)
labels = list(site_labels)
image_names = list(site_image_names)
self.transform = transforms.Compose([
transforms.ToTensor()
Expand Down
7 changes: 6 additions & 1 deletion src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initialize model
model_inf = models.model_fxn(pretrained = False, requires_grad = False).to(device)


#
model_inf = torch.nn.DataParallel(model_inf).to(device)


# load model checkpoint
checkpoint = torch.load('../output/model.pt')
# load model weights state_dict
Expand All @@ -41,7 +45,8 @@
test_loader = DataLoader(
test_data,
batch_size = 512,
shuffle=False
shuffle=False,
num_workers=8
)

samples = []
Expand Down
17 changes: 15 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@
print(f'\nBCE Loss weights: {loss_weights}')
criterion = nn.BCEWithLogitsLoss(pos_weight=loss_weights)

# Run from checkpoint
run_checkpoint = False
if run_checkpoint:
checkpoint = torch.load('../output/model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']
loss = checkpoint['loss']
else:
checkpoint_epoch = 0


#train data loader
train_loader = DataLoader(
Expand Down Expand Up @@ -114,14 +125,16 @@ def plot_roc_prc(training_roc, validation_roc, training_prc, validation_prc):
fig.savefig(f'../output/{now.strftime("%b-%d-%Y-%H-%M")}_{epochs}_b{batch_size}_ROCPRC.png')

## Train
for epoch in range(epochs):
for epoch in range(checkpoint_epoch, epochs, 1):
print(f"\n{datetime.now()}")
print(f"\nEpoch {epoch+1} of {epochs}")

# Actual training and validation
train_results = train(model, train_loader, optimizer, criterion, train_data, device)
valid_results = validate(model, valid_loader, criterion, valid_data, device)
if epoch == 0:

# initialize
if epoch == 0 or epoch == checkpoint_epoch:
best_val_loss = valid_results['val_loss']

#Save values for output
Expand Down

0 comments on commit 7ada413

Please sign in to comment.