@@ -110,6 +110,10 @@ def get_crossings(self, alpha = 0.05, outfile = None, **kwargs):
110
110
:param alpha: 1 - alpha = confidence level for confidence bands
111
111
:param outfile: path to output figure displaying algorithm
112
112
"""
113
+ plt_set = {'fontsize' : 18 , 'legend_fontsize' : 20 , 'labelsize' : 28 , 'linecolor' : '#f9665e' , 'bandcolor' : '#a8d9ed' , 'linewidth' : 2 , 'dpi' : 100 }
114
+ for key , value in kwargs .items ():
115
+ plt_set [key ] = value
116
+
113
117
rep = self .replicates / self .se
114
118
est = self .estimate / self .se
115
119
num_rep , crossings = len (rep ), 0
@@ -167,8 +171,11 @@ def left_upper_cb(x):
167
171
self .crossings = crossings
168
172
169
173
if not outfile == None :
170
- helpers .plot_min_crossings (outfile , optimal_path , self .crossings , alpha , rep , est ,
171
- 1 , upper_cb , lower_cb , left_upper_cb , ** kwargs )
174
+ fig = helpers .plot_min_crossings (optimal_path , self .crossings , alpha , rep , est ,
175
+ 1 , upper_cb , lower_cb , left_upper_cb , plt_set )
176
+ if type (outfile ) == str :
177
+ fig .savefig (outfile , transparent = True , dpi = plt_set ['dpi' ])
178
+ return fig
172
179
173
180
def pp_plot (self , confidence_band = True , alpha = 0.05 , outfile = None , ** kwargs ):
174
181
""" create the pp plot
@@ -180,6 +187,7 @@ def pp_plot(self, confidence_band = True, alpha = 0.05, outfile = None, **kwargs
180
187
plt_set = {'fontsize' : 18 , 'legend_fontsize' : 20 , 'labelsize' : 28 , 'pointsize' : 7 , 'pointcolor' : '#f9665e' , 'bandcolor' : '#a8d9ed' , 'dpi' : 100 }
181
188
for key , value in kwargs .items ():
182
189
plt_set [key ] = value
190
+
183
191
num_replicates = len (self .replicates )
184
192
185
193
replicates_eval_normcdf = norm .cdf (self .replicates , self .estimate , self .se )
@@ -197,28 +205,27 @@ def pp_plot(self, confidence_band = True, alpha = 0.05, outfile = None, **kwargs
197
205
'Neg. distance = %.3f' % self .neg_dist ))
198
206
props = dict (boxstyle = 'round, pad = 0.75, rounding_size = 0.3' , facecolor = 'white' , alpha = 0.86 )
199
207
200
- plt .rcParams . update ({ 'font.size' : plt_set ['fontsize' ]} )
201
- plt .rcParams . update ({ 'legend.fontsize' : plt_set ['legend_fontsize' ]} )
202
- plt .rcParams . update ({ 'axes.labelsize' : plt_set ['labelsize' ]} )
203
-
204
- plt . figure ( figsize = ( 10 , 10 ))
208
+ plt .rc ( 'font' , size = plt_set ['fontsize' ])
209
+ plt .rc ( 'legend' , fontsize = plt_set ['legend_fontsize' ])
210
+ plt .rc ( 'axes' , labelsize = plt_set ['labelsize' ])
211
+ fig , ax = plt . subplots ( figsize = ( 10 , 10 ))
212
+
205
213
if confidence_band == True :
206
- plt .fill_between (dkw_xgrid , dkw_lbound , dkw_ubound , color = plt_set ['bandcolor' ], label = 'Confidence band' , alpha = 0.35 )
207
- plt .scatter (replicates_eval_normcdf , replicate_ecdf , s = plt_set ['pointsize' ],
214
+ ax .fill_between (dkw_xgrid , dkw_lbound , dkw_ubound , color = plt_set ['bandcolor' ], label = 'Confidence band' , alpha = 0.35 )
215
+ ax .scatter (replicates_eval_normcdf , replicate_ecdf , s = plt_set ['pointsize' ],
208
216
c = plt_set ['pointcolor' ], label = 'Bootstrap replicates' )
209
- plt . xlabel ("CDF of normal distribution" )
210
- plt . ylabel ("CDF of bootstrap distribution" )
211
- plt .legend (edgecolor = 'k' , loc = 'upper left' )
212
- plt .axline ((0 , 0 ), (1 , 1 ), color = "black" , linestyle = (0 , (5 , 5 )))
213
- plt .text (0.52 , 0.06 , plot_data , fontsize = plt_set ['legend_fontsize' ], \
217
+ ax . set_xlabel ("CDF of normal distribution" )
218
+ ax . set_ylabel ("CDF of bootstrap distribution" )
219
+ ax .legend (edgecolor = 'k' , loc = 'upper left' )
220
+ ax .axline ((0 , 0 ), (1 , 1 ), color = "black" , linestyle = (0 , (5 , 5 )))
221
+ ax .text (0.52 , 0.06 , plot_data , fontsize = plt_set ['legend_fontsize' ], \
214
222
verticalalignment = 'bottom' , horizontalalignment = 'left' , bbox = props )
215
- plt . ylim (0 , 1 )
216
- plt . xlim (0 , 1 )
223
+ ax . set_ylim (0 , 1 )
224
+ ax . set_xlim (0 , 1 )
217
225
218
226
if not outfile == None :
219
- plt .savefig (outfile , transparent = True , dpi = plt_set ['dpi' ])
220
- plt .clf ()
221
- mpl .rcParams .update (mpl .rcParamsDefault )
227
+ fig .savefig (outfile , transparent = True , dpi = plt_set ['dpi' ])
228
+ return fig
222
229
223
230
224
231
def density_plot (self , bounds = None , bandwidth = None , outfile = None , ** kwargs ):
@@ -230,6 +237,7 @@ def density_plot(self, bounds = None, bandwidth = None, outfile = None, **kwargs
230
237
plt_set = {'fontsize' : 18 , 'legend_fontsize' : 20 , 'labelsize' : 28 , 'linecolor' : '#f9665e' , 'linewidth' : 1 , 'dpi' : 100 }
231
238
for key , value in kwargs .items ():
232
239
plt_set [key ] = value
240
+
233
241
if not bandwidth :
234
242
bandwidth = self .best_bandwidth_value
235
243
if bounds != None :
@@ -238,25 +246,24 @@ def density_plot(self, bounds = None, bandwidth = None, outfile = None, **kwargs
238
246
lbound , ubound = self .replicates [0 ] - 2 * self .best_bandwidth_value , self .replicates [- 1 ] + 2 * self .best_bandwidth_value
239
247
240
248
pdf_from_kde = helpers .get_kde (self .replicates , bandwidth )
241
-
242
- xgrid = np .linspace (lbound , ubound , len (self .replicates ) * 100 )
249
+ xgrid = np .linspace (lbound , ubound , plt_set ['dpi' ] * 2 )
243
250
density = [pdf_from_kde (x ) for x in xgrid ]
244
251
245
- plt .rcParams .update ({'font.size' : plt_set ['fontsize' ]})
246
- plt .rcParams .update ({'legend.fontsize' : plt_set ['legend_fontsize' ]})
247
- plt .rcParams .update ({'axes.labelsize' : plt_set ['labelsize' ]})
252
+ plt .rc ('font' , size = plt_set ['fontsize' ])
253
+ plt .rc ('legend' , fontsize = plt_set ['legend_fontsize' ])
254
+ plt .rc ('axes' , labelsize = plt_set ['labelsize' ])
255
+ fig , ax = plt .subplots ()
248
256
249
- plt . xlim (lbound , ubound )
250
- plt . xlabel ('Value of object of interest' )
251
- plt . ylabel ('Density' )
252
- plt .plot (xgrid , density , linewidth = plt_set ['linewidth' ], color = plt_set ['linecolor' ])
253
- plt .plot ([self .replicates [0 ], self .replicates [- 1 ]], [0.0001 , 0.0001 ], '|k' , markeredgewidth = 1 , label = 'Range of bootstrap replicates' )
254
- plt .legend (loc = 'best' , fontsize = 'x-small' , markerscale = 0.75 )
257
+ ax . set_xlim (lbound , ubound )
258
+ ax . set_xlabel ('Value of object of interest' )
259
+ ax . set_ylabel ('Density' )
260
+ ax .plot (xgrid , density , linewidth = plt_set ['linewidth' ], color = plt_set ['linecolor' ])
261
+ ax .plot ([self .replicates [0 ], self .replicates [- 1 ]], [0.0001 , 0.0001 ], '|k' , markeredgewidth = 1 , label = 'Range of bootstrap replicates' )
262
+ ax .legend (loc = 'best' , fontsize = 'x-small' , markerscale = 0.75 )
255
263
256
264
if not outfile == None :
257
- plt .savefig (outfile , transparent = True , dpi = plt_set ['dpi' ])
258
- plt .clf ()
259
- mpl .rcParams .update (mpl .rcParamsDefault )
265
+ fig .savefig (outfile , transparent = True , dpi = plt_set ['dpi' ])
266
+ return fig
260
267
261
268
def get_tv_min (self , init_values = None , optimization_bounds = None , bounds_of_integration = np .inf ):
262
269
"""
0 commit comments