Skip to content

Commit a66f5af

Browse files
committed
Add argparse
1 parent 752dd65 commit a66f5af

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def get_file_info(filename):
9090
while True:
9191
desc = file.readline()
9292
str_array = file.readline()
93+
if desc=="":
94+
break
9395
_, _, _, epoch_no = desc.strip().split('$')
9496
epoch_no = int(epoch_no)
9597
if epoch_no==0:

visualize_mean_var.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
import argparse
12
import numpy as np
23
import seaborn as sns
34
import matplotlib.pyplot as plt
45

56
import utils
6-
import config_bayesian as cfg
77

88

9-
def draw_distributions(filename, type='mean', node_no=0):
9+
def draw_distributions(filename, type='mean', node_no=0, save_plots=False, plot_time=0.5):
1010
file_desc = utils.get_file_info(filename)
1111
layer = file_desc['layer_name']
1212
means, std = utils.load_mean_std_from_file(filename)
@@ -26,12 +26,15 @@ def draw_distributions(filename, type='mean', node_no=0):
2626
plt.xlabel(f'Value of {type}')
2727
plt.ylabel('Density')
2828
plt.show(block=False)
29-
plt.pause(0.5)
29+
plt.pause(plot_time)
3030
ax.clear()
3131
plt.close()
3232

33+
if save_plots:
34+
raise NotImplementedError
3335

34-
def draw_lineplot(filename, type='mean', node_no=0):
36+
37+
def draw_lineplot(filename, type='mean', node_no=0, save_plots=False, plot_time=5):
3538
file_desc = utils.get_file_info(filename)
3639
layer = file_desc['layer_name']
3740
means, stds = utils.load_mean_std_from_file(filename)
@@ -47,6 +50,24 @@ def draw_lineplot(filename, type='mean', node_no=0):
4750
plt.title(f'Mean value of {type} for node {node_no} of {layer}')
4851
plt.xlabel('Epoch Number')
4952
plt.ylabel(f'Mean of {type}s')
50-
plt.show()
53+
plt.show(block=False)
54+
plt.pause(plot_time)
55+
if save_plots:
56+
plt.savefig('temp.jpg')
57+
58+
if __name__ == '__main__':
59+
parser = argparse.ArgumentParser(description = "Visualize Mean and Variance")
60+
parser.add_argument('--filename', type=str, help='path to log file', required=True)
61+
parser.add_argument('--data_type', default='mean', type=str, help='Draw plots for what? mean or std?')
62+
parser.add_argument('--node_no', default=0, type=int, help='Draw plots for which node?')
63+
parser.add_argument('--plot_type', default='lineplot', type=str, help='Which plot to draw? lineplot or distplot?')
64+
parser.add_argument('--plot_time', default=1, type=int, help='Pause the plot for how much time?')
65+
parser.add_argument('--save_plots', default=0, type=int, help='Save plots? 0 (No) or 1 (Yes)')
66+
args = parser.parse_args()
5167

52-
# draw_lineplot("checkpoints/MNIST/bayesian/lenet/fc3.txt", 'mean', 3)
68+
if args.plot_type=='lineplot':
69+
draw_lineplot(args.filename, args.data_type, args.node_no, bool(args.save_plots), args.plot_time)
70+
elif args.plot_type=='distplot':
71+
draw_distributions(args.filename, args.data_type, args.node_no, bool(args.save_plots), args.plot_time)
72+
else:
73+
raise NotImplementedError

0 commit comments

Comments
 (0)