|
7 | 7 |
|
8 | 8 | class DiagnosticsMixin:
|
9 | 9 |
|
10 |
| - def _heat_map(self, density = 50, outfile = False): |
| 10 | + def _heat_map(self, distance = "SK", density = 50, bounds = [], vrange = [0,1], outfile = False): |
11 | 11 | """ create a heat map showing total variation distance as mean and standard deviation are varied
|
12 |
| - :param mean_bounds: tuple or list of bounds for the mean |
13 |
| - :param sd_bounds: tuple or list of bounds for the standard deviation |
| 12 | + :param distance: string containing distance type, either 'TV' or 'SK' |
14 | 13 | :param density: the density of the points in the heat map
|
15 |
| - :outfile: location to be saved |
| 14 | + :param bounds: list of bounds for the mean and the standard deviation |
| 15 | + :param vrange: list containing min and max values for heatmap color scaling |
| 16 | + :param outfile: location to be saved |
16 | 17 | """
|
17 | 18 | sns.set_theme(font_scale=0.6)
|
18 | 19 | replicates = np.sort(self.replicates)
|
19 |
| - mean_bounds = (self.estimate - 3*self.se, self.estimate + 3*self.se) |
20 |
| - sd_bounds = (self.se/5, 4*self.se) |
| 20 | + |
| 21 | + if bounds: |
| 22 | + mean_bounds = bounds[0] |
| 23 | + sd_bounds = bounds[1] |
| 24 | + else: |
| 25 | + mean_bounds = (self.estimate - 3*self.se, self.estimate + 3*self.se) |
| 26 | + sd_bounds = (self.se/5, 4*self.se) |
21 | 27 |
|
22 | 28 | pdf_from_kde = helpers.get_kde(replicates, self.best_bandwidth_value)
|
23 | 29 | sigma_range, mu_range = np.linspace(min(sd_bounds), max(sd_bounds), density * 2), np.linspace(min(mean_bounds), max(mean_bounds), density)
|
24 | 30 | sigma_label, mu_label = ['%.3f' % sigma for sigma in sigma_range], ['%.3f' % mu for mu in mu_range]
|
25 |
| - |
26 |
| - tvd_table = [[helpers.get_tvd(lambda x: norm.pdf(x, loc = mu, scale = sigma), pdf_from_kde) for sigma in sigma_range] for mu in mu_range] |
27 |
| - df = pd.DataFrame(tvd_table, columns = sigma_label, index = mu_label) |
28 |
| - ax = sns.heatmap(df, vmin = 0, vmax = 1) |
| 31 | + |
| 32 | + if distance == "TV": |
| 33 | + table = [[helpers.get_tvd(lambda x: norm.pdf(x, loc = mu, scale = sigma), pdf_from_kde) for sigma in sigma_range] for mu in mu_range] |
| 34 | + elif distance == "SK": |
| 35 | + table = [[helpers.get_sk_dist(replicates, norm(loc = mu, scale = sigma)) for sigma in sigma_range] for mu in mu_range] |
| 36 | + |
| 37 | + df = pd.DataFrame(table, columns = sigma_label, index = mu_label) |
| 38 | + |
| 39 | + ax = sns.heatmap(df, vmin = vrange[0], vmax = vrange[1]) |
29 | 40 | ax.invert_yaxis()
|
30 | 41 | ax.set_xlabel('σ')
|
31 | 42 | ax.set_ylabel('µ')
|
32 |
| - plt.title('TVD with varied µ and σ') |
| 43 | + plt.title(f'{distance} with varied µ and σ') |
33 | 44 |
|
34 | 45 | if outfile:
|
35 | 46 | plt.savefig(outfile, dpi=600)
|
|
0 commit comments