1
+ import argparse
1
2
import numpy as np
2
3
import seaborn as sns
3
4
import matplotlib .pyplot as plt
4
5
5
6
import utils
6
- import config_bayesian as cfg
7
7
8
8
9
- def draw_distributions (filename , type = 'mean' , node_no = 0 ):
9
+ def draw_distributions (filename , type = 'mean' , node_no = 0 , save_plots = False , plot_time = 0.5 ):
10
10
file_desc = utils .get_file_info (filename )
11
11
layer = file_desc ['layer_name' ]
12
12
means , std = utils .load_mean_std_from_file (filename )
@@ -26,12 +26,15 @@ def draw_distributions(filename, type='mean', node_no=0):
26
26
plt .xlabel (f'Value of { type } ' )
27
27
plt .ylabel ('Density' )
28
28
plt .show (block = False )
29
- plt .pause (0.5 )
29
+ plt .pause (plot_time )
30
30
ax .clear ()
31
31
plt .close ()
32
32
33
+ if save_plots :
34
+ raise NotImplementedError
33
35
34
- def draw_lineplot (filename , type = 'mean' , node_no = 0 ):
36
+
37
+ def draw_lineplot (filename , type = 'mean' , node_no = 0 , save_plots = False , plot_time = 5 ):
35
38
file_desc = utils .get_file_info (filename )
36
39
layer = file_desc ['layer_name' ]
37
40
means , stds = utils .load_mean_std_from_file (filename )
@@ -47,6 +50,24 @@ def draw_lineplot(filename, type='mean', node_no=0):
47
50
plt .title (f'Mean value of { type } for node { node_no } of { layer } ' )
48
51
plt .xlabel ('Epoch Number' )
49
52
plt .ylabel (f'Mean of { type } s' )
50
- plt .show ()
53
+ plt .show (block = False )
54
+ plt .pause (plot_time )
55
+ if save_plots :
56
+ plt .savefig ('temp.jpg' )
57
+
58
+ if __name__ == '__main__' :
59
+ parser = argparse .ArgumentParser (description = "Visualize Mean and Variance" )
60
+ parser .add_argument ('--filename' , type = str , help = 'path to log file' , required = True )
61
+ parser .add_argument ('--data_type' , default = 'mean' , type = str , help = 'Draw plots for what? mean or std?' )
62
+ parser .add_argument ('--node_no' , default = 0 , type = int , help = 'Draw plots for which node?' )
63
+ parser .add_argument ('--plot_type' , default = 'lineplot' , type = str , help = 'Which plot to draw? lineplot or distplot?' )
64
+ parser .add_argument ('--plot_time' , default = 1 , type = int , help = 'Pause the plot for how much time?' )
65
+ parser .add_argument ('--save_plots' , default = 0 , type = int , help = 'Save plots? 0 (No) or 1 (Yes)' )
66
+ args = parser .parse_args ()
51
67
52
- # draw_lineplot("checkpoints/MNIST/bayesian/lenet/fc3.txt", 'mean', 3)
68
+ if args .plot_type == 'lineplot' :
69
+ draw_lineplot (args .filename , args .data_type , args .node_no , bool (args .save_plots ), args .plot_time )
70
+ elif args .plot_type == 'distplot' :
71
+ draw_distributions (args .filename , args .data_type , args .node_no , bool (args .save_plots ), args .plot_time )
72
+ else :
73
+ raise NotImplementedError
0 commit comments