Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Jul 4, 2024
1 parent 4615306 commit 7916656
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
7 changes: 2 additions & 5 deletions ogcore/pensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,18 +445,15 @@ def deriv_DB(w, e, per_rmn, p):
Change in DB pension benefits for another unit of labor supply
"""

if per_rmn < (p.S - p.S_ret + 1):
if per_rmn < (p.S - p.retire + 1):
d_theta = np.zeros(p.S)
else:
d_theta_empty = np.zeros(p.S)
d_theta = deriv_DB_loop(
w,
e,
p.S,
p.S_ret,
p.retire,
per_rmn,
p.g_y,
d_theta_empty,
p.last_career_yrs,
p.rep_rate_py,
p.yr_contr,
Expand Down
51 changes: 51 additions & 0 deletions tests/test_pensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,54 @@ def test_deriv_PS_loop(args, deriv_PS_loop_expected):
)

assert np.allclose(deriv_PS_loop, deriv_PS_loop_expected)


#############non-zero d_theta: case 1############
p = Specifications()
p.S = 7
p.retire = 4
p.last_career_yrs = 3
p.yr_contr = p.retire
p.rep_rate_py = 0.2
p.g_y = 0.03
n_ddb1 = np.array([0.4, 0.45, 0.4, 0.42, 0.3, 0.2, 0.2])
w_ddb1 = np.array([1.2, 1.1, 1.21, 1, 1.01, 0.99, 0.8])
e_ddb1 = np.array([1.1, 1.11, 0.9, 0.87, 0.87, 0.7, 0.6])
per_rmn = n_ddb1.shape[0]
d_theta_empty = np.zeros_like(per_rmn)
deriv_DB_expected1 = np.array(
[0.352, 0.3256, 0.2904, 0.232, 0.0, 0.0, 0.0])
args_ddb1 = (w_ddb1, e_ddb1, per_rmn, p)

#############non-zero d_theta: case 2############
p2 = Specifications()
p2.S = 7
p2.retire = 5
p2.last_career_yrs = 2
p2.yr_contr = p2.retire
p2.rep_rate_py = 0.2
p2.g_y = 0.03
n_ddb2 = np.array([0.45, 0.4, 0.42, 0.3, 0.2, 0.2])
w_ddb1 = np.array([1.1, 1.21, 1, 1.01, 0.99, 0.8])
e_ddb1 = np.array([1.11, 0.9, 0.87, 0.87, 0.7, 0.6])
per_rmn = n_ddb2.shape[0]
d_theta_empty = np.zeros_like(per_rmn)
deriv_DB_expected2 = np.array(
[0.6105, 0.5445, 0.435, 0.43935, 0.0, 0.0])
args_ddb2 = (w_ddb1, e_ddb1, per_rmn, p2)

test_data = [(args_ddb1, deriv_DB_expected1),
(args_ddb2, deriv_DB_expected2)]


@pytest.mark.parametrize('args,deriv_DB_expected', test_data,
ids=['non-zero d_theta: case 1',
'non-zero d_theta: case 2'])
def test_deriv_DB(args, deriv_DB_expected):
"""
Test of the pensions.deriv_DB() function.
"""
(w, e, per_rmn, p) = args
deriv_DB = pensions.deriv_DB(w, e, per_rmn, p)

assert (np.allclose(deriv_DB, deriv_DB_expected))

0 comments on commit 7916656

Please sign in to comment.