Skip to content

Commit 632ba93

Browse files
committed
new function to extrapolate lists
1 parent 7206399 commit 632ba93

File tree

2 files changed

+70
-41
lines changed

2 files changed

+70
-41
lines changed

ogcore/parameters.py

Lines changed: 7 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import paramtools
55
import ogcore
66
from ogcore import elliptical_u_est
7-
from ogcore.utils import rate_conversion, extrapolate_arrays
7+
from ogcore.utils import rate_conversion, extrapolate_arrays, extrapolate_nested_list
88
from ogcore.constants import BASELINE_DIR
99

1010
CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
@@ -223,53 +223,19 @@ def compute_default_params(self):
223223
"mtry_params",
224224
]
225225
for item in tax_params_to_TP:
226-
tax_to_set = getattr(self, item)
226+
tax_to_set_in = getattr(self, item)
227227
try:
228-
tax_to_set = (
229-
tax_to_set.tolist()
230-
) # in case parameters are numpy arrays
231-
except AttributeError: # catches if they are lists already
232-
pass
233-
if len(tax_to_set) == 1 and isinstance(tax_to_set[0], float):
234-
setattr(
235-
self,
236-
item,
237-
[
238-
[[tax_to_set] for i in range(self.S)]
239-
for t in range(self.T)
240-
],
241-
)
242-
elif any(
243-
[
244-
isinstance(tax_to_set[i][j], list)
245-
for i, v in enumerate(tax_to_set)
246-
for j, vv in enumerate(tax_to_set[i])
247-
]
248-
):
249-
if len(tax_to_set) > self.T + self.S:
250-
tax_to_set = tax_to_set[: self.T + self.S]
251-
if len(tax_to_set) < self.T + self.S:
252-
tax_params_to_add = [tax_to_set[-1]] * (
253-
self.T + self.S - len(tax_to_set)
254-
)
255-
tax_to_set.extend(tax_params_to_add)
256-
if len(tax_to_set[0]) > self.S:
257-
for t, v in enumerate(tax_to_set):
258-
tax_to_set[t] = tax_to_set[t][: self.S]
259-
if len(tax_to_set[0]) < self.S:
260-
tax_params_to_add = [tax_to_set[:][-1]] * (
261-
self.S - len(tax_to_set[0])
262-
)
263-
tax_to_set[0].extend(tax_params_to_add)
264-
setattr(self, item, tax_to_set)
265-
else:
228+
len(tax_to_set_in[0][0])
229+
except TypeError:
266230
print(
267231
"please give a "
268232
+ item
269-
+ " that is a single element or nested lists of"
233+
+ " that is a nested lists of"
270234
+ " lists that is three lists deep"
271235
)
272236
assert False
237+
tax_to_set_out = extrapolate_nested_list(tax_to_set_in, dims=(self.T, self.S, len(tax_to_set_in[0][0])))
238+
setattr(self, item, tax_to_set_out)
273239

274240
# Try to deal with size of eta. It may vary by S, J, T, but
275241
# want to allow user to enter one that varies by only S, S and J,

ogcore/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,69 @@ def extrapolate_arrays(param_in, dims=None, item="Parameter Name"):
10221022
return param_out
10231023

10241024

1025+
def extrapolate_lists(list_in, dims=(400, 80, 1)):
1026+
try:
1027+
tax_to_set = (
1028+
tax_to_set.tolist()
1029+
) # in case parameters are numpy arrays
1030+
except AttributeError: # catches if they are lists already
1031+
pass
1032+
if len(tax_to_set) == 1 and isinstance(tax_to_set[0], float):
1033+
setattr(
1034+
self,
1035+
item,
1036+
[
1037+
[[tax_to_set] for i in range(self.S)]
1038+
for t in range(self.T)
1039+
],
1040+
)
1041+
elif any(
1042+
[
1043+
isinstance(tax_to_set[i][j], list)
1044+
for i, v in enumerate(tax_to_set)
1045+
for j, vv in enumerate(tax_to_set[i])
1046+
]
1047+
):
1048+
if len(tax_to_set) > self.T + self.S:
1049+
tax_to_set = tax_to_set[: self.T + self.S]
1050+
if len(tax_to_set) < self.T + self.S:
1051+
tax_params_to_add = [tax_to_set[-1]] * (
1052+
self.T + self.S - len(tax_to_set)
1053+
)
1054+
tax_to_set.extend(tax_params_to_add)
1055+
if len(tax_to_set[0]) > self.S:
1056+
for t, v in enumerate(tax_to_set):
1057+
tax_to_set[t] = tax_to_set[t][: self.S]
1058+
if len(tax_to_set[0]) < self.S:
1059+
tax_params_to_add = [tax_to_set[:][-1]] * (
1060+
self.S - len(tax_to_set[0])
1061+
)
1062+
tax_to_set[0].extend(tax_params_to_add)
1063+
1064+
# for t, v in enumerate(tax_to_set):
1065+
# for j, k in enumerate(tax_to_set)
1066+
# tax_params_to_add = [tax_to_set[t][-1]] * (
1067+
# self.S - len(tax_to_set[t])
1068+
# )
1069+
# tax_to_set[t].extend(tax_params_to_add)
1070+
1071+
1072+
1073+
print("TAX PARAMS TO ADD: ", tax_params_to_add)
1074+
print("TAX TO SET sizes before: ", len(tax_to_set), len(tax_to_set[0]), len(tax_to_set[0][0]))
1075+
tax_to_set[0].extend(tax_params_to_add)
1076+
print("TAX TO SET sizes after: ", len(tax_to_set), len(tax_to_set[0]), len(tax_to_set[0][0]))
1077+
1078+
setattr(self, item, tax_to_set)
1079+
else:
1080+
print(
1081+
"please give a "
1082+
+ item
1083+
+ " that is a single element or nested lists of"
1084+
+ " lists that is three lists deep"
1085+
)
1086+
assert False
1087+
10251088
class CustomHttpAdapter(requests.adapters.HTTPAdapter):
10261089
"""
10271090
The UN Data Portal server doesn't support "RFC 5746 secure renegotiation". This causes and error when the client is using OpenSSL 3, which enforces that standard by default.

0 commit comments

Comments
 (0)