Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Mar 19, 2024
1 parent e9f0d7c commit 75dceb8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
34 changes: 27 additions & 7 deletions ogcore/parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def plot_imm_rates(


def plot_mort_rates(
p_list, labels=[""], years=[DEFAULT_START_YEAR], survival_rates=False, include_title=False, path=None
p_list,
labels=[""],
years=[DEFAULT_START_YEAR],
survival_rates=False,
include_title=False,
path=None,
):
"""
Create a plot of mortality rates from OG-Core parameterization.
Expand All @@ -88,7 +93,11 @@ def plot_mort_rates(
t = y - p0.start_year
for i, p in enumerate(p_list):
if survival_rates:
plt.plot(age_per, np.cumprod(1 - p.rho[t, :]), label=labels[i] + " " + str(y))
plt.plot(
age_per,
np.cumprod(1 - p.rho[t, :]),
label=labels[i] + " " + str(y),
)
else:
plt.plot(age_per, p.rho[t, :], label=labels[i] + " " + str(y))
plt.xlabel(r"Age $s$ (model periods)")
Expand Down Expand Up @@ -193,7 +202,9 @@ def plot_population(p, years_to_plot=["SS"], include_title=False, path=None):
plt.savefig(fig_path, dpi=300)


def plot_ability_profiles(p, p2=None, t=None, log_scale=False, include_title=False, path=None):
def plot_ability_profiles(
p, p2=None, t=None, log_scale=False, include_title=False, path=None
):
"""
Create a plot of earnings ability profiles.
Expand Down Expand Up @@ -308,7 +319,13 @@ def plot_elliptical_u(p, plot_MU=True, include_title=False, path=None):
plt.savefig(fig_path, dpi=300)


def plot_chi_n(p_list, labels=[""], years_to_plot=[DEFAULT_START_YEAR], include_title=False, path=None):
def plot_chi_n(
p_list,
labels=[""],
years_to_plot=[DEFAULT_START_YEAR],
include_title=False,
path=None,
):
"""
Create a plot of showing the values of the chi_n parameters.
Expand All @@ -328,7 +345,11 @@ def plot_chi_n(p_list, labels=[""], years_to_plot=[DEFAULT_START_YEAR], include_
fig, ax = plt.subplots()
for y in years_to_plot:
for i, p in enumerate(p_list):
plt.plot(age, p.chi_n[y - p.start_year, :], label=labels[i] + " " + str(y))
plt.plot(
age,
p.chi_n[y - p.start_year, :],
label=labels[i] + " " + str(y),
)
if include_title:
plt.title("Utility Weight on the Disutility of Labor Supply")
plt.xlabel("Age, $s$")
Expand Down Expand Up @@ -375,8 +396,7 @@ def plot_fert_rates(
for i, fert_rates in enumerate(fert_rates_list):
plt.plot(fert_rates[i, :], label=labels[i] + " " + str(y))
if include_title:
plt.title('Fertility rates by age ($f_{s}$)',
fontsize=20)
plt.title("Fertility rates by age ($f_{s}$)", fontsize=20)
plt.xlabel(r"Age $s$")
plt.ylabel(r"Fertility rate $f_{s}$")
plt.legend(loc="upper right")
Expand Down
12 changes: 7 additions & 5 deletions tests/test_parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def test_plot_mort_rates():


def test_plot_surv_rates():
fig = parameter_plots.plot_mort_rates([base_params], survival_rates=True, include_title=True)
fig = parameter_plots.plot_mort_rates(
[base_params], survival_rates=True, include_title=True
)
assert fig


Expand All @@ -83,7 +85,9 @@ def test_plot_mort_rates_save_fig(tmpdir):


def test_plot_surv_rates_save_fig(tmpdir):
parameter_plots.plot_mort_rates([base_params], survival_rates=True, path=tmpdir)
parameter_plots.plot_mort_rates(
[base_params], survival_rates=True, path=tmpdir
)
img = mpimg.imread(os.path.join(tmpdir, "survival_rates.png"))

assert isinstance(img, np.ndarray)
Expand All @@ -105,9 +109,7 @@ def test_plot_pop_growth_rates_save_fig(tmpdir):

def test_plot_ability_profiles():
p = Specifications()
fig = parameter_plots.plot_ability_profiles(
p, p2=p, include_title=True
)
fig = parameter_plots.plot_ability_profiles(p, p2=p, include_title=True)
assert fig


Expand Down

0 comments on commit 75dceb8

Please sign in to comment.