Skip to content

Commit

Permalink
refactor(model): transformers style model call
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Jan 17, 2025
1 parent b7658d2 commit 817456a
Show file tree
Hide file tree
Showing 38 changed files with 893 additions and 857 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
**[Documentation](https://time-series-prediction.readthedocs.io)** | **[Tutorials](https://time-series-prediction.readthedocs.io/en/latest/tutorials.html)** | **[Release Notes](https://time-series-prediction.readthedocs.io/en/latest/CHANGELOG.html)** | **[中文](https://github.com/LongxingTan/Time-series-prediction/blob/master/README_CN.md)**

**TFTS** (TensorFlow Time Series) is an easy-to-use time series package, supporting the classical and latest deep learning methods in TensorFlow or Keras.
- Support sota performance for time series task (prediction, classification, anomaly detection)
- Support sota models for time series tasks (prediction, classification, anomaly detection)
- Provide advanced deep learning models for industry, research and competition
- Documentation lives at [time-series-prediction.readthedocs.io](https://time-series-prediction.readthedocs.io)

Expand All @@ -55,18 +55,19 @@ pip install tfts

```python
import matplotlib.pyplot as plt
import tensorflow as tf
import tfts
from tfts import AutoModel, AutoConfig, KerasTrainer

train_length = 24
predict_sequence_length = 8
(x_train, y_train), (x_valid, y_valid) = tfts.get_data("sine", train_length, predict_sequence_length, test_size=0.2)

model_name_or_path = 'seq2seq' # 'wavenet', 'transformer'
model_name_or_path = 'seq2seq' # 'wavenet', 'transformer', 'rnn', 'tcn', 'bert', 'dlinear', 'nbeats', 'informer', 'autoformer'
config = AutoConfig.for_model(model_name_or_path)
model = AutoModel.from_config(config, predict_sequence_length=predict_sequence_length)
trainer = KerasTrainer(model)
trainer.train((x_train, y_train), (x_valid, y_valid), epochs=15)
trainer = KerasTrainer(model, optimizer=tf.keras.optimizers.Adam(0.0007))
trainer.train((x_train, y_train), (x_valid, y_valid), epochs=30)

pred = trainer.predict(x_valid)
trainer.plot(history=x_valid, true=y_valid, pred=pred)
Expand Down Expand Up @@ -202,11 +203,12 @@ model = AutoModel.from_config(config, predict_sequence_length=7)
- tcn
- bert
- nbeats
- dlinear
- seq2seq
- wavenet
- transformer
- informer
- dlinear
- autoformer

</details>

Expand Down
28 changes: 15 additions & 13 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
青山遮不住,毕竟东流去。江晚正愁余,山深闻鹧鸪。<br>
**东流TFTS** (TensorFlow Time Series) 是一个高效易用的时间序列框架,基于TensorFlow/ Keras。

- 为多种时间序列任务提供SOTA的深度学习模型,预测、分类、异常检测
- 提供经典与前沿的深度学习模型,用于工业、科研、竞赛
- 查阅[英文文档](https://time-series-prediction.readthedocs.io),快速入门。欢迎移步[时序讨论区](https://github.com/LongxingTan/Time-series-prediction/discussions)
- 为多种时间序列任务(单步与多步预测、分类、异常检测等)提供SOTA的深度学习模型
- 提供经典与前沿的深度学习模型,可用于工业、科研、竞赛
- 查阅[文档](https://time-series-prediction.readthedocs.io),快速入门。欢迎移步[时序讨论区](https://github.com/LongxingTan/Time-series-prediction/discussions)


## 快速使用
Expand All @@ -57,22 +57,23 @@ pip install tfts

```python
import matplotlib.pyplot as plt
import tensorflow as tf
import tfts
from tfts import AutoModel, KerasTrainer, Trainer, AutoConfig
from tfts import AutoModel, AutoConfig, KerasTrainer

train_length = 24
predict_sequence_length = 8

# 其中,train是包含(x_train, y_train)的tuple, valid包含(x_valid, y_valid)
train, valid = tfts.get_data('sine', train_length, predict_sequence_length, test_size=0.2)
config = AutoConfig.for_model("seq2seq") # 'wavenet', 'transformer'
model = AutoModel.from_config(config, predict_sequence_length)
(x_train, y_train), (x_valid, y_valid) = tfts.get_data("sine", train_length, predict_sequence_length, test_size=0.2)

trainer = KerasTrainer(model)
trainer.train(train, valid, epochs=15)
model_name_or_path = 'seq2seq' # 'wavenet', 'transformer', 'rnn', 'tcn', 'bert', 'dlinear', 'nbeats', 'informer', 'autoformer'
config = AutoConfig.for_model(model_name_or_path)
model = AutoModel.from_config(config, predict_sequence_length=predict_sequence_length)
trainer = KerasTrainer(model, optimizer=tf.keras.optimizers.Adam(0.0007))
trainer.train((x_train, y_train), (x_valid, y_valid), epochs=30)

pred = trainer.predict(valid[0])
trainer.plot(history=valid[0], true=valid[1], pred=pred)
pred = trainer.predict(x_valid)
trainer.plot(history=x_valid, true=y_valid, pred=pred)
plt.show()
```

Expand Down Expand Up @@ -203,12 +204,13 @@ model = AutoModel.from_config(config, predict_sequence_length=7)
- rnn
- tcn
- bert
- dlinear
- nbeats
- seq2seq
- wavenet
- transformer
- informer
- dlinear
- autoformer

</details>

Expand Down
8 changes: 3 additions & 5 deletions docs/source/quick-start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,20 @@ The general setup for training and testing a model is
import tfts
from tfts import AutoConfig, AutoModel, KerasTrainer
# load data
train_length = 36
predict_sequence_length = 12
train, valid = tfts.get_data('sine', train_length, predict_sequence_length)
# build model
# build model: 'seq2seq', 'wavenet', 'transformer', 'rnn', 'tcn', 'bert'
model_name_or_path = 'seq2seq'
config = AutoConfig.for_model(model_name_or_path)
model = AutoModel.from_config(config, predict_sequence_length=predict_sequence_length)
# train
opt = tf.keras.optimizers.Adam(0.003)
opt = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.MeanSquaredError()
trainer = KerasTrainer(model, loss_fn=loss_fn, optimizer=opt)
trainer.train(train, valid, epochs=10, batch_size=32)
trainer.train(train, valid, epochs=30, batch_size=32)
# test
trainer.predict(valid[0])
Expand All @@ -78,7 +77,6 @@ The general setup for training and testing a model is
Before training, ensure your raw data is preprocessed into a 3D format with the shape `(batch_size, train_steps, features)`. Perform any necessary data cleaning, normalization, or transformation steps to ensure the data is ready for training.



3.2 Train the Model
~~~~~~~~~~~~~~~~~~~~~~~~~~
When training the model, use appropriate loss functions, optimizers, and hyperparameters to achieve the best results.
Expand Down
9 changes: 9 additions & 0 deletions docs/source/tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ Tricks

.. _tricks:

.. note::

Time series is a typical No Free lunch scenario


Use tfts in competition flexible
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -21,6 +26,10 @@ There is no free launch, and it's impossible to forecast the future. So we shoul

skip connect. skip connect from ResNet is a special and common target transformation, tfts provides some basic skip connect in model config. If you want try more skip connect, please use ``AutoModel`` to make custom model.

* feature engineering

feature engineering is a art.

* different temporal scale

we can train different models from different scale
Expand Down
4 changes: 2 additions & 2 deletions tests/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_demo(self):
model = AutoModel.from_config(config, predict_sequence_length=predict_sequence_length)

trainer = Trainer(model)
trainer.train((x_train, y_train), (x_valid, y_valid), epochs=2)
trainer.train((x_train, y_train), (x_valid, y_valid), epochs=1)

pred = trainer.predict(x_valid)
# trainer.plot(history=x_valid, true=y_valid, pred=pred)
Expand All @@ -40,7 +40,7 @@ def test_demo2(self):
model = AutoModel.from_config(config=config, predict_sequence_length=predict_sequence_length)
print(x_train.shape, y_train.shape, x_valid.shape, y_valid.shape)

trainer = Trainer(model, optimizer=tf.keras.optimizers.legacy.Adam(0.003))
trainer = Trainer(model, optimizer=tf.keras.optimizers.legacy.Adam(0.001))
trainer.train((x_train, y_train), epochs=2)

pred = trainer.predict(x_valid)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_examples/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class args(object):
use_model = "rnn"
train_length = 10
predict_sequence_length = 5
epochs = 2
epochs = 1
batch_size = 32
learning_rate = 0.003

Expand Down
12 changes: 6 additions & 6 deletions tests/test_examples/test_tfts_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def setUp(self):
self.test_models = ["seq2seq", "wavenet", "transformer", "rnn", "tcn", "bert", "informer"]

def test_encoder_array(self):
train_length = 49
predict_sequence_length = 10
train_length = 32
predict_sequence_length = 9
n_feature = 2
x_train = np.random.rand(1, train_length, n_feature)
y_train = np.random.rand(1, predict_sequence_length, 1)
Expand All @@ -31,8 +31,8 @@ def test_encoder_array(self):
trainer.train(train_dataset=(x_train, y_train), valid_dataset=(x_valid, y_valid), epochs=1)

def test_encoder_decoder_array(self):
train_length = 49
predict_sequence_length = 10
train_length = 32
predict_sequence_length = 9
n_encoder_feature = 2
n_decoder_feature = 3
x_train = {
Expand All @@ -55,8 +55,8 @@ def test_encoder_decoder_array(self):
trainer.train((x_train, y_train), (x_valid, y_valid), epochs=1)

def test_encoder_decoder_array2(self):
train_length = 49
predict_sequence_length = 10
train_length = 32
predict_sequence_length = 9
n_encoder_feature = 2
n_decoder_feature = 3

Expand Down
1 change: 0 additions & 1 deletion tests/test_models/test_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization

import tfts
from tfts import AutoConfig, AutoModel, KerasTrainer, Trainer
from tfts.layers.attention_layer import Attention, ProbAttention
from tfts.models.informer import Decoder, DecoderLayer, DistilConv, Encoder, EncoderLayer, Informer
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_train(self):
config = AutoConfig.for_model("rnn")
model = AutoModel.from_config(config, predict_sequence_length=8)
trainer = KerasTrainer(model, optimizer=tf.keras.optimizers.legacy.Adam(0.003))
trainer.train(train, valid, epochs=2)
trainer.train(train, valid, epochs=1)
y_test = trainer.predict(valid[0])
self.assertEqual(y_test.shape, valid[1].shape)

Expand Down
12 changes: 1 addition & 11 deletions tests/test_models/test_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tfts
from tfts import AutoModel, KerasTrainer, Trainer
from tfts.models.seq2seq import DecoderV1, DecoderV2, DecoderV3, Encoder, Seq2seq
from tfts.models.seq2seq import DecoderV1, DecoderV2, Encoder, Seq2seq


class Seq2seqTest(unittest.TestCase):
Expand Down Expand Up @@ -33,16 +33,6 @@ def test_decoder2(self):
y = layer(x, init_input, init_state)
self.assertEqual(y.shape, (2, predict_sequence_length, 1))

def test_decoder3(self):
rnn_type = "gru"
rnn_size = 32
layer = DecoderV3(rnn_type=rnn_type, rnn_size=rnn_size, predict_sequence_length=1)
x = tf.random.normal([2, 11, 1])
init_input = tf.random.normal([2, 1])
init_state = tf.random.normal([2, rnn_size])
y = layer(x, init_input, init_state)
self.assertEqual(y.shape, (2, 11, 1))

def test_model(self):
predict_sequence_length = 8
model = Seq2seq(predict_sequence_length=predict_sequence_length)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_train(self):
config = AutoConfig.for_model("tcn")
model = AutoModel.from_config(config=config, predict_sequence_length=8)
trainer = KerasTrainer(model, optimizer=tf.keras.optimizers.legacy.Adam(0.003))
trainer.train(train, valid, epochs=2)
trainer.train(train, valid, epochs=1)
y_test = trainer.predict(valid[0])
self.assertEqual(y_test.shape, valid[1].shape)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import tfts
from tfts import AutoConfig, AutoModel, KerasTrainer, Trainer
from tfts.models.transformer import Decoder, Decoder2, Encoder, Transformer
from tfts.models.transformer import Decoder, Encoder, Transformer


class TransformerTest(unittest.TestCase):
Expand Down Expand Up @@ -70,6 +70,6 @@ def test_train(self):
config = AutoConfig.for_model("rnn")
model = AutoModel.from_config(config, predict_sequence_length=8)
trainer = KerasTrainer(model, optimizer=tf.keras.optimizers.legacy.Adam(0.003))
trainer.train(train, valid, epochs=2)
trainer.train(train, valid, epochs=1)
y_test = trainer.predict(valid[0])
self.assertEqual(y_test.shape, valid[1].shape)
Empty file.
2 changes: 1 addition & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_trainer_2gpu(self):
class KerasTrainerTest(unittest.TestCase):
def setUp(self):
self.fit_config = {
"epochs": 2,
"epochs": 1,
"batch_size": 1,
}

Expand Down
1 change: 0 additions & 1 deletion tfts/layers/autoformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def call(self, x: tf.Tensor):
- The moving average tensor, which is a smoothed version of the input tensor.
"""
moving_mean = self.moving_avg(x)
print(x.shape, moving_mean.shape)
trend = x - moving_mean
return trend, moving_mean

Expand Down
8 changes: 4 additions & 4 deletions tfts/layers/dense_layer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Layer for :py:class:`~tfts.models.wavenet` :py:class:`~tfts.models.transformer`"""

from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import Optional, Tuple

import tensorflow as tf
from tensorflow.keras import activations, constraints, initializers, regularizers
from tensorflow.keras.layers import BatchNormalization, Dense, Dropout
from tensorflow.keras.layers import Dense, Dropout


class DenseTemp(tf.keras.layers.Layer):
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self, hidden_size: int, filter_size: int, relu_dropout: float = 0.0
def build(self, input_shape: Tuple[Optional[int], ...]):
self.filter_dense_layer = Dense(self.filter_size, use_bias=True, activation="relu")
self.output_dense_layer = Dense(self.hidden_size, use_bias=True)
self.drop = Dropout(self.relu_dropout)
# self.drop = Dropout(self.relu_dropout)
super(FeedForwardNetwork, self).build(input_shape)

def call(self, x):
Expand All @@ -103,7 +103,7 @@ def call(self, x):
_description_
"""
output = self.filter_dense_layer(x)
output = self.drop(output)
# output = self.drop(output)
output = self.output_dense_layer(output)
return output

Expand Down
24 changes: 19 additions & 5 deletions tfts/layers/embed_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tensorflow as tf
from tensorflow.keras.layers import GRU, LSTM, Conv1D, Dense, Dropout, Embedding, LayerNormalization, SpatialDropout1D

from .position_layer import PositionalEmbedding, PositionalEncoding
from .position_layer import PositionalEmbedding, PositionalEncoding, RelativePositionEmbedding


class TokenEmbedding(tf.keras.layers.Layer):
Expand Down Expand Up @@ -106,8 +106,13 @@ def call(self, x, **kwargs):
return


class PatchEmbedding(tf.keras.layers.Layer):
def __init__(self):
super().__init__()


class DataEmbedding(tf.keras.layers.Layer):
def __init__(self, embed_size: int, dropout: float = 0.0):
def __init__(self, embed_size: int, dropout: float = 0.0, position_embedding_type: Optional[str] = None):
"""
Data Embedding layer.
Expand All @@ -118,7 +123,14 @@ def __init__(self, embed_size: int, dropout: float = 0.0):
super(DataEmbedding, self).__init__()
self.embed_size = embed_size
self.value_embedding = TokenEmbedding(embed_size)
self.positional_embedding = PositionalEncoding()
if position_embedding_type == "positional encoding":
self.positional_embedding = PositionalEncoding()
elif position_embedding_type == "positional embedding":
self.positional_embedding = PositionalEmbedding()
elif position_embedding_type == "relative encoding":
self.positional_embedding = RelativePositionEmbedding()
else:
self.positional_embedding = None
self.dropout = Dropout(dropout)

def build(self, input_shape: Tuple[Optional[int], ...]):
Expand All @@ -135,8 +147,10 @@ def call(self, x):
tf.Tensor: Output tensor of shape (batch_size, seq_length, embed_size).
"""
ve = self.value_embedding(x)
pe = self.positional_embedding(ve)
return self.dropout(ve + pe)
if self.positional_embedding is not None:
pe = self.positional_embedding(ve)
return self.dropout(ve + pe)
return ve

def get_config(self):
config = {"embed_size": self.embed_size}
Expand Down
Loading

0 comments on commit 817456a

Please sign in to comment.