Skip to content

Commit ed3bd02

Browse files
authored
[FIX] Reproducable samples for PERMBU (#337)
1 parent 491c2de commit ed3bd02

File tree

4 files changed

+42
-34
lines changed

4 files changed

+42
-34
lines changed

Diff for: hierarchicalforecast/core.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _reverse_engineer_sigmah(
5050
id_col: str = "unique_id",
5151
time_col: str = "ds",
5252
target_col: str = "y",
53+
num_samples: int = 200,
5354
) -> np.ndarray:
5455
"""
5556
This function assumes that the model creates prediction intervals
@@ -81,7 +82,7 @@ def _reverse_engineer_sigmah(
8182
sign = -1 if "lo" in pi_col else 1
8283
level_cols = re.findall("[\d]+[.,\d]+|[\d]*[.][\d]+|[\d]+", pi_col)
8384
level_col = float(level_cols[-1])
84-
z = norm.ppf(0.5 + level_col / 200)
85+
z = norm.ppf(0.5 + level_col / num_samples)
8586
sigmah = Y_hat_df[pi_col].to_numpy().reshape(n_series, -1)
8687
sigmah = sign * (sigmah - y_hat) / z
8788

@@ -430,18 +431,22 @@ def reconcile(
430431
reconciler_args["y_hat_insample"] = y_hat_insample
431432

432433
if level is not None:
434+
reconciler_args["intervals_method"] = intervals_method
435+
reconciler_args["num_samples"] = 200
436+
reconciler_args["seed"] = seed
437+
433438
if intervals_method in ["normality", "permbu"]:
434439
sigmah = _reverse_engineer_sigmah(
435-
Y_hat_df=Y_hat_nw, y_hat=y_hat, model_name=model_name
440+
Y_hat_df=Y_hat_nw,
441+
y_hat=y_hat,
442+
model_name=model_name,
443+
id_col=id_col,
444+
time_col=time_col,
445+
target_col=target_col,
446+
num_samples=reconciler_args["num_samples"],
436447
)
437448
reconciler_args["sigmah"] = sigmah
438449

439-
reconciler_args["intervals_method"] = intervals_method
440-
reconciler_args["num_samples"] = (
441-
200 # TODO: solve duplicated num_samples
442-
)
443-
reconciler_args["seed"] = seed
444-
445450
# Mean and Probabilistic reconciliation
446451
kwargs_ls = [
447452
key

Diff for: hierarchicalforecast/probabilistic_methods.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ def get_samples(self, num_samples: int):
8585
**Returns:**<br>
8686
`samples`: Coherent samples of size (`base`, `horizon`, `num_samples`).
8787
"""
88-
state = np.random.RandomState(self.seed)
88+
rng = np.random.default_rng(self.seed)
8989
n_series, n_horizon = self.y_hat.shape
9090
samples = np.empty(shape=(num_samples, n_series, n_horizon))
9191
for t in range(n_horizon):
9292
with warnings.catch_warnings():
9393
# Avoid 'RuntimeWarning: covariance is not positive-semidefinite.'
9494
# By definition the multivariate distribution is not full-rank
95-
partial_samples = state.multivariate_normal(
95+
partial_samples = rng.multivariate_normal(
9696
mean=self.SP @ self.y_hat[:, t],
9797
cov=self.cov_rec[t],
9898
size=num_samples,
@@ -194,8 +194,8 @@ def get_samples(self, num_samples: int):
194194
# removing nas from residuals
195195
residuals = residuals[:, np.isnan(residuals).sum(axis=0) == 0]
196196
sample_idx = np.arange(residuals.shape[1] - h)
197-
state = np.random.RandomState(self.seed)
198-
samples_idx = state.choice(sample_idx, size=num_samples)
197+
rng = np.random.default_rng(self.seed)
198+
samples_idx = rng.choice(sample_idx, size=num_samples)
199199
samples = [self.y_hat + residuals[:, idx : (idx + h)] for idx in samples_idx]
200200
SP = self.S @ self.P
201201
samples = np.apply_along_axis(
@@ -382,21 +382,21 @@ def get_samples(self, num_samples: Optional[int] = None):
382382
num_samples = residuals.shape[1]
383383

384384
# Expand residuals to match num_samples [(a,b),T] -> [(a,b),num_samples]
385+
rng = np.random.default_rng(self.seed)
385386
if num_samples > residuals.shape[1]:
386-
residuals_idxs = np.random.choice(residuals.shape[1], size=num_samples)
387+
residuals_idxs = rng.choice(residuals.shape[1], size=num_samples)
387388
else:
388-
residuals_idxs = np.random.choice(
389+
residuals_idxs = rng.choice(
389390
residuals.shape[1], size=num_samples, replace=False
390391
)
391392
residuals = residuals[:, residuals_idxs]
392393
rank_permutations = self._obtain_ranks(residuals)
393394

394-
state = np.random.RandomState(self.seed)
395395
n_series, n_horizon = self.y_hat.shape
396396

397397
base_samples = np.array(
398398
[
399-
state.normal(loc=m, scale=s, size=num_samples)
399+
rng.normal(loc=m, scale=s, size=num_samples)
400400
for m, s in zip(self.y_hat.flatten(), self.sigmah.flatten())
401401
]
402402
)
@@ -432,7 +432,7 @@ def get_samples(self, num_samples: Optional[int] = None):
432432
parent_samples = np.einsum("ab,bhs->ahs", Agg, children_samples)
433433
random_permutation = np.array(
434434
[
435-
np.random.permutation(np.arange(num_samples))
435+
rng.permutation(np.arange(num_samples))
436436
for serie in range(len(parent_samples))
437437
]
438438
)

Diff for: nbs/src/core.ipynb

+11-8
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@
150150
" model_name: str,\n",
151151
" id_col: str = \"unique_id\",\n",
152152
" time_col: str = \"ds\",\n",
153-
" target_col: str = \"y\") -> np.ndarray:\n",
153+
" target_col: str = \"y\",\n",
154+
" num_samples: int = 200) -> np.ndarray:\n",
154155
" \"\"\"\n",
155156
" This function assumes that the model creates prediction intervals\n",
156157
" under a normality with the following the Equation:\n",
@@ -179,7 +180,7 @@
179180
" sign = -1 if 'lo' in pi_col else 1\n",
180181
" level_cols = re.findall('[\\d]+[.,\\d]+|[\\d]*[.][\\d]+|[\\d]+', pi_col)\n",
181182
" level_col = float(level_cols[-1])\n",
182-
" z = norm.ppf(0.5 + level_col / 200)\n",
183+
" z = norm.ppf(0.5 + level_col / num_samples)\n",
183184
" sigmah = Y_hat_df[pi_col].to_numpy().reshape(n_series,-1)\n",
184185
" sigmah = sign * (sigmah - y_hat) / z\n",
185186
"\n",
@@ -476,14 +477,17 @@
476477
" reconciler_args['y_hat_insample'] = y_hat_insample\n",
477478
"\n",
478479
" if level is not None:\n",
480+
" reconciler_args['intervals_method'] = intervals_method\n",
481+
" reconciler_args['num_samples'] = 200\n",
482+
" reconciler_args['seed'] = seed\n",
483+
"\n",
479484
" if intervals_method in ['normality', 'permbu']:\n",
480485
" sigmah = _reverse_engineer_sigmah(Y_hat_df=Y_hat_nw,\n",
481-
" y_hat=y_hat, model_name=model_name)\n",
486+
" y_hat=y_hat, model_name=model_name, \n",
487+
" id_col=id_col, time_col=time_col, \n",
488+
" target_col=target_col, num_samples=reconciler_args['num_samples'])\n",
482489
" reconciler_args['sigmah'] = sigmah\n",
483490
"\n",
484-
" reconciler_args['intervals_method'] = intervals_method\n",
485-
" reconciler_args['num_samples'] = 200 # TODO: solve duplicated num_samples\n",
486-
" reconciler_args['seed'] = seed\n",
487491
"\n",
488492
" # Mean and Probabilistic reconciliation\n",
489493
" kwargs_ls = [key for key in signature(reconciler.fit_predict).parameters if key in reconciler_args.keys()]\n",
@@ -513,11 +517,10 @@
513517
" if num_samples > 0:\n",
514518
" samples = reconciler.sample(num_samples=num_samples)\n",
515519
" self.sample_names[recmodel_name] = [f'{recmodel_name}-sample-{i}' for i in range(num_samples)]\n",
516-
" samples = np.reshape(samples, (len(Y_tilde_nw),-1)) \n",
520+
" samples = np.reshape(samples, (len(Y_tilde_nw),-1)) \n",
517521
" y_tilde = dict(zip(self.sample_names[recmodel_name], samples.T))\n",
518522
" Y_tilde_nw = Y_tilde_nw.with_columns(**y_tilde)\n",
519523
" \n",
520-
"\n",
521524
" end = time.time()\n",
522525
" self.execution_times[f'{model_name}/{reconcile_fn_name}'] = (end - start)\n",
523526
"\n",

Diff for: nbs/src/probabilistic_methods.ipynb

+9-9
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,14 @@
139139
" **Returns:**<br>\n",
140140
" `samples`: Coherent samples of size (`base`, `horizon`, `num_samples`).\n",
141141
" \"\"\"\n",
142-
" state = np.random.RandomState(self.seed)\n",
142+
" rng = np.random.default_rng(self.seed)\n",
143143
" n_series, n_horizon = self.y_hat.shape\n",
144144
" samples = np.empty(shape=(num_samples, n_series, n_horizon))\n",
145145
" for t in range(n_horizon):\n",
146146
" with warnings.catch_warnings():\n",
147147
" # Avoid 'RuntimeWarning: covariance is not positive-semidefinite.'\n",
148148
" # By definition the multivariate distribution is not full-rank\n",
149-
" partial_samples = state.multivariate_normal(mean=self.SP @ self.y_hat[:,t],\n",
149+
" partial_samples = rng.multivariate_normal(mean=self.SP @ self.y_hat[:,t],\n",
150150
" cov=self.cov_rec[t], size=num_samples)\n",
151151
" samples[:,:,t] = partial_samples\n",
152152
"\n",
@@ -273,8 +273,8 @@
273273
" #removing nas from residuals\n",
274274
" residuals = residuals[:, np.isnan(residuals).sum(axis=0) == 0]\n",
275275
" sample_idx = np.arange(residuals.shape[1] - h)\n",
276-
" state = np.random.RandomState(self.seed)\n",
277-
" samples_idx = state.choice(sample_idx, size=num_samples)\n",
276+
" rng = np.random.default_rng(self.seed)\n",
277+
" samples_idx = rng.choice(sample_idx, size=num_samples)\n",
278278
" samples = [self.y_hat + residuals[:, idx:(idx + h)] for idx in samples_idx]\n",
279279
" SP = self.S @ self.P\n",
280280
" samples = np.apply_along_axis(lambda path: np.matmul(SP, path),\n",
@@ -488,19 +488,19 @@
488488
" num_samples = residuals.shape[1]\n",
489489
"\n",
490490
" # Expand residuals to match num_samples [(a,b),T] -> [(a,b),num_samples]\n",
491+
" rng = np.random.default_rng(self.seed)\n",
491492
" if num_samples > residuals.shape[1]:\n",
492-
" residuals_idxs = np.random.choice(residuals.shape[1], size=num_samples)\n",
493+
" residuals_idxs = rng.choice(residuals.shape[1], size=num_samples)\n",
493494
" else:\n",
494-
" residuals_idxs = np.random.choice(residuals.shape[1], size=num_samples, \n",
495+
" residuals_idxs = rng.choice(residuals.shape[1], size=num_samples, \n",
495496
" replace=False)\n",
496497
" residuals = residuals[:,residuals_idxs]\n",
497498
" rank_permutations = self._obtain_ranks(residuals)\n",
498499
"\n",
499-
" state = np.random.RandomState(self.seed)\n",
500500
" n_series, n_horizon = self.y_hat.shape\n",
501501
"\n",
502502
" base_samples = np.array([\n",
503-
" state.normal(loc=m, scale=s, size=num_samples) for m, s in \\\n",
503+
" rng.normal(loc=m, scale=s, size=num_samples) for m, s in \\\n",
504504
" zip(self.y_hat.flatten(), self.sigmah.flatten())\n",
505505
" ])\n",
506506
" base_samples = base_samples.reshape(n_series, n_horizon, num_samples)\n",
@@ -536,7 +536,7 @@
536536
" # and randomly shuffle parent predictions after aggregation\n",
537537
" parent_samples = np.einsum('ab,bhs->ahs', Agg, children_samples)\n",
538538
" random_permutation = np.array([\n",
539-
" np.random.permutation(np.arange(num_samples)) \\\n",
539+
" rng.permutation(np.arange(num_samples)) \\\n",
540540
" for serie in range(len(parent_samples))\n",
541541
" ])\n",
542542
" parent_samples = self._permutate_predictions(\n",

0 commit comments

Comments
 (0)