Skip to content

Commit

Permalink
fix: task
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan committed Jan 13, 2025
1 parent 2bf4a15 commit 66b6eab
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 25 deletions.
7 changes: 3 additions & 4 deletions examples/run_anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def run_train(args):

trainer = KerasTrainer(model)
trainer.train((x_test, y_test), (x_test, y_test), epochs=args.epochs)
model.save_weights(args.output_dir)
# model.save_weights(args.output_dir)
trainer.save_model(args.output_dir)
return


Expand All @@ -87,9 +88,7 @@ def run_inference(args):
config = AutoConfig.for_model(args.use_model)
config.train_sequence_length = args.train_length

model = AutoModelForAnomaly.from_pretrained(
weights_dir=args.output_dir, predict_sequence_length=args.predict_sequence_length
)
model = AutoModelForAnomaly.from_pretrained(weights_dir=args.output_dir)
det = model.detect(x_test, y_test)
return sig, det

Expand Down
35 changes: 27 additions & 8 deletions tfts/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ def from_pretrained(cls, weights_dir, predict_sequence_length: int = 1):
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found at {config_path}")

config = BaseConfig.from_json(config_path) # Load config from JSON
model = cls.from_config(config, predict_sequence_length=predict_sequence_length)
model.load_weights(os.path.join(weights_dir, "weights.h5")) # Load weights
# config = BaseConfig.from_json(config_path) # Load config from JSON
# model = cls.from_config(config, predict_sequence_length=predict_sequence_length)
# model.load_weights(os.path.join(weights_dir, "weights.h5")) # Load weights
model = tf.keras.models.load_model(weights_dir)
return model

def save_pretrained(self):
Expand All @@ -102,9 +103,13 @@ def save_pretrained(self):
class AutoModelForPrediction(AutoModel):
"""tfts model for prediction"""

def __call__(self, x):
def __call__(
self,
x: Union[tf.data.Dataset, Tuple[np.ndarray], Tuple[pd.DataFrame], List[np.ndarray], List[pd.DataFrame]],
return_dict: Optional[bool] = None,
):

model_output = super().__call__(x)
model_output = super().__call__(x, return_dict=return_dict)

if self.config.skip_connect_circle:
x_mean = x[:, -self.predict_sequence_length :, 0:1]
Expand All @@ -121,8 +126,9 @@ class AutoModelForClassification(AutoModel):
def __call__(
self,
x: Union[tf.data.Dataset, Tuple[np.ndarray], Tuple[pd.DataFrame], List[np.ndarray], List[pd.DataFrame]],
return_dict: Optional[bool] = None,
):
return super().__call__(x)
return super().__call__(x, return_dict=return_dict)


class AutoModelForAnomaly(AutoModel):
Expand All @@ -141,15 +147,27 @@ def detect(
dist = self.head(model_output, labels)
return dist

@classmethod
def from_pretrained(cls, weights_dir: str):
model = tf.keras.models.load_model(weights_dir)
logger.info(f"Load model from {weights_dir}")
config_path = os.path.join(weights_dir, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found at {config_path}")

config = BaseConfig.from_json(config_path) # Load config from JSON
return cls(model, config)


class AutoModelForSegmentation(AutoModel):
"""tfts model for time series segmentation"""

def __call__(
self,
x: Union[tf.data.Dataset, Tuple[np.ndarray], Tuple[pd.DataFrame], List[np.ndarray], List[pd.DataFrame]],
return_dict: Optional[bool] = None,
):
model_output = self.model(x)
model_output = self.model(x, return_dict=return_dict)
return model_output


Expand All @@ -159,6 +177,7 @@ class AutoModelForUncertainty(AutoModel):
def __call__(
self,
x: Union[tf.data.Dataset, Tuple[np.ndarray], Tuple[pd.DataFrame], List[np.ndarray], List[pd.DataFrame]],
return_dict: Optional[bool] = None,
):
model_output = self.model(x)
model_output = self.model(x, return_dict=return_dict)
return model_output
32 changes: 19 additions & 13 deletions tfts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from collections.abc import Iterable
import logging
import os
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Input

from reference.gan import config

from .constants import TFTS_HUB_CACHE

__all__ = ["Trainer", "KerasTrainer", "Seq2seqKerasTrainer"]
Expand Down Expand Up @@ -124,10 +127,10 @@ def train(
else:
no_improve_epochs += 1
if no_improve_epochs >= stop_no_improve_epochs:
logging.info("Tried the best, no improved and stop training")
logger.info("Tried the best, no improved and stop training")
break

logging.info(log_str)
logger.info(log_str)

# self.export_model(model_dir, only_pb=True) # save the model

Expand Down Expand Up @@ -166,7 +169,7 @@ def train_step(self, x_train, y_train):
lr = self.learning_rate
self.optimizer.lr.assign(lr)
self.global_step.assign_add(1)
# logging.info('Step: {}, Loss: {}'.format(self.global_step.numpy(), loss))
# logger.info('Step: {}, Loss: {}'.format(self.global_step.numpy(), loss))
return y_pred, loss

def valid_loop(self, valid_loader):
Expand Down Expand Up @@ -208,11 +211,11 @@ def predict(self, test_loader):
def export_model(self, model_dir, only_pb=True):
# save the model
tf.saved_model.save(self.model, model_dir)
logging.info(f"Protobuf model successfully saved in {model_dir}")
logger.info(f"Protobuf model successfully saved in {model_dir}")

if not only_pb:
self.model.save_weights(f"{model_dir}.ckpt")
logging.info(f"Model weights successfully saved in {model_dir}.ckpt")
logger.info(f"Model weights successfully saved in {model_dir}.ckpt")


class KerasTrainer(object):
Expand All @@ -235,6 +238,7 @@ def __init__(
run_eagerly: it depends on which one is much faster
"""
self.model = model
self.config = model.config if hasattr(model, "config") else None
self.loss_fn = loss_fn
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
Expand Down Expand Up @@ -269,7 +273,7 @@ def train(
callbacks.append(checkpoint)
if "callbacks" in kwargs:
callbacks += kwargs.get("callbacks")
logging.info("callback", callbacks)
logger.info("callback", callbacks)

# if self.strategy is None:
# self.strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
Expand Down Expand Up @@ -309,7 +313,7 @@ def train(

trainable_params = np.sum([tf.keras.backend.count_params(w) for w in self.model.trainable_weights])
# print(self.model.summary())
logging.info(f"Trainable parameters: {trainable_params}")
logger.info(f"Trainable parameters: {trainable_params}")
self.model.compile(
loss=self.loss_fn, optimizer=self.optimizer, metrics=callback_metrics, run_eagerly=self.run_eagerly
)
Expand Down Expand Up @@ -345,20 +349,22 @@ def predict(self, x_test, batch_size: int = 1):
def get_model(self):
return self.model

def save_model(self, model_dir, only_pb: bool = True, checkpoint_dir: Optional[str] = None):
def save_model(self, model_dir, save_weights_only: bool = True, checkpoint_dir: Optional[str] = None):
# save the model, checkpoint_dir if you use Checkpoint callback to save your best weights
if checkpoint_dir is not None:
logging.info("checkpoint Loaded", checkpoint_dir)
logger.info("checkpoint Loaded", checkpoint_dir)
self.model.load_weights(checkpoint_dir)
else:
logging.info("No checkpoint Loaded")
logger.info("No checkpoint Loaded")

self.model.save(model_dir)
logging.info("protobuf model successfully saved in {}".format(model_dir))
if self.config is not None:
self.config.to_json(os.path.join(model_dir, "config.json"))
logger.info("protobuf model successfully saved in {}".format(model_dir))

if not only_pb:
if not save_weights_only:
self.model.save_weights("{}.ckpt".format(model_dir))
logging.info("model weights successfully saved in {}.ckpt".format(model_dir))
logger.info("model weights successfully saved in {}.ckpt".format(model_dir))
return

def plot(self, history, true, pred):
Expand Down

0 comments on commit 66b6eab

Please sign in to comment.