Skip to content

Commit

Permalink
Initial version. Main file for the energy optimization model.
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Y. de Camargo committed Mar 21, 2023
1 parent 76730a9 commit 8e9029e
Showing 1 changed file with 387 additions and 0 deletions.
387 changes: 387 additions & 0 deletions src/energy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
import argparse
import csv
import datetime
import pickle
import numpy as np
from scipy.special import softmax
from sklearn.metrics import confusion_matrix, accuracy_score, balanced_accuracy_score, cohen_kappa_score, \
classification_report
import os.path


def print_matrix(text, matrix):
print(text)
for i in range(len(matrix)):
for j in range(len(matrix[0])):
print("{:.2f} ".format(matrix[i][j]), end="")
print()


# 1. Determine transition probabilities between stages
def find_transition_prob(data):
unique, counts = np.unique(data, return_counts=True)
print("Unique values {} of labels {}".format(counts, unique))
data_prev = data[:-1]
data_next = data[1:]
n_labels = len(unique)
trans_prob = np.zeros((n_labels, n_labels))
for label in unique:
transitions = data_next[np.where(data_prev == label)]
trans_count = np.bincount(transitions)
trans_prob[label] = np.append(trans_count, np.zeros(n_labels - len(trans_count)))
trans_prob[label] /= np.sum(trans_prob[label])

return trans_prob

# 2. Finds the energy of the state of each epoch in the hypnogram.
# The energy depends on the probability of each state, as defined in the paper.
def find_energy_state_vec(alpha, conf_real_pred, e_pred, e_trans, pred, pred_opt, trans_prob, pred_prob):

error_t1 = np.zeros(len(pred))
error_t2 = np.zeros(len(pred))
error_pc = np.zeros(len(pred))
error_pr = np.zeros(len(pred))

for i in range(len(pred)):
state = pred_opt[i]
# Probability for each state transition from state series in the training set
if i > 0:
p_trans = trans_prob[pred_opt[i - 1]][state]
error_t1[i] = (1 + e_trans) / (p_trans + e_trans) - 1
if i < len(pred) - 1:
p_trans = trans_prob[state][pred_opt[i + 1]]
error_t2[i] = (1 + e_trans) / (p_trans + e_trans) - 1
# Probability for each state from the confusion matrices from the validation set
p_pred = conf_real_pred[state][pred[i]]
error_pc[i] = (1 + e_pred) / (p_pred + e_pred) - 1
# Probability for each state from the NN predictions for the state
p_out = pred_prob[i][state]
error_pr[i] = (1 + e_pred) / (p_out + e_pred) - 1

energy_vec = alpha[0] * (error_t1 + error_t2) + alpha[1] * (error_pc + error_pr)
return energy_vec

# 3. Finds the changes in energy of the state of a single epoch
def find_delta_energy(curr_energy, alpha, conf_real_pred, e_pred, e_trans, i,
pred, pred_opt, state, trans_prob, pred_prob):
# Probability for each state transition from state series in the training set
error_t, error_t2 = 0, 0
if i > 0:
p_trans = trans_prob[pred_opt[i - 1]][state]
error_t = (1 + e_trans) / (p_trans + e_trans) - 1
if i < len(pred) - 1:
p_trans = trans_prob[state][pred_opt[i + 1]]
error_t2 = (1 + e_trans) / (p_trans + e_trans) - 1
# Probability for each state from the confusion matrices from the validation set
p_pred = conf_real_pred[state][pred[i]]
error_p1 = (1 + e_pred) / (p_pred + e_pred) - 1
# Probability for each state from the NN predictions for the state
p_out = pred_prob[i][state]
error_p2 = (1 + e_pred) / (p_out + e_pred) - 1

delta = np.array([0, 0, 0])
delta[1] = alpha[0] * (error_t + error_t2) + alpha[1] * (error_p1 + error_p2) - curr_energy

# Evaluate the energy change of previous and next positions
if i > 0:
p_transP = trans_prob[pred_opt[i - 1]][pred_opt[i]]
p_transN = trans_prob[pred_opt[i - 1]][state]
delta[0] = alpha[0]*((1 + e_trans) / (p_transN + e_trans) - (1 + e_trans) / (p_transP + e_trans))
if i < len(pred) - 1:
p_transP = trans_prob[pred_opt[i]][pred_opt[i+1]]
p_transN = trans_prob[state][pred_opt[i+1]]
delta[2] = alpha[0]*((1 + e_trans) / (p_transN + e_trans) - (1 + e_trans) / (p_transP + e_trans))

return delta


# 4. Optimize sleep series using the energy function
# trans_prob: transition probability [prev_state][next_state]
# conf_real_pred: confusion matrix [real][pred]
def optimize_energy(trans_prob, conf_real_pred, pred, pred_prob, target, betas,
e_trans=0.1, e_pred=0.1, alpha=None, n_steps=100):
if alpha is None:
alpha = [1.0, 1.0]
n_states = len(trans_prob)
pred_opt = pred.copy()

# Finds the error for each position
energy_list = find_energy_state_vec(alpha, conf_real_pred, e_pred, e_trans, pred, pred_opt, trans_prob, pred_prob)

# Changes one position on each step
for step in range(n_steps):
if step > 0 and step % 1000 == 0 and len(target) == len(pred_opt):
print("Step {:5} Accuracy: {:.2f}%".format(step,
100 * len(np.equal(target, pred_opt).nonzero()[0]) / len(
target)))

# Select position to update. Selection probability is proportional to position energy
pos = np.random.choice(len(energy_list), 1, p=softmax(betas[step] * energy_list))[0]

# Select the state of the selected position. Next states with less energy are more likely to be selected
delta_energy = [] # np.zeros(n_states)
for state in range(n_states):
delta_energy.append(find_delta_energy(energy_list[pos], alpha, conf_real_pred, e_pred, e_trans, pos, pred,
pred_opt, state, trans_prob, pred_prob))
delta_sum = np.sum(delta_energy, axis=1)
# best_state = np.argmin(delta_sum)
best_state = np.random.choice(len(delta_sum), 1, p=softmax(-betas[step] * delta_sum))[0]

# Update the state of the selected position and its neighbors.
pred_opt[pos] = best_state
energy_list[pos] += delta_energy[best_state][1]
if pos > 0:
energy_list[pos-1] += delta_energy[best_state][0]
if pos < len(energy_list) - 1:
energy_list[pos+1] += delta_energy[best_state][2]

return pred_opt

# 5. Optimize sleep series for multiple subjects in a dataset
def perform_optimizations_dataset(opt_params, part, model, dataset, writefile):

pickle_input = "../data/input_" + dataset + "_" + model + "_part_" + str(part) + ".pkl"
file = open(pickle_input, 'rb')
data_per_subj = pickle.load(file)
file.close()

# Concatenate validation data from all subjects
y_true_valid = []
y_pred_valid = []
y_prob_valid = []
for index, row in data_per_subj[data_per_subj['train_test'] == 'valid'].iterrows():
y_true_valid.append(row.y_true)
y_pred_valid.append(row.y_pred)
y_prob_valid.append(row.y_prob)
y_true_valid = np.hstack(y_true_valid)
y_pred_valid = np.hstack(y_pred_valid)
y_prob_valid = np.vstack(y_prob_valid)

# Determine Transition Probability matrix
y_true_train = []
for index, row in data_per_subj[data_per_subj['train_test'] == 'train'].iterrows():
y_true_train.append(row.y_true)
y_true_train = np.hstack(y_true_train)
trans_prob_train = find_transition_prob(y_true_train)

# output files
csv_file = '../output/sleep-results-' + dataset + '.csv'
npz_file = "../output/" + "energy-" + model + "-" + str(part) + "-" + dataset + "-opt.npz"

curr_time = '{date:%Y-%m-%d %H:%M:%S}'.format(date=datetime.datetime.now())
targets = ['0', '1', '2', '3', '4']

# Creates and normalizes confusion matrix from validation data
conf_mat_valid = confusion_matrix(y_true_valid, y_pred_valid)
conf_mat_valid = conf_mat_valid.astype('float') / conf_mat_valid.sum(axis=0)[:, np.newaxis]

# Print initial metrics
print_matrix('Confusion Matrix (Validation Set)', conf_mat_valid)
acc_valid = accuracy_score(y_true_valid, y_pred_valid) * 100
bal_valid = balanced_accuracy_score(y_true_valid, y_pred_valid) * 100
print("Accuracy: {:.2f}% (Validation)".format(acc_valid))
print("Balanced: {:.2f}% (Validation)".format(bal_valid))
report = classification_report(y_true_valid, y_pred_valid, target_names=targets, output_dict=True)
print_matrix('Transition Matrix (Training Set)', trans_prob_train)

# Optimization model parameters (alpha, beta, and epsilons). Check paper for details.
a_max_list, b_max, ep_max, et_max= [[0.5, 0.5]], 1.0, 0.1, 0.1
if opt_params:
a_max_list = [[0,1],[0.25,0.75],[0.5,0.5],[0.75,0.25],[1,0]]

# Number of otimization steps and schedule for the beta (1/temperature) value
n_steps = 5 * 1200
sched_steps = [1, 2, 4, 8]
beta_sched = []
for b in sched_steps:
beta_sched.append(np.ones(n_steps)*b)
beta_sched = np.array(beta_sched).flatten()

if 'recording' not in data_per_subj.columns:
data_per_subj['recording'] = np.zeros(len(data_per_subj), dtype='int')

y_test_list = {'pred': [], 'pred-opt': [], 'true': []}

print('\n\n===== Evaluating Optimizations =====')
metrics = {}
for index, row in data_per_subj[data_per_subj['train_test'] == 'test'].iterrows():
subj, rec = row['subjects'], row['recording']
for a_max in a_max_list:
print(subj, rec)
subj_data = data_per_subj.loc[data_per_subj['subjects'] == subj].loc[data_per_subj['recording'] == rec]
if subj_data.size == 0:
continue
y_pred_test = subj_data.y_pred.values[0]
y_prob_test = subj_data.y_prob.values[0]
y_true_test = subj_data.y_true.values[0]
# Apply energy-optimization for a single subject
y_pred_opt_test = optimize_energy(trans_prob_train, conf_mat_valid, y_pred_test, y_prob_test,
[], b_max * beta_sched, et_max, ep_max, a_max, n_steps)

metrics = evaluate_metrics(metrics, targets, y_pred_opt_test, y_pred_test, y_true_test, subj, rec, a_max)
if writefile:
write_results_file(opt_params, part, model, csv_file, metrics, curr_time,
a_max, ep_max, et_max, targets, y_pred_opt_test, y_pred_test, y_true_test)

y_test_list['pred-opt'].append(y_pred_opt_test)
y_test_list['pred'].append(y_pred_test)
y_test_list['true'].append(y_true_test)

# Write results to file
if writefile:
np.savez(npz_file, conf_mat_valid=conf_mat_valid, conf_mat_test=metrics['cm_test'],
conf_mat_opt=metrics['cm_opt'], trans_prob_train=trans_prob_train,
subj=metrics['subj'], rec=metrics['rec'], alpha=metrics['alpha'],
y_pred_opt_test_list=y_test_list['pred-opt'], y_pred_test_list=y_test_list['pred'],
y_true_test=y_test_list['true'])

print("\n================================================================")
print("Finished - model: " + model + " cross-validation: " + str(part) + " dataset: " + dataset)
print("mean_acc =", np.mean(metrics['acc']), "mean_acc_opt =", np.mean(metrics['acc_opt']))
print("std_acc =", np.std(metrics['acc']), "std_acc_opt =", np.std(metrics['acc_opt']))
print("mean_bal =", np.mean(metrics['bal']), "mean_bal_opt =", np.mean(metrics['bal_opt']))
print("std_bal =", np.std(metrics['bal']), "std_bal_opt =", np.std(metrics['bal_opt']))
print("================================================================ \n")


def write_results_file(optparams, part, model, csv_file, metr, curtime, a_max, ep_max, et_max, targets,
y_pred_opt_test, y_pred_test, y_true_test, pos=-1):

if not os.path.exists(csv_file):
row = ['type', 'part', 'model', 'optparams', 'et_max', 'ep_max', 'a_max',
'acc', 'acc_opt', 'bal', 'bal_opt', 'coh', 'coh_opt', 'timestamp']
for t in targets:
row += ['acc-' + t, 'acc-opt-' + t, 'prec-' + t, 'prec-opt-' + t, 'rec-' + t, 'rec-opt-' + t,
'f1-' + t, 'f1-opt-' + t, 'sup-' + t, 'sup-opt-' + t]
row += ['subject', 'record']
with open(csv_file, 'a+', newline='') as write_obj:
csv.writer(write_obj).writerow(row)
write_obj.flush()

report = classification_report(y_true_test, y_pred_test,
target_names=targets, output_dict=True, labels=list(map(int, targets)))
report_opt = classification_report(y_true_test, y_pred_opt_test,
target_names=targets, output_dict=True, labels=list(map(int, targets)))
results = ['energy', part, model, optparams, et_max, ep_max, a_max,
metr['acc'][pos], metr['acc_opt'][pos], metr['bal'][pos], metr['bal_opt'][pos], metr['coh'][pos],
metr['coh_opt'][pos], curtime]
for t in targets:
acc_t = metr['acc_class_test'][pos][int(t)]
acc_opt_t = metr['acc_class_opt'][pos][int(t)]
if np.isnan(acc_t):
acc_t = 0
if np.isnan(acc_opt_t):
acc_opt_t = 0
results += [acc_t, acc_opt_t,
report[t]['precision'], report_opt[t]['precision'],
report[t]['recall'], report_opt[t]['recall'],
report[t]['f1-score'], report_opt[t]['f1-score'],
report[t]['support'], report_opt[t]['support']]
results += [metr['subj'][pos], metr['rec'][pos]]
with open(csv_file, 'a+', newline='') as write_obj:
csv.writer(write_obj).writerow(results)
write_obj.flush()


# acc, acc_class_opt, acc_class_test, acc_opt, bal, bal_opt, cm_opt, cm_test, coh, coh_opt
def evaluate_metrics(metrics, targets, y_pred_opt_test, y_pred_test, y_true_test, subj=-1, rec=-1, alpha=None,
print_res=True):

if alpha is None:
alpha = [1.0, 1.0]
if len(metrics) == 0:
metrics = {'acc': [], 'acc_opt': [], 'bal': [], 'bal_opt': [], 'coh': [], 'coh_opt': [],
'cm_test': [], 'cm_opt': [], 'acc_class_test': [], 'acc_class_opt': [],
'y_true_counts': [], 'y_pred_counts': [], 'y_opt_counts': [], 'subj': [], 'rec': [], 'alpha': []}

labels = list(map(int, targets))
metrics['acc'].append(accuracy_score(y_true_test, y_pred_test))
metrics['acc_opt'].append(accuracy_score(y_true_test, y_pred_opt_test))
metrics['bal'].append(balanced_accuracy_score(y_true_test, y_pred_test))
metrics['bal_opt'].append(balanced_accuracy_score(y_true_test, y_pred_opt_test))
metrics['coh'].append(cohen_kappa_score(y_true_test, y_pred_test))
metrics['coh_opt'].append(cohen_kappa_score(y_true_test, y_pred_opt_test))
metrics['cm_test'].append(confusion_matrix(y_true_test, y_pred_test, normalize='true', labels=labels))
metrics['cm_opt'].append(confusion_matrix(y_true_test, y_pred_opt_test, normalize='true', labels=labels))
metrics['subj'].append(subj)
metrics['rec'].append(rec)
metrics['alpha'].append(alpha)

metrics['acc_class_test'].append(metrics['cm_test'][-1].diagonal() / metrics['cm_test'][-1].sum(axis=1))
metrics['acc_class_opt'].append(metrics['cm_opt'][-1].diagonal() / metrics['cm_opt'][-1].sum(axis=1))
metrics['y_true_counts'].append(get_targets_counts(targets, y_true_test))
metrics['y_pred_counts'].append(get_targets_counts(targets, y_pred_test))
metrics['y_opt_counts'].append(get_targets_counts(targets, y_pred_opt_test))

if print_res:
print("Accuracy: {:.2f}% (Test Set)".format(metrics['acc'][-1] * 100))
print("Accuracy: {:.2f}% (Optimized Test Set)".format(metrics['acc_opt'][-1] * 100))
print("Balanced Accuracy: {:.2f}% (Test Set)".format(metrics['bal'][-1] * 100))
print("Balanced Accuracy: {:.2f}% (Optimized Test Set)".format(metrics['bal_opt'][-1] * 100))
print("Cohen Kappa: {:.2f}% (Test Set)".format(metrics['coh'][-1] * 100))
print("Cohen Kappa: {:.2f}% (Optimized Test Set)".format(metrics['coh_opt'][-1] * 100))
print("Counts y_true: ")
print(metrics['y_true_counts'][-1] / sum(metrics['y_true_counts'][-1]))
print("Counts y_pred_test: ")
print(metrics['y_pred_counts'][-1] / sum(metrics['y_pred_counts'][-1]))
print("Counts y_pred_opt_test: ")
print(metrics['y_opt_counts'][-1] / sum(metrics['y_opt_counts'][-1]))
# print(classification_report(y_true_test, y_pred_test, target_names=targets, labels=labels))
# print(classification_report(y_true_test, y_pred_opt_test, target_names=targets, labels=labels))
# print_matrix('Confusion Matrix (Test Set)', metr['cm_test'][-1])
# print_matrix('Confusion Matrix (Optimized Set)', metr['cm_opt'][-1])

return metrics


def get_targets_counts(targets, y_true_test):
values, counts = np.unique(y_true_test, return_counts=True)
if len(values) < len(targets):
counts_new = np.zeros(len(targets))
for t, c in zip(values, counts):
counts_new[t] = c
counts = counts_new
return counts


def main(cmd_args):
opt_params = cmd_args.optparams
writefile = cmd_args.writefile

if cmd_args.part == -1:
part = [0, 1, 2, 3, 4]
else:
part = [cmd_args.part]

if cmd_args.model == 'all':
model = ['stager', 'usleep']
else:
model = [cmd_args.model]

if cmd_args.dataset == 'all':
dataset = ['edf', 'dreamer']
else:
dataset = [cmd_args.dataset]

for d in dataset:
for m in model:
for p in part:
perform_optimizations_dataset(opt_params, p, m, d, writefile)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Energy model')

parser.add_argument('--optparams', type=int, default=0,
help='Set 1 to optimize the energy model parameters and 0 otherwise.')
parser.add_argument('--part', type=int, default=0,
help='Define the CV slice, between 0 and 5. Use -1 for all parts.')
parser.add_argument('--model', type=str, default='usleep',
help='Name of the neural network model.', choices=['stager', 'usleep', 'all'])
parser.add_argument('--writefile', type=int, default=1,
help='Set 1 to write results to file and 0 otherwise.')
parser.add_argument('--dataset', type=str, default='edf',
help='Name of the dataset to optimize.', choices=['edf', 'dreamer', 'all'])

args = parser.parse_args()
main(args)

0 comments on commit 8e9029e

Please sign in to comment.