-
Notifications
You must be signed in to change notification settings - Fork 735
[ENH] Add predict to v2 models
#1984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
phoeenniixx
wants to merge
18
commits into
sktime:main
Choose a base branch
from
phoeenniixx:predict-v2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
cff05b9
initial design
phoeenniixx 1371c68
Merge branch 'main' into predict-v2
phoeenniixx 13c12b9
preliminary design
phoeenniixx 7cfb26c
Merge branch 'main' into predict-v2
phoeenniixx 35e2447
add predict to all v2 models
phoeenniixx 882caba
update base model
phoeenniixx c53f881
update base model
phoeenniixx f8d06fe
update test_integration
phoeenniixx a1aad82
Merge branch 'main' into predict-v2
phoeenniixx e7fcc77
Merge branch 'main' into predict-v2
phoeenniixx 5024fcb
hadnle kwargs
phoeenniixx e9c88c3
Merge branch 'main' into predict-v2
phoeenniixx 8e048cd
add tests for checkpoints and predict modes
phoeenniixx 5072689
update docstring
phoeenniixx 64f7e7d
add examples
phoeenniixx ec89f50
add examples
phoeenniixx c1e3426
add comments
phoeenniixx ba74f1e
update notebook
phoeenniixx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,310 @@ | ||
| from pathlib import Path | ||
| import pickle | ||
| from typing import Any, Optional, Union | ||
|
|
||
| from lightning import Trainer | ||
| from lightning.pytorch.callbacks import ModelCheckpoint | ||
| from lightning.pytorch.core.datamodule import LightningDataModule | ||
| import torch | ||
| from torch.utils.data import DataLoader | ||
| import yaml | ||
|
|
||
| from pytorch_forecasting.data import TimeSeries | ||
| from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 | ||
|
|
||
|
|
||
| class Base_pkg(_BasePtForecasterV2): | ||
| """ | ||
| Base model package class acting as a high-level wrapper for the Lightning workflow. | ||
|
|
||
| This class simplifies the user experience by managing model, datamodule, and trainer | ||
| configurations, and providing streamlined ``fit`` and ``predict`` methods. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model_cfg : dict, optional | ||
| Model configs for the initialisation of the model. Required if not loading | ||
| from a checkpoint. Defaults to ``{}``. | ||
| trainer_cfg : dict, optional | ||
| Configs to initialise ``lightning.Trainer``. Defaults to {}. | ||
| datamodule_cfg : Union[dict, str, Path], optional | ||
| Configs to initialise a ``LightningDataModule``. | ||
|
|
||
| - If dict, the keys and values are used as configuration parameters. | ||
| - If str or Path, it should be a path to a ``.pkl`` file containing | ||
| the serialized configuration dictionary. Required for reproducibility | ||
| when loading a model for inference. Defaults to {}. | ||
|
|
||
| ckpt_path : Union[str, Path], optional | ||
| Path to the checkpoint from which to load the model. If provided, `model_cfg` | ||
| is ignored. Defaults to None. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_cfg: Optional[Union[dict[str, Any], str, Path]] = None, | ||
| trainer_cfg: Optional[Union[dict[str, Any], str, Path]] = None, | ||
| datamodule_cfg: Optional[Union[dict[str, Any], str, Path]] = None, | ||
| ckpt_path: Optional[Union[str, Path]] = None, | ||
| ): | ||
| self.ckpt_path = Path(ckpt_path) if ckpt_path else None | ||
| self.model_cfg = self._load_config( | ||
| model_cfg, ckpt_path=self.ckpt_path, auto_file_name="model_cfg.pkl" | ||
| ) | ||
| print(self.model_cfg) | ||
|
|
||
| self.datamodule_cfg = self._load_config( | ||
| datamodule_cfg, | ||
| ckpt_path=self.ckpt_path, | ||
| auto_file_name="datamodule_cfg.pkl", | ||
| ) | ||
| self.trainer_cfg = self._load_config(trainer_cfg) | ||
| self.metadata = self._load_config( | ||
| None, ckpt_path=self.ckpt_path, auto_file_name="metadata.pkl" | ||
| ) | ||
|
|
||
| self.model = None | ||
| self.trainer = None | ||
| self.datamodule = None | ||
| if self.ckpt_path: | ||
| print(self.metadata) | ||
| self._build_model(metadata=self.metadata, **self.model_cfg) | ||
| else: | ||
| self.model = None | ||
|
|
||
| @staticmethod | ||
| def _load_config( | ||
| config: Union[dict, str, Path, None], | ||
| ckpt_path: Optional[Union[str, Path]] = None, | ||
| auto_file_name: Optional[str] = None, | ||
| ) -> dict: | ||
| """ | ||
| Loads configuration from a dictionary, YAML file, or Pickle file. | ||
| """ | ||
| if config is None: | ||
| if ckpt_path and auto_file_name: | ||
| path = Path(ckpt_path).parent / auto_file_name | ||
| if path.exists(): | ||
| with open(path, "rb") as f: | ||
| return pickle.load(f) # noqa : S301 | ||
| return {} | ||
|
|
||
| if isinstance(config, dict): | ||
| return config | ||
|
|
||
| path = Path(config) | ||
| if not path.exists(): | ||
| raise FileNotFoundError(f"Configuration file not found: {path}") | ||
|
|
||
| suffix = path.suffix.lower() | ||
| print(suffix) | ||
|
|
||
| if suffix in [".yaml", ".yml"]: | ||
| with open(path) as f: | ||
| return yaml.safe_load(f) or {} | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| f"Unsupported config format: {suffix}. Use .yaml, .yml, or .pkl" | ||
| ) | ||
|
|
||
| @classmethod | ||
| def get_cls(cls): | ||
| """Get the underlying model class.""" | ||
| raise NotImplementedError("Subclasses must implement `get_cls`.") | ||
|
|
||
| @classmethod | ||
| def get_datamodule_cls(cls): | ||
| """Get the underlying DataModule class.""" | ||
| raise NotImplementedError("Subclasses must implement `get_datamodule_cls`.") | ||
|
|
||
| @classmethod | ||
| def get_test_dataset_from(cls, **kwargs): | ||
| """ | ||
| Creates and returns D1 TimeSeries dataSet objects for testing. | ||
| """ | ||
| from pytorch_forecasting.tests._data_scenarios import ( | ||
| data_with_covariates_v2, | ||
| make_datasets_v2, | ||
| ) | ||
|
|
||
| raw_data = data_with_covariates_v2() | ||
|
|
||
| datasets_info = make_datasets_v2(raw_data, **kwargs) | ||
|
|
||
| return { | ||
| "train": datasets_info["training_dataset"], | ||
| "predict": datasets_info["validation_dataset"], | ||
| } | ||
|
|
||
| def _build_model(self, metadata: dict, **kwargs): | ||
| """Instantiates the model, either from a checkpoint or from config.""" | ||
| model_cls = self.get_cls() | ||
| if self.ckpt_path: | ||
| self.model = model_cls.load_from_checkpoint( | ||
| self.ckpt_path, metadata=metadata, **kwargs | ||
| ) | ||
| elif self.model_cfg: | ||
| self.model = model_cls(**self.model_cfg, metadata=metadata) | ||
| else: | ||
| self.model = None | ||
|
|
||
| def _build_datamodule(self, data: TimeSeries) -> LightningDataModule: | ||
| """Constructs a DataModule from a D1 layer object.""" | ||
| if not self.datamodule_cfg: | ||
| raise ValueError("`datamodule_cfg` must be provided to build a datamodule.") | ||
| datamodule_cls = self.get_datamodule_cls() | ||
| return datamodule_cls(data, **self.datamodule_cfg) | ||
|
|
||
| def _load_dataloader( | ||
| self, data: Union[TimeSeries, LightningDataModule, DataLoader] | ||
| ) -> DataLoader: | ||
| """Converts various data input types into a DataLoader for prediction.""" | ||
| if isinstance(data, TimeSeries): # D1 Layer | ||
| dm = self._build_datamodule(data) | ||
| dm.setup(stage="predict") | ||
| return dm.predict_dataloader() | ||
| elif isinstance(data, LightningDataModule): # D2 Layer | ||
| data.setup(stage="predict") | ||
| return data.predict_dataloader() | ||
| elif isinstance(data, DataLoader): | ||
| return data | ||
| else: | ||
| raise TypeError( | ||
| f"Unsupported data type for prediction: {type(data).__name__}. " | ||
| "Expected TimeSeriesDataSet, LightningDataModule, or DataLoader." | ||
| ) | ||
|
|
||
| def _save_artifact(self, output_dir: Path): | ||
| """Save all configuration artifacts.""" | ||
| output_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| with open(output_dir / "datamodule_cfg.pkl", "wb") as f: | ||
| pickle.dump(self.datamodule_cfg, f) | ||
|
|
||
| with open(output_dir / "model_cfg.pkl", "wb") as f: | ||
| pickle.dump(self.model_cfg, f) | ||
|
|
||
| if self.datamodule is not None and hasattr(self.datamodule, "metadata"): | ||
| with open(output_dir / "metadata.pkl", "wb") as f: | ||
| pickle.dump(self.datamodule.metadata, f) | ||
|
|
||
| def fit( | ||
| self, | ||
| data: Union[TimeSeries, LightningDataModule], | ||
| # todo: we should create a base data_module for different data_modules | ||
| save_ckpt: bool = True, | ||
| ckpt_dir: Union[str, Path] = "checkpoints", | ||
| ckpt_kwargs: Optional[dict[str, Any]] = None, | ||
| **trainer_fit_kwargs, | ||
| ): | ||
| """ | ||
| Fit the model to the training data. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : Union[TimeSeries, LightningDataModule] | ||
| The data to fit on (D1 or D2 layer). This object is responsible | ||
| for providing both training and validation data. | ||
| save_ckpt : bool, default=True | ||
| If True, save the best model checkpoint and the `datamodule_cfg`. | ||
| ckpt_dir : Union[str, Path], default="checkpoints" | ||
| Directory to save artifacts. | ||
| ckpt_kwargs : dict, optional | ||
| Keyword arguments passed to ``ModelCheckpoint``. | ||
| **trainer_fit_kwargs : | ||
| Additional keyword arguments passed to `trainer.fit()`. | ||
|
|
||
| Returns | ||
| ------- | ||
| Optional[Path] | ||
| The path to the best model checkpoint if `save_ckpt=True`, else None. | ||
| """ | ||
| if isinstance(data, TimeSeries): | ||
| self.datamodule = self._build_datamodule(data) | ||
| else: | ||
| self.datamodule = data | ||
| self.datamodule.setup(stage="fit") | ||
|
|
||
| if self.model is None: | ||
| if not self.model_cfg: | ||
| raise RuntimeError( | ||
| "`model_cfg` must be provided to train from scratch." | ||
| ) | ||
| metadata = self.datamodule.metadata | ||
| self._build_model(metadata) | ||
|
|
||
| callbacks = self.trainer_cfg.get("callbacks", []).copy() | ||
| checkpoint_cb = None | ||
| if save_ckpt: | ||
| ckpt_dir = Path(ckpt_dir) | ||
| ckpt_dir.mkdir(parents=True, exist_ok=True) | ||
| default_ckpt_kwargs = { | ||
| "dirpath": ckpt_dir, | ||
| "filename": "best-{epoch}-{step}", | ||
| "save_top_k": 1, | ||
| "monitor": "val_loss", | ||
| "mode": "min", | ||
| } | ||
| if ckpt_kwargs: | ||
| default_ckpt_kwargs.update(ckpt_kwargs) | ||
| checkpoint_cb = ModelCheckpoint(**default_ckpt_kwargs) | ||
| callbacks.append(checkpoint_cb) | ||
| trainer_init_cfg = self.trainer_cfg.copy() | ||
| trainer_init_cfg.pop("callbacks", None) | ||
|
|
||
| self.trainer = Trainer(**trainer_init_cfg, callbacks=callbacks) | ||
|
|
||
| self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs) | ||
| if save_ckpt and checkpoint_cb: | ||
| best_model_path = Path(checkpoint_cb.best_model_path) | ||
| self._save_artifact(best_model_path.parent) | ||
| print(f"Artifacts saved in: {best_model_path.parent}") | ||
| return best_model_path | ||
| return None | ||
|
|
||
| def predict( | ||
| self, | ||
| data: Union[TimeSeries, LightningDataModule, DataLoader], | ||
| output_dir: Optional[Union[str, Path]] = None, | ||
| **kwargs, | ||
| ) -> Union[dict[str, torch.Tensor], None]: | ||
| """ | ||
| Generate predictions by wrapping the model's predict method. | ||
|
|
||
| This method prepares the data by resolving it into a DataLoader and then | ||
| delegates the prediction task to the underlying model's ``.predict()`` method. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : Union[TimeSeries, LightningDataModule, DataLoader] | ||
| The data to predict on (D1, D2, or DataLoader). | ||
| **kwargs : | ||
| Additional keyword arguments passed directly to the model's ``.predict()`` | ||
| method. This includes `mode`, `return_info`, `output_dir`, and any | ||
| `trainer_kwargs`. | ||
|
|
||
| Returns | ||
| ------- | ||
| Union[Dict[str, torch.Tensor], None] | ||
| A dictionary of prediction tensors, or `None` if `output_dir` is specified | ||
| in `**kwargs`. | ||
| """ | ||
| if self.model is None: | ||
| raise RuntimeError( | ||
| "Model is not initialized. Provide `model_cfg` or `ckpt_path`." | ||
| ) | ||
|
|
||
| dataloader = self._load_dataloader(data) | ||
| predictions = self.model.predict(dataloader, **kwargs) | ||
|
|
||
| if output_dir: | ||
| output_path = Path(output_dir) | ||
| output_path.mkdir(parents=True, exist_ok=True) | ||
| output_file = output_path / "predictions.pkl" | ||
| with open(output_file, "wb") as f: | ||
| pickle.dump(predictions, f) | ||
| print(f"Predictions saved to {output_file}") | ||
| return None | ||
|
|
||
| return predictions | ||
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor formatting issue: please have newlines around bullet point lists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for poiniting it out. I will make the changes to the PR soon.