From d485aae01f182052f90607fd8961c2b05edccda8 Mon Sep 17 00:00:00 2001 From: Patricio Cerda Mardini Date: Tue, 14 May 2024 19:06:50 +0200 Subject: [PATCH] hotfix & version bump: 24.5.1.1 --- dataprep_ml/__init__.py | 2 +- dataprep_ml/splitters.py | 11 +++++++---- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/dataprep_ml/__init__.py b/dataprep_ml/__init__.py index 0a164d3..0b4cf02 100644 --- a/dataprep_ml/__init__.py +++ b/dataprep_ml/__init__.py @@ -1,6 +1,6 @@ from dataprep_ml.base import StatisticalAnalysis, DataAnalysis -__version__ = '24.5.1.0' +__version__ = '24.5.1.1' __name__ = "dataprep_ml" diff --git a/dataprep_ml/splitters.py b/dataprep_ml/splitters.py index 6daefb4..8495d4c 100644 --- a/dataprep_ml/splitters.py +++ b/dataprep_ml/splitters.py @@ -58,11 +58,14 @@ def splitter( train, dev, test = simple_split(data, pct_train, pct_dev, pct_test) # Final assertions for time series - if min(len(train), len(dev)) < tss.get('window', 1): - raise Exception(f"Dataset size is too small for the specified window size ({tss.get('window', 1)})") + window = tss.get('window', 1) if tss.get('window', 1) else 1 + horizon = tss.get('horizon', 1) if tss.get('horizon', 1) else 1 - if min(len(train), len(dev), len(test)) < tss.get('horizon', 1): - raise Exception(f"Dataset size is too small for the specified horizon size ({tss.get('horizon', 1)})") + if min(len(train), len(dev)) < window: + raise Exception(f"Dataset size is too small for the specified window size ({window})") + + if min(len(train), len(dev), len(test)) < horizon: + raise Exception(f"Dataset size is too small for the specified horizon size ({horizon})") return {"train": train, "test": test, "dev": dev, "stratified_on": stratify_on} diff --git a/pyproject.toml b/pyproject.toml index 1a0257b..95c4e39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dataprep-ml" -version = "24.5.1.0" +version = "24.5.1.1" description = "Automated dataframe analysis for Machine Learning pipelines." authors = ["MindsDB Inc. "] license = "GPL-3.0"