Skip to content

Commit

Permalink
Few visualization cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
owencqueen committed Dec 19, 2022
1 parent c356cef commit 0be0e9b
Show file tree
Hide file tree
Showing 12 changed files with 312 additions and 32 deletions.
Binary file removed .DS_Store
Binary file not shown.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
*__pycache__*
*egg-info*
*.ipynb_checkpoints*
Representations
Representations
*.DS_Store*
*.icloud*
Binary file removed experiment_nb/IV_exp_output/dataset.pt
Binary file not shown.
43 changes: 43 additions & 0 deletions formal/explainability/explain_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def agg_exps(exp_list, add_data_keys = ['Mw', 'AN', 'OHN', '%TMP']):
parser.add_argument('--monomertype', type = str, default = 'G',
help = 'A (acid) or G (glycol). Used only for TMP rank experiment.')
parser.add_argument('--mono', action = 'store_true', help ='Use if using Mono variant for IV')
parser.add_argument('--mw_experiment', action = 'store_true', help ='Run Mw Experiment (partitioning and observing attribution values)')
parser.add_argument('--debug', action = 'store_true')
args = parser.parse_args()

# Load dataset:
Expand Down Expand Up @@ -82,6 +84,8 @@ def agg_exps(exp_list, add_data_keys = ['Mw', 'AN', 'OHN', '%TMP']):
}

exps = []
ref_inds = []
count = 0
for f in tqdm(os.listdir(args.history_loc)):
# Gather all relevant histories:
history = pickle.load(open(os.path.join(args.history_loc, f), 'rb'))
Expand All @@ -90,6 +94,8 @@ def agg_exps(exp_list, add_data_keys = ['Mw', 'AN', 'OHN', '%TMP']):
kfgen = KFold(n_splits=5, shuffle=False).split(history['all_reference_inds'])
split_ref_inds = [k[1] for k in kfgen]

ref_inds.append(history['all_reference_inds'])

# Iterate over the splits, since state dictionaries are separate by kfold splits
for i in range(len(history['model_state_dicts'])):

Expand All @@ -106,6 +112,10 @@ def agg_exps(exp_list, add_data_keys = ['Mw', 'AN', 'OHN', '%TMP']):
)

exps.append(exp_out)

count += 1
if args.debug and (count > 1): # Breaks after 2
break

if args.tmp_experiment:
tmp_importance = []
Expand All @@ -120,6 +130,39 @@ def agg_exps(exp_list, add_data_keys = ['Mw', 'AN', 'OHN', '%TMP']):
if args.save_path_tmp_experiment is not None:
pickle.dump(tmp_importance, open(args.save_path_tmp_experiment, 'wb'))

elif args.mw_experiment:

mw_vals = []
mw_attr = []

for i in range(len(ref_inds)):
# Each ref_inds[i] is a sub-list
for j in range(len(ref_inds[i])):
mw_attr.append(exps[i][-1]['Mw'][0])
mw_vals.append(data.iloc[ref_inds[i][j],:].loc['Mw (PS)'])

# Partition by mw_vals:
mw_inds = np.argsort(mw_vals)
upper_quartile = mw_inds[-int(len(mw_inds) / 4):]
lower_quartile = mw_inds[:int(len(mw_inds) / 4)]

upper_attr = [mw_attr[i] for i in upper_quartile]
lower_attr = [mw_attr[i] for i in lower_quartile]

print('Upper Quartile Attr: {:.4f} +- {:.4f}'.format(np.mean(upper_attr), np.std(upper_attr) / np.sqrt(len(upper_attr))))
print('Lower Quartile Attr: {:.4f} +- {:.4f}'.format(np.mean(lower_attr), np.std(lower_attr) / np.sqrt(len(lower_attr))))

plt.rcParams["font.family"] = "serif"
fig = plt.gcf()
fig.set_size_inches(5, 5)

#plt.hlines(0, xmin=0, xmax=len(name_list) + 1, colors = 'black', linestyles='dashed')
plt.boxplot([lower_attr, upper_attr])
plt.ylabel('Attribution')
plt.xticks([1, 2], ['Lower Quartile', 'Upper Quartile'])
plt.tight_layout()
plt.show()

elif args.tmp_rank_experiment:

ind = 2 if args.monomertype == 'G' else 1
Expand Down
Binary file not shown.
Binary file added formal/explainability/tmp_exp_results/tmp.pickle
Binary file not shown.
116 changes: 116 additions & 0 deletions formal/performance/parity_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os, pickle, argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

plt.style.use('ggplot')
font = {'family' : 'serif',
'size' : 12}

matplotlib.rc('font', **font)

def median(arr):
inds = np.argsort(arr)
return inds[len(inds) // 2]

def main(args):

# Open files:
d = args.path
yhat_iv, y_iv = [], []
yhat_tg, y_tg = [], []
iv_mae, tg_mae = [], []
iv_r2, tg_r2 = [], []
for f in os.listdir(d):
h = pickle.load(open(os.path.join(d, f), 'rb'))

pred = h['all_predictions']
y = h['all_y']

if args.iv and args.tg:
iv_mae.append(h['IV'][1])
iv_r2.append(h['IV'][0])
yhat_iv.append([pred[i][0] for i in range(len(pred))])
y_iv.append([y[i][0] for i in range(len(y))])

tg_mae.append(h['Tg'][1])
tg_r2.append(h['Tg'][0])
yhat_tg.append([pred[i][1] for i in range(len(pred))])
y_tg.append([y[i][1] for i in range(len(y))])

elif args.iv:
iv_mae.append(h['mae'])
iv_r2.append(h['r2'])
yhat_iv.append(pred)
y_iv.append(y)
elif args.tg:
tg_mae.append(h['mae'])
tg_r2.append(h['r2'])
yhat_tg.append(pred)
y_tg.append(y)

i = None
if (args.iv and args.tg): # Joint model
tog = np.array(iv_r2) + np.array(tg_r2)
i = median(tog)

# Plot fig

#plt.savefig('joint_iv_parity.pdf')

if args.iv: # IV-only model
if i is None:
i = median(iv_r2)
y = y_iv[i]
yhat = yhat_iv[i]

plt.plot([min(y), max(y)], [min(y), max(y)], color = 'black', linestyle = '--')
plt.scatter(y, yhat, color = '#006C93')
plt.ylabel('Predicted IV', c = 'black')
plt.xlabel('Actual IV', c = 'black')
lx, rx = plt.xlim()
by, ty = plt.ylim()

plt.text(rx*0.75, 0.05, s = '$R^2$ = {:.4f}'.format(iv_r2[i]))
plt.xticks(c = 'black')
plt.yticks(c = 'black')
if args.iv and args.tg:
plt.savefig('joint_iv_parity.pdf')
else:
plt.savefig('iv_parity.pdf')
plt.show()



if args.tg: # Tg-only model
if i is None:
i = median(tg_r2)
y = y_tg[i]
yhat = yhat_tg[i]

plt.plot([min(y), max(y)], [min(y), max(y)], color = 'black', linestyle = '--')
plt.scatter(y, yhat, color = '#FF8200')
plt.ylabel('Predicted $T_g$', c = 'black')
plt.xlabel('Actual $T_g$', c = 'black')
lx, rx = plt.xlim()
by, ty = plt.ylim()

plt.text(rx*0.7, by*0.9, s = '$R^2$ = {:.4f}'.format(tg_r2[i]))
plt.xticks(c = 'black')
plt.yticks(c = 'black')

if args.iv and args.tg:
plt.savefig('joint_tg_parity.pdf')
else:
plt.savefig('tg_parity.pdf')
plt.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, help='Path to directory with history files')
parser.add_argument('--tg', action='store_true')
parser.add_argument('--iv', action='store_true')
args = parser.parse_args()

main(args)
57 changes: 57 additions & 0 deletions formal/performance/rmse_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pickle, os
import argparse
import numpy as np
from functools import partial
from sklearn.metrics import mean_squared_error

RMSE = partial(mean_squared_error, squared = False)

def main_single(args):

rmse_scores = []
for f in os.listdir(args.dir):
try:
h = pickle.load(open(os.path.join(args.dir, f), 'rb'))
score = RMSE(h['all_y'], h['all_predictions'])
rmse_scores.append(score)
except: # Control for stray files (e.g. __pycache__)
continue

print('Score = {:.4f} +- {:.4f}'.format(np.mean(rmse_scores), np.std(rmse_scores) / np.sqrt(len(rmse_scores))))

def main_joint(args):
rmse_scores_tg = []
rmse_scores_iv = []
for f in os.listdir(args.dir):
try:
h = pickle.load(open(os.path.join(args.dir, f), 'rb'))

# Tg calc:
tgpred = [l[1] for l in h['all_predictions']]
tgy = [l[1] for l in h['all_y']]
score = RMSE(tgy, tgpred)
rmse_scores_tg.append(score)

# IV calc:
ivpred = [l[0] for l in h['all_predictions']]
ivy = [l[0] for l in h['all_y']]
score = RMSE(ivy, ivpred)
rmse_scores_iv.append(score)
except: # Control for stray files (e.g. __pycache__)
continue

print('Tg Score = {:.4f} +- {:.4f}'.format(np.mean(rmse_scores_tg), np.std(rmse_scores_tg) / np.sqrt(len(rmse_scores_tg))))
print('IV Score = {:.4f} +- {:.4f}'.format(np.mean(rmse_scores_iv), np.std(rmse_scores_iv) / np.sqrt(len(rmse_scores_iv))))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dir', type = str, help = 'director containing histories')
parser.add_argument('--joint', action = 'store_true')

args = parser.parse_args()

if args.joint:
main_joint(args)
else:
main_single(args)
60 changes: 47 additions & 13 deletions formal/vis/plot_model_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@

FIGSIZE = (6, 4)

def ridgeline(data, overlap=0, fill=True, labels=None, n_points=150, sep = 200, color = None):
def ridgeline(data, overlap=0, fill=True, labels=None, n_points=150, sep = 200, color = None,
r2 = False):
"""
Creates a standard ridgeline plot.
Expand All @@ -66,8 +67,13 @@ def ridgeline(data, overlap=0, fill=True, labels=None, n_points=150, sep = 200,

if overlap > 1 or overlap < 0:
raise ValueError('overlap must be in [0 1]')
xx = np.linspace(np.min(np.concatenate(data)),
np.max(np.concatenate(data)), n_points)

if r2:
xx = np.linspace(0.2,
1, n_points)
else:
xx = np.linspace(np.min(np.concatenate(data)),
np.max(np.concatenate(data)), n_points)
curves = []
ys = []
for i, d in enumerate(data):
Expand All @@ -82,6 +88,9 @@ def ridgeline(data, overlap=0, fill=True, labels=None, n_points=150, sep = 200,
plt.plot(xx, curve+y, c='k', zorder=len(data)-i+1)
if labels:
plt.yticks(ys, labels)

if r2:
plt.xlim(0.15, 1.05)

return ys

Expand Down Expand Up @@ -192,6 +201,12 @@ def filter_csv_2(df, comp = 'iv'):
mae = [OH_mae, OHS_mae, CM_mae, PI_mae, SOAP_mae, MBTR_mae]

return r2, mae

def print_all_stats(scores, lab):

for s, l in zip(scores, lab):
print('\t {}: {:.4f} +- {:.4f}'.format(l, np.mean(s), np.std(s) / np.sqrt(len(s))))


def plot_IV(opt = 1):

Expand Down Expand Up @@ -267,6 +282,13 @@ def plot_IV(opt = 1):
r2 = r2O + r2
mae = maeO + mae


print('\nIV')
print('R2')
print_all_stats(r2, lab)
print('\nMAE')
print_all_stats(mae, lab)

# Sort by R2:
args = np.argsort([np.mean(r) for r in r2])

Expand All @@ -282,16 +304,19 @@ def apply_args(L):
plt.rcParams.update({'font.size': 12})
plt.figure(figsize=FIGSIZE)
ridgeline(r2, overlap =0, fill = 'y', sep = 10,
labels = lab, color = c)
labels = lab, color = c, r2 = True)
plt.xlabel('$R^2$')
plt.tight_layout()
plt.show()
#plt.show()
plt.savefig('r2_iv_ind.pdf', format = 'pdf')

plt.figure(figsize=FIGSIZE)
ridgeline(mae, overlap =0, fill = 'y', sep = 250,
labels = lab, color = c)
plt.xlabel('MAE')
plt.xlabel('MAE (dL/g)')
plt.tight_layout()
plt.show()
#plt.show()
plt.savefig('mae_iv_ind.pdf', format = 'pdf')

def plot_Tg(opt = 1):

Expand Down Expand Up @@ -362,8 +387,14 @@ def plot_Tg(opt = 1):
mae = maeO + mae
#r2, mae = r2[1:], mae[1:]

print('\nTg')
print('R2')
print_all_stats(r2, lab)
print('\nMAE')
print_all_stats(mae, lab)

# Sort by R2:
args = np.argsort([np.mean(r) for r in r2])
args = np.argsort([-np.mean(r) for r in mae])

def apply_args(L):
return [L[i] for i in args]
Expand All @@ -377,18 +408,21 @@ def apply_args(L):
plt.rcParams.update({'font.size': 12})
plt.figure(figsize=FIGSIZE)
ridgeline(r2, overlap =0, fill = 'y', sep = 20,
labels = lab, color = c)
labels = lab, color = c, r2 = True)
plt.xlabel('$R^2$')
plt.tight_layout()
plt.show()
#plt.show()
plt.savefig('r2_tg_ind.pdf', format = 'pdf')

plt.figure(figsize=FIGSIZE)
ridgeline(mae, overlap =0, fill = 'y', sep = 0.5,
labels = lab, color = c)
plt.xlabel('MAE')
plt.xlabel('MAE ($^\circ$C)')
plt.tight_layout()
plt.show()
#plt.show()
plt.savefig('mae_tg_ind.pdf', format = 'pdf')

if __name__ == '__main__':
#plot_Tg(opt = 2)
plot_Tg(opt = 2)
plot_IV(opt = 2)

Loading

0 comments on commit 0be0e9b

Please sign in to comment.