Skip to content

[ENH] Added TCN forecaster in aeon/forecasting/deep_learning #2938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
af94a86
Basedeep forecaster added
lucifer4073 May 22, 2025
7bb161e
Merge upstream main to basedlf
lucifer4073 May 24, 2025
d2ee9ec
init for basedlf added
lucifer4073 May 26, 2025
ab3030c
test file and axis added for basedeepforecaster
lucifer4073 Jun 15, 2025
1f202db
test locally
lucifer4073 Jun 15, 2025
14eb41f
dlf corrected
lucifer4073 Jun 15, 2025
d1a2aab
tf soft dep added
lucifer4073 Jun 22, 2025
865ed14
Merge remote-tracking branch 'upstream/main' into basedlf
lucifer4073 Jun 22, 2025
5fb72c7
tcn network added
lucifer4073 Jul 6, 2025
3434757
tcn_net pytest added
lucifer4073 Jul 6, 2025
a73c5f7
Merge branch 'main' of https://github.com/aeon-toolkit/aeon into tcn_net
lucifer4073 Jul 6, 2025
f2f393d
Merge branch 'tcn_net' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 6, 2025
c2b6231
Merge branch 'basedlf' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 6, 2025
c602e39
tcn_network updated with default params
lucifer4073 Jul 6, 2025
ad2fc01
Merge branch 'tcn_net' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 6, 2025
05a0f35
TCN forecaster added
lucifer4073 Jul 7, 2025
2f3c98b
tcn reshaped
lucifer4073 Jul 7, 2025
dd5b014
Merge branch 'main' of https://github.com/aeon-toolkit/aeon into tcn_fst
lucifer4073 Jul 7, 2025
e630a99
Merge branch 'tcn_net' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 7, 2025
f6447b1
tcn changed
lucifer4073 Jul 8, 2025
30d862a
base fst changed
lucifer4073 Jul 8, 2025
135a98d
Merge branch 'tcn_net' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 8, 2025
9b9d266
TCN forecaster updated
lucifer4073 Jul 8, 2025
78b2f3d
test file corrected
lucifer4073 Jul 8, 2025
79fe3e2
Merge branch 'basedlf' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 8, 2025
49be666
tcn updated
lucifer4073 Jul 8, 2025
7bacdac
tcn updated
lucifer4073 Jul 8, 2025
9a1b878
tcnfst updated with net
lucifer4073 Jul 8, 2025
08dadec
doctest corrected
lucifer4073 Jul 8, 2025
b167479
merge tcn_net
lucifer4073 Jul 8, 2025
d1a7fd0
Merge branch 'main' into basedlf
lucifer4073 Jul 8, 2025
086c5a4
changes made
lucifer4073 Jul 13, 2025
a1f68cd
Merge branch 'main' into tcn_fst
lucifer4073 Jul 13, 2025
b6ccd07
basedelf updated
lucifer4073 Jul 20, 2025
a39fafb
Merge branch 'basedlf' of https://github.com/lucifer4073/aeon into ba…
lucifer4073 Jul 20, 2025
f7fd5bd
Merge branch 'main' into basedlf
lucifer4073 Jul 20, 2025
405fa80
test base chanegd
lucifer4073 Jul 20, 2025
5cb1523
tcn rshaped
lucifer4073 Jul 21, 2025
f7bc502
Merge remote-tracking branch 'upstream/main' into tcn_net
lucifer4073 Jul 21, 2025
758dd38
Merge remote-tracking branch 'upstream/main' into tcn_fst
lucifer4073 Jul 21, 2025
98f42ae
Merge remote-tracking branch 'origin/tcn_net' into tcn_fst
lucifer4073 Jul 21, 2025
9e65cd5
Merge remote-tracking branch 'origin/basedlf' into tcn_fst
lucifer4073 Jul 21, 2025
2ab68c9
tcn fst updated
lucifer4073 Jul 22, 2025
7727f20
Merge branch 'tcn_fst' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 22, 2025
6893a3d
merged main to tcn_fst
lucifer4073 Aug 17, 2025
4c6b789
tcn forecaster updated with new base class
lucifer4073 Aug 19, 2025
5aae69d
Merge branch 'main' into tcn_fst
lucifer4073 Aug 19, 2025
4aad1b6
workflow corrected
lucifer4073 Aug 19, 2025
7387ea2
Merge branch 'tcn_fst' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Aug 19, 2025
6c4dca7
excluded forecasting test for tcn
lucifer4073 Aug 19, 2025
9357e2f
Merge branch 'main' into tcn_fst
lucifer4073 Aug 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions aeon/forecasting/deep_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Initialization for aeon forecasting deep learning module."""

__all__ = [
"BaseDeepForecaster",
"TCNForecaster",
]

from aeon.forecasting.deep_learning._tcn import TCNForecaster
from aeon.forecasting.deep_learning.base import BaseDeepForecaster
251 changes: 251 additions & 0 deletions aeon/forecasting/deep_learning/_tcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
"""TCNForecaster module for deep learning forecasting in aeon."""

from __future__ import annotations

__maintainer__ = []
__all__ = ["TCNForecaster"]

from typing import Any

import numpy as np
from sklearn.utils import check_random_state

from aeon.forecasting.base import DirectForecastingMixin
from aeon.forecasting.deep_learning.base import BaseDeepForecaster
from aeon.networks._tcn import TCNNetwork


class TCNForecaster(BaseDeepForecaster, DirectForecastingMixin):
"""A deep learning forecaster using Temporal Convolutional Network (TCN).

It leverages the `TCNNetwork` from aeon's network module
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters missing are:

save_last_model, save_init_model and best_model_name and init_model_name, please check any example in aeon/classification/deep_learning to see how to use them

to build the architecture suitable for forecasting tasks.

Parameters
----------
horizon : int, default=1
Forecasting horizon, the number of steps ahead to predict.
window : int, default=10
The window size for creating input sequences.
batch_size : int, default=32
Batch size for training the model.
n_epochs : int, default=100
Number of epochs to train the model.
verbose : int, default=0
Verbosity mode (0, 1, or 2).
optimizer : str or tf.keras.optimizers.Optimizer, default='adam'
Optimizer to use for training.
loss : str or tf.keras.losses.Loss, default='mse'
Loss function for training.
callbacks : list of tf.keras.callbacks.Callback or None, default=None
List of Keras callbacks to be applied during training.
random_state : int, default=None
Seed for random number generators.
axis : int, default=0
Axis along which to apply the forecaster.
last_file_name : str, default="last_model"
The name of the file of the last model, used for saving models.
save_best_model : bool, default=False
Whether to save the best model during training based on validation loss.
file_path : str, default="./"
Directory path where models will be saved.
n_blocks : list of int, default=[16, 16, 16]
List specifying the number of output channels for each layer of the
TCN. The length determines the depth of the network.
kernel_size : int, default=2
Size of the convolutional kernel in the TCN.
dropout : float, default=0.2
Dropout rate applied after each convolutional layer for
regularization.
"""

_tags = {
"python_dependencies": ["tensorflow"],
"capability:horizon": True,
"capability:multivariate": True,
"capability:exogenous": False,
"capability:univariate": True,
"algorithm_type": "deeplearning",
"non_deterministic": True,
"cant_pickle": True,
}

def __init__(
self,
horizon=1,
window=10,
batch_size=32,
n_epochs=100,
verbose=0,
optimizer="adam",
loss="mse",
callbacks=None,
random_state=None,
axis=0,
last_file_name="last_model",
save_best_model=False,
file_path="./",
n_blocks=None,
kernel_size=2,
dropout=0.2,
):
super().__init__(
horizon=horizon,
window=window,
verbose=verbose,
callbacks=callbacks,
axis=axis,
last_file_name=last_file_name,
save_best_model=save_best_model,
file_path=file_path,
)
self.n_blocks = n_blocks
self.kernel_size = kernel_size
self.dropout = dropout
self.batch_size = batch_size
self.n_epochs = n_epochs
self.optimizer = optimizer
self.loss = loss
self.random_state = random_state

def build_model(self, input_shape):
"""Build the TCN model for forecasting.

Parameters
----------
input_shape : tuple
Shape of input data, typically (window, num_inputs).

Returns
-------
model : tf.keras.Model
Compiled Keras model with TCN architecture.
"""
import tensorflow as tf

network = TCNNetwork(
n_blocks=self.n_blocks if self.n_blocks is not None else [16, 16, 16],
kernel_size=self.kernel_size,
dropout=self.dropout,
)
input_layer, output = network.build_network(input_shape=input_shape)
model = tf.keras.Model(inputs=input_layer, outputs=output)
return model

def _fit(self, y, exog=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at this example from deep classification in aeon to see how to handle saving best last and init model, and their names, please do the exact same thing, especially the loading of the best model at the end where we do try and except:

def _fit(self, X, y):

"""Fit the forecaster to training data.

Parameters
----------
y : np.ndarray or pd.Series
Target time series to which to fit the forecaster.

Returns
-------
self : TCNForecaster
Returns an instance of self.
"""
import tensorflow as tf

rng = check_random_state(self.random_state)
self.random_state_ = rng.randint(0, np.iinfo(np.int32).max)
tf.keras.utils.set_random_seed(self.random_state_)
y_inner = y
num_timepoints, num_channels = y_inner.shape
num_sequences = num_timepoints - self.window - self.horizon + 1
if y_inner.shape[0] < self.window + self.horizon:
raise ValueError(
f"Data length ({y_inner.shape}) is insufficient for window "
f"({self.window}) and horizon ({self.horizon})."
)
windows_full = np.lib.stride_tricks.sliding_window_view(
y_inner, window_shape=(self.window, num_channels)
)
windows_full = np.squeeze(windows_full, axis=1)
X_train = windows_full[:num_sequences]
# print(f"Shape of X_train is {X_train.shape}")
tail = y_inner[self.window :]
y_windows = np.lib.stride_tricks.sliding_window_view(
tail, window_shape=(self.horizon, num_channels)
)
y_windows = np.squeeze(y_windows, axis=1)
y_train = y_windows[:num_sequences]
# print(f"Shape of y_train is {y_train.shape}")
input_shape = X_train.shape[1:]
self.model_ = self.build_model(input_shape)
self.model_.compile(optimizer=self.optimizer, loss=self.loss)
callbacks_list = self._prepare_callbacks()
self.history_ = self.model_.fit(
X_train,
y_train,
batch_size=self.batch_size,
epochs=self.n_epochs,
verbose=self.verbose,
callbacks=callbacks_list,
)
self.last_window_ = y_inner[-self.window :]
return self

def _predict(self, y=None, exog=None):
"""Make forecasts for y.

Parameters
----------
y : np.ndarray or pd.Series, default=None
Series to predict from. If None, uses last fitted window.

Returns
-------
predictions : np.ndarray
Predicted values for the specified horizon. Since TCN has single
horizon capability, returns single step prediction.
"""
if y is None:
if not hasattr(self, "last_window_"):
raise ValueError("No fitted data available for prediction.")
y_inner = self.last_window_
else:
y_inner = y
if y_inner.ndim == 1:
y_inner = y_inner.reshape(-1, 1)
if y_inner.shape[0] < self.window:
raise ValueError(
f"Input data length ({y_inner.shape}) is less than the "
f"window size ({self.window})."
)
y_inner = y_inner[-self.window :]
num_channels = y_inner.shape[-1]
last_window = y_inner.reshape(1, self.window, num_channels)
pred = self.model_.predict(last_window, verbose=0)
if num_channels == 1:
prediction = pred.flatten()[0]
else:
prediction = pred[0, :]
return prediction

@classmethod
def _get_test_params(
cls, parameter_set: str = "default"
) -> dict[str, Any] | list[dict[str, Any]]:
"""
Return testing parameter settings for the estimator.

Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return "default" set.

Returns
-------
params : dict or list of dict, default={}
Parameters to create testing instances of the class.
"""
param = {
"n_epochs": 10,
"batch_size": 4,
"n_blocks": [8, 8],
"kernel_size": 2,
"dropout": 0.1,
}
return [param]
Loading