diff --git a/ogcore/pensions.py b/ogcore/pensions.py index ce1f80815..47c65139f 100644 --- a/ogcore/pensions.py +++ b/ogcore/pensions.py @@ -445,18 +445,15 @@ def deriv_DB(w, e, per_rmn, p): Change in DB pension benefits for another unit of labor supply """ - if per_rmn < (p.S - p.S_ret + 1): + if per_rmn < (p.S - p.retire + 1): d_theta = np.zeros(p.S) else: - d_theta_empty = np.zeros(p.S) d_theta = deriv_DB_loop( w, e, p.S, - p.S_ret, + p.retire, per_rmn, - p.g_y, - d_theta_empty, p.last_career_yrs, p.rep_rate_py, p.yr_contr, diff --git a/tests/test_pensions.py b/tests/test_pensions.py index 0a88d1018..fcfe5b5e8 100644 --- a/tests/test_pensions.py +++ b/tests/test_pensions.py @@ -222,3 +222,54 @@ def test_deriv_PS_loop(args, deriv_PS_loop_expected): ) assert np.allclose(deriv_PS_loop, deriv_PS_loop_expected) + + +#############non-zero d_theta: case 1############ +p = Specifications() +p.S = 7 +p.retire = 4 +p.last_career_yrs = 3 +p.yr_contr = p.retire +p.rep_rate_py = 0.2 +p.g_y = 0.03 +n_ddb1 = np.array([0.4, 0.45, 0.4, 0.42, 0.3, 0.2, 0.2]) +w_ddb1 = np.array([1.2, 1.1, 1.21, 1, 1.01, 0.99, 0.8]) +e_ddb1 = np.array([1.1, 1.11, 0.9, 0.87, 0.87, 0.7, 0.6]) +per_rmn = n_ddb1.shape[0] +d_theta_empty = np.zeros_like(per_rmn) +deriv_DB_expected1 = np.array( + [0.352, 0.3256, 0.2904, 0.232, 0.0, 0.0, 0.0]) +args_ddb1 = (w_ddb1, e_ddb1, per_rmn, p) + +#############non-zero d_theta: case 2############ +p2 = Specifications() +p2.S = 7 +p2.retire = 5 +p2.last_career_yrs = 2 +p2.yr_contr = p2.retire +p2.rep_rate_py = 0.2 +p2.g_y = 0.03 +n_ddb2 = np.array([0.45, 0.4, 0.42, 0.3, 0.2, 0.2]) +w_ddb1 = np.array([1.1, 1.21, 1, 1.01, 0.99, 0.8]) +e_ddb1 = np.array([1.11, 0.9, 0.87, 0.87, 0.7, 0.6]) +per_rmn = n_ddb2.shape[0] +d_theta_empty = np.zeros_like(per_rmn) +deriv_DB_expected2 = np.array( + [0.6105, 0.5445, 0.435, 0.43935, 0.0, 0.0]) +args_ddb2 = (w_ddb1, e_ddb1, per_rmn, p2) + +test_data = [(args_ddb1, deriv_DB_expected1), + (args_ddb2, deriv_DB_expected2)] + + +@pytest.mark.parametrize('args,deriv_DB_expected', test_data, + ids=['non-zero d_theta: case 1', + 'non-zero d_theta: case 2']) +def test_deriv_DB(args, deriv_DB_expected): + """ + Test of the pensions.deriv_DB() function. + """ + (w, e, per_rmn, p) = args + deriv_DB = pensions.deriv_DB(w, e, per_rmn, p) + + assert (np.allclose(deriv_DB, deriv_DB_expected)) \ No newline at end of file