Skip to content

Commit 1f327a7

Browse files
rcalvo12MosesStewartjmshapir
authored
Pull Request for #16: Make heatmaps compatible with SK (#17)
* #16 Make heatmaps compatible with SK * Update src/BootstrapReport/_diagnostics.py Co-authored-by: jmshapir <[email protected]> * Update src/BootstrapReport/_diagnostics.py Co-authored-by: jmshapir <[email protected]> --------- Co-authored-by: MosesStewart <[email protected]> Co-authored-by: jmshapir <[email protected]>
1 parent 3fb6719 commit 1f327a7

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

src/BootstrapReport/_diagnostics.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,40 @@
77

88
class DiagnosticsMixin:
99

10-
def _heat_map(self, density = 50, outfile = False):
10+
def _heat_map(self, distance = "SK", density = 50, bounds = [], vrange = [0,1], outfile = False):
1111
""" create a heat map showing total variation distance as mean and standard deviation are varied
12-
:param mean_bounds: tuple or list of bounds for the mean
13-
:param sd_bounds: tuple or list of bounds for the standard deviation
12+
:param distance: string containing distance type, either 'TV' or 'SK'
1413
:param density: the density of the points in the heat map
15-
:outfile: location to be saved
14+
:param bounds: list of bounds for the mean and the standard deviation
15+
:param vrange: list containing min and max values for heatmap color scaling
16+
:param outfile: location to be saved
1617
"""
1718
sns.set_theme(font_scale=0.6)
1819
replicates = np.sort(self.replicates)
19-
mean_bounds = (self.estimate - 3*self.se, self.estimate + 3*self.se)
20-
sd_bounds = (self.se/5, 4*self.se)
20+
21+
if bounds:
22+
mean_bounds = bounds[0]
23+
sd_bounds = bounds[1]
24+
else:
25+
mean_bounds = (self.estimate - 3*self.se, self.estimate + 3*self.se)
26+
sd_bounds = (self.se/5, 4*self.se)
2127

2228
pdf_from_kde = helpers.get_kde(replicates, self.best_bandwidth_value)
2329
sigma_range, mu_range = np.linspace(min(sd_bounds), max(sd_bounds), density * 2), np.linspace(min(mean_bounds), max(mean_bounds), density)
2430
sigma_label, mu_label = ['%.3f' % sigma for sigma in sigma_range], ['%.3f' % mu for mu in mu_range]
25-
26-
tvd_table = [[helpers.get_tvd(lambda x: norm.pdf(x, loc = mu, scale = sigma), pdf_from_kde) for sigma in sigma_range] for mu in mu_range]
27-
df = pd.DataFrame(tvd_table, columns = sigma_label, index = mu_label)
28-
ax = sns.heatmap(df, vmin = 0, vmax = 1)
31+
32+
if distance == "TV":
33+
table = [[helpers.get_tvd(lambda x: norm.pdf(x, loc = mu, scale = sigma), pdf_from_kde) for sigma in sigma_range] for mu in mu_range]
34+
elif distance == "SK":
35+
table = [[helpers.get_sk_dist(replicates, norm(loc = mu, scale = sigma)) for sigma in sigma_range] for mu in mu_range]
36+
37+
df = pd.DataFrame(table, columns = sigma_label, index = mu_label)
38+
39+
ax = sns.heatmap(df, vmin = vrange[0], vmax = vrange[1])
2940
ax.invert_yaxis()
3041
ax.set_xlabel('σ')
3142
ax.set_ylabel('µ')
32-
plt.title('TVD with varied µ and σ')
43+
plt.title(f'{distance} with varied µ and σ')
3344

3445
if outfile:
3546
plt.savefig(outfile, dpi=600)

0 commit comments

Comments
 (0)