Skip to content

Commit badac89

Browse files
committed
implement forecasting start date poc
1 parent 334c8b1 commit badac89

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

evadb/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Architecture of EvaDB Database System
22

3-
<img src="https://raw.githubusercontent.com/georgia-tech-db/evadb/master/docs/images/evadb/eva-arch.png" alt="EvaDB Architecture Diagram" width="500">
3+
<img src="https://raw.githubusercontent.com/georgia-tech-db/evadb/master/docs/images/evadb/evadb-arch.png" alt="EvaDB Architecture Diagram" width="500">
44

55
* `server` - Code for launching server that sends client commands to command handler.
66
* `parser` - Converts SQL queries to statements (e.g., CREATE, SELECT, INSERT, and LOAD statements).

evadb/executor/create_function_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pickle
2020
import re
2121
import time
22+
from datetime import datetime
2223
from pathlib import Path
2324
from typing import Dict, List
2425

@@ -233,6 +234,8 @@ def handle_sklearn_function(self):
233234

234235
def convert_to_numeric(self, x):
235236
x = re.sub("[^0-9.,]", "", str(x))
237+
if x is None or x == '':
238+
return x
236239
locale.setlocale(locale.LC_ALL, "")
237240
x = float(locale.atof(x))
238241
if x.is_integer():
@@ -422,6 +425,8 @@ def handle_forecasting_function(self):
422425
) # shortens longer frequencies like Q-DEC
423426
season_length = season_dict[new_freq] if new_freq in season_dict else 1
424427

428+
start = arg_map.get("start", None)
429+
425430
"""
426431
Neuralforecast implementation
427432
"""
@@ -683,6 +688,8 @@ def get_optuna_config(trial):
683688
FunctionMetadataCatalogEntry("horizon", horizon),
684689
FunctionMetadataCatalogEntry("library", library),
685690
FunctionMetadataCatalogEntry("conf", conf),
691+
FunctionMetadataCatalogEntry("start", start),
692+
FunctionMetadataCatalogEntry("frequency", new_freq),
686693
]
687694

688695
return (

evadb/functions/forecast.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import os
1818
import pickle
19+
import numpy as np
1920

2021
import pandas as pd
2122

@@ -39,6 +40,8 @@ def setup(
3940
horizon: int,
4041
library: str,
4142
conf: int,
43+
start: str,
44+
frequency: str,
4245
):
4346
self.library = library
4447
if "neuralforecast" in self.library:
@@ -62,6 +65,8 @@ def setup(
6265
self.conf = conf
6366
self.hypers = None
6467
self.rmse = None
68+
self.start = start
69+
self.frequency = frequency
6570
if os.path.isfile(model_path + "_rmse"):
6671
with open(model_path + "_rmse", "r") as f:
6772
self.rmse = float(f.readline())
@@ -70,12 +75,25 @@ def setup(
7075

7176
def forward(self, data) -> pd.DataFrame:
7277
log_str = ""
78+
79+
dates = None
80+
if self.start:
81+
# dates = pd.DataFrame(columns=['unique_id', 'ds'])
82+
# dates = dates.assign(unique_id=1)
83+
dates = pd.date_range(start=self.start, periods=self.horizon, freq=self.frequency).to_numpy()
84+
7385
if self.library == "statsforecast":
74-
forecast_df = self.model.predict(
75-
h=self.horizon, level=[self.conf]
76-
).reset_index()
86+
if dates is None:
87+
forecast_df = self.model.predict(
88+
h=self.horizon, level=[self.conf]
89+
).reset_index()
90+
else:
91+
forecast_df = self.model.fitted_[0, 0].forward(y=dates, h=self.horizon, level=[self.conf])
7792
else:
78-
forecast_df = self.model.predict().reset_index()
93+
if dates is None:
94+
forecast_df = self.model.predict().reset_index()
95+
else:
96+
forecast_df = self.model.predict(futr_df=dates).reset_index()
7997

8098
# Feedback
8199
if len(data) == 0 or list(list(data.iloc[0]))[0] is True:

0 commit comments

Comments
 (0)