-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathskyline_creator.py
executable file
·86 lines (57 loc) · 2.08 KB
/
skyline_creator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/python
from collections import OrderedDict
import matplotlib.pyplot as plt
import sys
import numpy as np
from scipy.interpolate import spline
import matplotlib.ticker as ticker
def parse_csv(csv_file):
"""
Parses csv file with similar format to the one generated by Tracer
Example:
time; mean; median; hpd lower 95; hpd upper 95
"""
csv_data = OrderedDict()
csv_fh = open(csv_file)
next(csv_fh)
for line in csv_fh:
fields = line.split(",")
fields = [float(x) for x in fields]
# Get time as key and median and HPD bounds as value list
csv_data[float(fields[0])] = fields[2:]
return csv_data
def skyline_plot(csv_data, output_file):
"""
Creates a skyline style plot from data on a csv_file. This csv file should
be compliant with the one generated by Tracer and is parsed by parse_csv
function.
"""
fig, ax = plt.subplots()
#x_data = list(csv_data.keys())
x_data = np.arange(len(csv_data))
median_data = np.array([x[0] for x in csv_data.values()])
lower_hpd = np.array([x[1] for x in csv_data.values()])
higher_hpd = np.array([x[2] for x in csv_data.values()])
plt.xticks(x_data, ["%.2E" % x for x in csv_data.keys()], rotation=45,
ha="right")
xnew = np.linspace(x_data.min(),x_data.max(), 200)
smooth_median = spline(x_data, median_data, xnew)
smooth_lower = spline(x_data, lower_hpd, xnew)
smooth_higher = spline(x_data, higher_hpd, xnew)
ax.plot(xnew, smooth_median, "--", color="black")
#ax.fill_between(x_data, higher_hpd, lower_hpd, facecolor="blue", alpha=0.5)
ax.plot(xnew, smooth_lower, color="blue")
ax.plot(xnew, smooth_higher, color="blue")
ax.fill_between(xnew, smooth_higher, smooth_lower, facecolor="blue", alpha=0.3)
plt.xlabel("Time")
plt.ylabel("Ne")
plt.tight_layout()
plt.savefig("%s.svg" % (output_file))
def main():
# Get arguments
args = sys.argv
csv_file = args[1]
output_file = args[2]
csv_data = parse_csv(csv_file)
skyline_plot(csv_data, output_file)
main()