Skip to content
7 changes: 4 additions & 3 deletions pythontoolkit/DataAugmentation3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
Fairly basic set of tools for real-time data augmentation on the volumetric
data. Extended for 3D objects augmentation.
"""

import keras.backend as K
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
import numpy as np
import scipy.ndimage
from keras.utils.data_utils import Sequence
from tensorflow.keras.utils import Sequence
from scipy import linalg
from six.moves import range

Expand Down
8 changes: 5 additions & 3 deletions pythontoolkit/losses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from keras import backend as K
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K

# Define root mean sqared error loss function.
def rmse(y_true, y_pred):
return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))

return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
55 changes: 32 additions & 23 deletions pythontoolkit/networks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu May 24 10:57:19 2018

@author: claesnl
"""

import tensorflow as tf
from tensorflow import keras
import warnings
from keras.models import Model
from keras.optimizers import Adam
from keras.layers import Conv3D, Conv3DTranspose, Dropout, Input
from keras.layers import Activation, BatchNormalization, concatenate
from keras import regularizers
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, Dropout, Input
from tensorflow.keras.layers import Activation, BatchNormalization, concatenate
from tensorflow.keras import regularizers

warnings.filterwarnings('ignore')

# Define u-net structure.
def unet(X, f, dims_out):
'''
Inputs:
X inputs
f n_base_filters
dims_out output dimensions of network
'''
# Define convolution block:
def conv_block(layer,fsize,dropout,downsample=True):
for i in range(1,3):
layer = Conv3D(fsize, kernel_size=3, kernel_regularizer=regularizers.l2(1e-1),
Expand All @@ -30,6 +37,7 @@ def conv_block(layer,fsize,dropout,downsample=True):
downsample = Activation('relu')(downsample)
return layer, downsample

# Define transposed convolution block:
def convt_block(layer, concat, fsize):
layer = Conv3DTranspose(fsize, kernel_size=3, kernel_regularizer=regularizers.l2(1e-1),
kernel_initializer='he_normal', padding='same', strides=2)(layer)
Expand All @@ -38,26 +46,27 @@ def convt_block(layer, concat, fsize):
layer = concatenate([layer, concat], axis=-1)
return layer

# Dropout values
dropout = [.1,.1,.2,.3,.2,.2,.1]

# ENCODING
block1, dblock1 = conv_block(X,f,.1)
block2, dblock2 = conv_block(dblock1,f*2**1,.1)
block3, dblock3 = conv_block(dblock2,f*2**2,.2)
block4, dblock4 = conv_block(dblock3,f*2**3,.2)
block5, _ = conv_block(dblock4,f*2**4,.3,downsample=False)
block1, dblock1 = conv_block(X,f,dropout[0])
block2, dblock2 = conv_block(dblock1,f*2**1,dropout[1])
block3, dblock3 = conv_block(dblock2,f*2**2,dropout[2])
block4, _ = conv_block(dblock3,f*2**3,dropout[3],downsample=False)

# DECODING
block7 = convt_block(block5,block4,f*2**3)
block8, _ = conv_block(block7,f*2**3,.3,downsample=False)

block9 = convt_block(block8,block3,f*2**2)
block10, _ = conv_block(block9,f*2**2,.2,downsample=False)
block5 = convt_block(block4,block3,f*2**2)
block6, _ = conv_block(block5,f*2**2,dropout[4],downsample=False)

block11 = convt_block(block10,block2,f*2**1)
block12, _ = conv_block(block11,f*2**1,.2,downsample=False)
block7 = convt_block(block6,block2,f*2**1)
block8, _ = conv_block(block7,f*2**1,dropout[5],downsample=False)

block13 = convt_block(block12,block1,f)
block14, _ = conv_block(block13,f,.1,downsample=False)
block9 = convt_block(block8,block1,f)
block10, _ = conv_block(block9,f,dropout[6],downsample=False)

output = Conv3D(dims_out,kernel_size=3, kernel_regularizer=regularizers.l2(1e-1),
kernel_initializer='he_normal', padding='same',strides=1, activation='relu')(block14)
kernel_initializer='he_normal', padding='same',strides=1, activation='relu')(block10)

return output

5 changes: 3 additions & 2 deletions pythontoolkit/predict.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import tensorflow as tf
from tensorflow import keras
import warnings
warnings.filterwarnings('ignore')
import os
import pickle
from keras.models import load_model, model_from_json
from tensorflow.keras.models import load_model, model_from_json

class CNN():
def __init__(self,model,config=None,custom_objects={}):
Expand All @@ -17,7 +19,6 @@ def __init__(self,model,config=None,custom_objects={}):
else:
self.model = load_model(model,custom_objects=custom_objects)


def load_model_w_json(self,model):
modelh5name = os.path.join( os.path.dirname(model), os.path.splitext(os.path.basename(model))[0]+'.h5' )
json_file = open(model,'r')
Expand Down
101 changes: 65 additions & 36 deletions pythontoolkit/train.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,40 @@
# Import python libraries:
import tensorflow as tf
from tensorflow import keras
import warnings
warnings.filterwarnings('ignore')
import os
import pickle
from glob import glob
from CAAI import networks
import json
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras.layers import Input
from keras.models import Model, load_model, model_from_json
from keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model, load_model, model_from_json
from tensorflow.keras.optimizers import Adam



"""

TODO:
- Check content of existing CONFIG matches new run if continue
- Delete checkpoints if running overwriting already trained model.
- Generate data pickle file


"""

# Define Convolutional Neural Network class.
class CNN():

# Define default configurations, which will be used if no other configurations are defined.
def __init__(self,**kwargs):
self.config = dict()
self.config["model_name"] = 'PROJECT_NAME_WITH_VERSION_NUMBER'
self.config["overwrite"] = False
self.config["input_patch_shape"] = (16,192,240)
self.config["input_patch_shape"] = (8,256,256)
self.config["input_channels"] = 2
self.config["output_channels"] = 1
self.config["batch_size"] = 2
self.config["epochs"] = 100
self.config["batch_size"] = 1
self.config["epochs"] = 1000
self.config["checkpoint_save_rate"] = 10
self.config["initial_epoch"] = 0
self.config["learning_rate"] = 1e-4
self.config["data_folder"] = '' # Path to folder containing data
self.config["data_pickle"] = '' # Path to pickle containing train/validation splits
self.config["data_pickle_kfold"] = None # Set to fold if k-fold training is applied (key will e.g. be train_0 and valid_0)
self.config["pretrained_model"] = None # If transfer learning from other model (not used if resuming training, but keep for model_name's sake)
self.config["augmentation"] = True
self.config["augmentation"] = False
self.config["augmentation_params"] = {
#'rotation_range': [5,5,5],
'shift_range': [0.05,0.05,0.05],
Expand All @@ -57,7 +52,7 @@ def __init__(self,**kwargs):

# Config specific for network architecture
self.config["network_architecture"] = 'unet'
self.config['n_base_filters'] = 32
self.config['n_base_filters'] = 64
self.custom_network_architecture = None

# Metrics and loss functions
Expand All @@ -80,24 +75,22 @@ def __init__(self,**kwargs):

# Check if model has been trained (and can be overwritten), or if we should resume from checkpoint
self.check_model_existance()

# Setup callbacks
self.callbacks_list = self.setup_callbacks()


def setup_callbacks(self):

# Checkpoints
os.makedirs('checkpoint/{}'.format(self.config['model_name']), exist_ok=True)
checkpoint_file=os.path.join('checkpoint',self.config["model_name"],'e{epoch:02d}_{val_loss:.2f}.h5')
checkpoint = ModelCheckpoint(checkpoint_file, monitor='val_loss', verbose=1, save_best_only=False, mode='min',period=self.config["checkpoint_save_rate"])
checkpoint_file=os.path.join('checkpoint',self.config["model_name"],'e{epoch:02d}.h5')
checkpoint = ModelCheckpoint(checkpoint_file, monitor='val_loss', verbose=1, save_best_only=False, mode='min',save_freq=int(self.config['checkpoint_save_rate']*self.data_loader.n_batches))

# Tensorboard
os.makedirs('logs', exist_ok=True)
TB_file=os.path.join('logs',self.config["model_name"])
TB = TensorBoard(log_dir = TB_file)

return [checkpoint, TB]


def compile_network(self):

Expand Down Expand Up @@ -130,6 +123,7 @@ def compile_network(self):
self.model.compile(loss = loss, loss_weights = loss_weights, optimizer = optimizer, metrics=self.metrics)

self.is_compiled = True


def load_model_w_json(self,model):
modelh5name = os.path.join( os.path.dirname(model), os.path.splitext(os.path.basename(model))[0]+'.h5' )
Expand All @@ -140,6 +134,7 @@ def load_model_w_json(self,model):
model.load_weights(modelh5name)
return model


def build_network(self,inputs=None):
if not inputs:
inputs = Input(shape=self.config['input_patch_shape']+(self.config['input_channels'],))
Expand All @@ -156,6 +151,7 @@ def build_network(self,inputs=None):

return Model(inputs=inputs,outputs=outputs)


def generate_model_name_from_params(self):
# Build full model name
model_name = self.config['model_name']
Expand All @@ -169,10 +165,12 @@ def generate_model_name_from_params(self):

return model_name


def get_initial_epoch_from_file(self,f):
last_epoch = f.split('/')[-1].split('_')[0]
assert last_epoch.startswith('e') # check that it is indeed the epoch part of the name that we extract
return int(last_epoch[1:]) # extract only integer part of eXXX
return int(last_epoch[1:-3]) # extract only integer part of eXXX


def check_model_existance(self):
# Check if config file already exists
Expand All @@ -198,22 +196,53 @@ def check_model_existance(self):
# else -> model exists but we specified to overwrite, so we do so, without loading from the checkpoint folder.
# OBS: The checkpoints should probably be cleared before starting?


def print_config(self):
print(json.dumps(self.config, indent = 4))


def set(self,key,value):
self.config[key] = value

def plot_model(self):
# Compile network if it has not been done:
if not self.is_compiled:
self.compile_network()

tf.keras.utils.plot_model(self.model, show_shapes=True,
to_file='model_fig.png')

def train(self):

# Setup callbacks
self.callbacks_list = self.setup_callbacks()

# Compile network if it has not been done:
if not self.is_compiled:
self.compile_network()

print(self.model.summary())

# Check if data generators has been attached
if hasattr(self,'data_loader'):
self.training_generator = self.data_loader.generate( self.config['train_pts'] )
self.validation_generator = self.data_loader.generate( self.config['valid_pts'] )
#self.training_generator = self.data_loader.generate( self.config['train_pts'] )
#self.validation_generator = self.data_loader.generate( self.config['valid_pts'] )

# Updated to TFv2 generator
generator_shape_input = self.config["input_patch_shape"]+tuple([self.config["input_channels"]])
generator_shape_output = self.config["input_patch_shape"]+tuple([self.config["output_channels"]])
self.training_generator = tf.data.Dataset.from_generator(
lambda: self.data_loader.generate( self.config['train_pts'] ),
output_types=(tf.float32,tf.float32),
output_shapes=(tf.TensorShape(generator_shape_input),tf.TensorShape(generator_shape_output)))
self.validation_generator = tf.data.Dataset.from_generator(
lambda: self.data_loader.generate( self.config['valid_pts'] ),
output_types=(tf.float32,tf.float32),
output_shapes=(tf.TensorShape(generator_shape_input),tf.TensorShape(generator_shape_output)))

self.training_generator = self.training_generator.batch(self.config["batch_size"])
self.validation_generator = self.validation_generator.batch(self.config["batch_size"])

else:
print("No data generator was attached.")
exit(-1)
Expand All @@ -223,14 +252,14 @@ def train(self):
with open('configs/{}.pkl'.format(self.config["model_name"]), 'wb') as file_pi:
pickle.dump(self.config, file_pi)

history = self.model.fit_generator(generator = self.training_generator,
steps_per_epoch = self.data_loader.n_batches,
validation_data = self.validation_generator,
validation_steps = 1,
epochs = self.config['epochs'],
verbose = 1,
callbacks = self.callbacks_list,
initial_epoch = self.config['initial_epoch'] )
history = self.model.fit( self.training_generator,
steps_per_epoch = self.data_loader.n_batches,
validation_data = self.validation_generator,
validation_steps = 1,
epochs = self.config['epochs'],
verbose = 1,
callbacks = self.callbacks_list,
initial_epoch = self.config['initial_epoch'] )

# Save model
self.model.save('{}.h5'.format( self.config['model_name'] ))
Expand Down
6 changes: 6 additions & 0 deletions reinstall.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#reinstall the CNN package
#by running ./reinstall.sh in command-line
mkdir build
cd build
cmake ..
make install
Loading