Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions prototype/experiment.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Union, Optional\n",
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class TimeConfig:\n",
" \"\"\"Container for storing and accessing time configuration.\"\"\"\n",
" start_year: str\n",
" end_year: str\n",
" target_start: str = '11-01'\n",
" target_end: str = '12-01'\n",
" freq: Optional[str] = '2M' # daily/monthly\n",
" # tfreq: Union[int, str] = 7\n",
" \n",
" def get_dates_load(self):\n",
" \"\"\"Return index which lines up with the bins\n",
" Leap days are removed.\n",
"\n",
" (see the function \"timeseries_tofit_bins\" in line 628 of functions_pp)\n",
"\n",
" return pd.DateTimeIndex\n",
" \"\"\"\n",
"\n",
" return\n",
"\n",
" def get_resample_bins(self):\n",
" \"\"\"Return bins to aggregate to n day/ n month means\n",
" This function could work with the precursor and target variables.\n",
" \"\"\"\n",
"\n",
" return\n",
"\n",
" def split_train_test_groups(self):\n",
" \"\"\"\n",
" \"\"\"\n",
" return"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tc = TimeConfig(start_year, end_year, target_start, target_end, freq)\n",
"tc.dates2bin # time passed to data loader\n",
"tc.labelbins # same axis as dates2bin, but has labels to group by n day/ n month means\n",
"tc.traintestgroups # smart splitting of train/test data and avoid overlapping\n",
"tc.df_split # aggregated time axis pandas dataframe with columns train=true and RV_mask"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class TrainTestSet:\n",
" \"\"\"Base class for various traintest methods.\"\"\"\n",
" df: pd.DataFrame\n",
" \n",
" def get_test_labels(self):\n",
" return self.df.query('traintest==1')\n",
" \n",
" def get_training_labels(self):\n",
" return self.df.query('traintest==2')\n",
"\n",
"class LeaveOutN(TrainTestSet):\n",
" \"\"\"TrainTestSet based on a leave-n-out sampling method.\"\"\"\n",
" def __init__(self, timeconfig, n, max_lag=None):\n",
" times = timeconfig.datetimes\n",
" df = pd.DataFrame(index=times, columns=[\"traintest\"])\n",
" df['traintest'] = np.random.randint(0, 3, len(df))\n",
" self.df = df\n",
"\n",
"class Random(TrainTestSet):\n",
" \"\"\"TrainTestSet based on a random sampling method.\"\"\"\n",
" n: int\n",
"\n",
"class Ranstrat(TrainTestSet):\n",
" \"\"\"TrainTestSet based on a random stratified sampling method.\"\"\"\n",
" n: int\n",
"\n",
"class split(TrainTestSet):\n",
" \"\"\"TrainTestSet based on a simple split method.\"\"\"\n",
" n: int\n",
"\n",
"class TimeSeriesSplit(TrainTestSet):\n",
" \"\"\"TrainTestSet based on a \"one-step-ahead\" method.\"\"\"\n",
" n: int\n",
"\n",
"class RepeatedKfold(TrainTestSet):\n",
" \"\"\"TrainTestSet based on a repeated k-fold with different randomizations.\"\"\"\n",
" n_repeats: int\n",
" n_folds: int"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>traintest</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1979-01-03</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1979-01-09</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1979-01-16</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1979-01-18</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1979-01-19</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2020-12-08</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2020-12-14</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2020-12-24</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2020-12-29</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2020-12-30</th>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5082 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" traintest\n",
"1979-01-03 2\n",
"1979-01-09 2\n",
"1979-01-16 2\n",
"1979-01-18 2\n",
"1979-01-19 2\n",
"... ...\n",
"2020-12-08 2\n",
"2020-12-14 2\n",
"2020-12-24 2\n",
"2020-12-29 2\n",
"2020-12-30 2\n",
"\n",
"[5082 rows x 1 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tc = TimeConfig(start_year = '1979', end_year = '2021', freq='D')\n",
"ttset = LeaveOutN(tc, n=5)\n",
"ttset.get_training_labels()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class CrossValidator():\n",
" \"\"\"Perform cross-validation of training/testing data\"\"\"\n",
" def __init__(self):"
]
}
],
"metadata": {
"interpreter": {
"hash": "08d8a5f855c7a1e3d1e204794a733da246b168cd162c0e6afd9e496498bf62ad"
},
"kernelspec": {
"display_name": "Python 3.8.12 ('eucp')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
89 changes: 89 additions & 0 deletions prototype/prototype.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import S2Sinit\n",
"import climpp\n",
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Initialize the workflow."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define parameters for time config\n",
"start_year= 1979\n",
"end_year = 2020\n",
"target_start = '11-01'\n",
"target_end = '12-01'\n",
"freq = '2M' # two month means, with pandas\n",
"# define parameters for cross-validator\n",
"max_lag = 5\n",
"k = 10\n",
"\n",
"# this function only allows single target period\n",
"# data on the same coordinate\n",
"initializer = S2Sinit(start_year, end_year, target_start,\n",
" target_end, freq)\n",
" \n",
"# ------------------ optional procedure\n",
"# call cross validator\n",
"initializer.cv.kfold(max_lag, k) # only specify parameters for cross-validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Preprocessing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define parameters for preprocessing\n",
"path2target = 'path/target_var'\n",
"path2sst = 'path/sst'\n",
"path2z500 = 'path/z500'\n",
"# preprocess recipe for target variable\n",
"preprocess_steps_target = ['detrend', 'deseasonalize', 'standardize']\n",
"# preprocess recipe for precursors\n",
"preprocess_steps_sst = ['detrend', 'deseasonalize', 'standardize']\n",
"preprocess_steps_z500 = ['detrend', 'deseasonalize']\n",
"\n",
"# preprocess target\n",
"target_instance = climpp(initializer, path2target, preprocess_steps_target)\n",
"target_instance = climpp(initializer, path2target, preprocess_steps_target)\n",
"# preprocess precursor sst\n",
"sst_instance = climpp(initializer, path2sst, preprocess_steps_sst)\n",
"z500_instance = climpp(initializer, path2z500, preprocess_steps_z500)\n",
"\n",
"# execute preprocessing steps\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}