-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from emilyzfliu/1214update
merge to local main branch
- Loading branch information
Showing
8 changed files
with
426 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
import tensorflow as tf | ||
import os | ||
#import numpy as np | ||
import uuid | ||
import glob | ||
from make_model import sample_net | ||
from file_reader import make_TFRecord_from_nii | ||
|
||
class Client(): | ||
def __init__(self, clientid): | ||
self.client_id = clientid | ||
shape = self.load_input_output() | ||
self.model = sample_net(shape) | ||
self.save_prior() | ||
|
||
def load_input_output(self): | ||
# example inputs/outputs taken from tf keras Model documentation | ||
# in use case, load data from files | ||
record_file = 'tfrecord_'+self.client_id+'.tfrec' | ||
shape = make_TFRecord_from_nii('data', '*_imgs*', '*_labels*', record_file) | ||
self.dataset = tf.data.TFRecordDataset(record_file) | ||
|
||
return shape | ||
#self.inputs = tf.keras.Input(shape=(3,)) | ||
#x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(self.inputs) | ||
#self.outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) | ||
|
||
def save_prior(self): | ||
self.model.save_weights('prior-'+self.client_id+'.h5', save_format = 'h5') | ||
|
||
def train(self): | ||
_op = 'adam' | ||
_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | ||
_metrics = ['accuracy'] | ||
self.model.compile(optimizer=_op, loss=_loss, metrics=_metrics) | ||
#self.model.fit(inputs=self.inputs, outputs=self.outputs, epochs=1, verbose=2) | ||
self.model.fit(self.dataset, epochs=1, verbose=2) | ||
|
||
def load_consolidated(self): | ||
list_of_files = glob.glob('/server/*') | ||
latest_file = max(list_of_files, key=os.path.getctime) | ||
self.model.load_weights(latest_file) | ||
return latest_file | ||
|
||
def save_weights(self): | ||
filename = 'server/consolidated-'+self.client_id+'-'+uuid.uuid4().__str__()+'.h5' | ||
self.model.save_weights(filename, save_format = 'h5') | ||
|
||
# example usage | ||
|
||
a = Client('A') | ||
try: | ||
a.load_consolidated() | ||
except: | ||
pass | ||
a.train() | ||
a.save_weights() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author: Aakanksha Rana, Emi Z Liu | ||
""" | ||
import numpy as np | ||
|
||
def distributed_weight_consolidation(model_weights, model_priors): | ||
# models is a list of weights of client-models; models = [model1, model2, model3...] | ||
num_layers = int(len(model_weights[0])/2.0) | ||
num_datasets = np.shape(model_weights)[0] | ||
consolidated_model = model_weights[0] | ||
mean_idx = [i for i in range(len(model_weights[0])) if i % 2 == 0] | ||
std_idx = [i for i in range(len(model_weights[0])) if i % 2 != 0] | ||
ep = 1e-5 | ||
for i in range(num_layers): | ||
num_1 = 0; num_2 = 0 | ||
den_1 = 0; den_2 = 0 | ||
for m in range(num_datasets): | ||
model = model_weights[m] | ||
prior = model_priors[m] | ||
mu_s = model[mean_idx[i]] | ||
mu_o = prior[mean_idx[i]] | ||
sig_s = model[std_idx[i]] | ||
sig_o = prior[std_idx[i]] | ||
d1 = np.power(sig_s,2) + ep; d2= np.power(sig_o,2) + ep | ||
num_1 += (mu_s/d1) | ||
num_2 += (mu_o/d2) | ||
den_1 += (1.0/d1) | ||
if m != num_datasets-1: | ||
den_2 +=(1.0/d2) | ||
consolidated_model[mean_idx[i]] = (num_1 - num_2)/(den_1 -den_2) | ||
consolidated_model[std_idx[i]] = 1/(den_1 -den_2) | ||
return consolidated_model | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Mon Dec 14 15:33:42 2020 | ||
@author: Emi Z Liu | ||
""" | ||
|
||
# from https://github.com/corticometrics/neuroimage-tensorflow genTFrecord.py | ||
|
||
|
||
# Creates a .tfrecord file from a directory of nifti images. | ||
# This assumes your niftis are soreted into subdirs by directory, and a regex | ||
# can be written to match a volume-filenames and label-filenames | ||
# | ||
# USAGE | ||
# python ./genTFrecord.py <data-dir> <input-vol-regex> <label-vol-regex> | ||
# EXAMPLE: | ||
# python ./genTFrecord.py ./buckner40 'norm' 'aseg' buckner40.tfrecords | ||
# | ||
# Based off of this: | ||
# http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ | ||
|
||
# imports | ||
import tensorflow as tf | ||
import nibabel as nib | ||
import os, re | ||
import numpy as np | ||
|
||
def make_TFRecord_from_nii(data_dir, v_regex, l_regex, outfile): | ||
# RETURN AN INPUT SHAPE!!! | ||
|
||
def _bytes_feature(value): | ||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | ||
|
||
def _int64_feature(value): | ||
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | ||
|
||
def select_hipp(x): | ||
x[x != 17] = 0 | ||
x[x == 17] = 1 | ||
return x | ||
|
||
def crop_brain(x): | ||
x = x[90:130,90:130,90:130] #should take volume zoomed in on hippocampus area | ||
return x | ||
|
||
def preproc_brain(x): | ||
x = select_hipp(x) | ||
x = crop_brain(x) | ||
return x | ||
|
||
def listfiles(folder): | ||
for root, folders, files in os.walk(folder): | ||
for filename in folders + files: | ||
yield os.path.join(root, filename) | ||
|
||
def gen_filename_pairs(data_dir, v_re, l_re): | ||
unfiltered_filelist=list(listfiles(data_dir)) | ||
input_list = [item for item in unfiltered_filelist if re.search(v_re,item)] | ||
label_list = [item for item in unfiltered_filelist if re.search(l_re,item)] | ||
print("input_list size: ", len(input_list)) | ||
print("label_list size: ", len(label_list)) | ||
if len(input_list) != len(label_list): | ||
print("input_list size and label_list size don't match") | ||
raise Exception | ||
sample_img = nib.load(input_list[0]).get_fdata() | ||
return (zip(input_list, label_list), sample_img) | ||
|
||
# parse args - UNCOMMENTED IN ORIGINAL CODE, NOW PASSED AS PARAMS | ||
# data_dir = sys.argv[1] | ||
# v_regex = sys.argv[2] | ||
# l_regex = sys.argv[3] | ||
# outfile = sys.argv[4] | ||
# print("data_dir: ", data_dir) | ||
# print("v_regex: ", v_regex ) | ||
# print("l_regex: ", l_regex ) | ||
# print("outfile: ", outfile ) | ||
|
||
# Generate a list of (volume_filename, label_filename) tuples | ||
filename_pairs, sample_img = gen_filename_pairs(data_dir, v_regex, l_regex) | ||
|
||
# To compare original to reconstructed images | ||
#original_images = [] | ||
|
||
writer = tf.python_io.TFRecordWriter(outfile) | ||
for v_filename, l_filename in filename_pairs: | ||
|
||
print("Processing:") | ||
print(" volume: ", v_filename) | ||
print(" label: ", l_filename) | ||
|
||
# The volume, in nifti format | ||
v_nii = nib.load(v_filename) | ||
# The volume, in numpy format | ||
v_np = v_nii.get_data().astype('int16') | ||
# The volume, in raw string format | ||
v_np = crop_brain(v_np) | ||
# The volume, in raw string format | ||
v_raw = v_np.tostring() | ||
|
||
# The label, in nifti format | ||
l_nii = nib.load(l_filename) | ||
# The label, in numpy format | ||
l_np = l_nii.get_data().astype('int16') | ||
# Preprocess the volume | ||
l_np = preproc_brain(l_np) | ||
# The label, in raw string format | ||
l_raw = l_np.tostring() | ||
|
||
# Dimensions | ||
x_dim = v_np.shape[0] | ||
y_dim = v_np.shape[1] | ||
z_dim = v_np.shape[2] | ||
print("DIMS: " + str(x_dim) + str(y_dim) + str(z_dim)) | ||
|
||
data_point = tf.train.Example(features=tf.train.Features(feature={ | ||
'image_raw': _bytes_feature(v_raw), | ||
'label_raw': _bytes_feature(l_raw)})) | ||
|
||
writer.write(data_point.SerializeToString()) | ||
|
||
writer.close() | ||
return sample_img.shape |
Oops, something went wrong.