Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return a forecasting plot with the output #1346

Open
wants to merge 30 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4fba6a9
Add feedback for flat predictions
americast Oct 5, 2023
c15dba3
Merge remote-tracking branch 'origin/staging' into feedback
americast Oct 20, 2023
dbe6d02
update suggestion handle to bool
americast Oct 20, 2023
c118203
wip: add confidence
americast Oct 24, 2023
d208bf9
confidence done
americast Oct 24, 2023
883c598
add docs
americast Oct 24, 2023
4f6ea4a
fix tests
americast Oct 24, 2023
70088d2
metrics for statsforecast
americast Oct 27, 2023
6192554
neuralforecast rmse
americast Oct 29, 2023
ed8de98
reveal hyperparams
americast Nov 1, 2023
a83c30d
update docs
americast Nov 1, 2023
0171926
Add more DL models
americast Nov 2, 2023
fc48003
metrics optional
americast Nov 3, 2023
0e9f5cd
Merge remote-tracking branch 'origin/staging' into feedback
americast Nov 3, 2023
4ddd107
Merge remote-tracking branch 'origin/staging' into feedback
americast Nov 7, 2023
56a401a
fix binder error
americast Nov 7, 2023
30b0fb1
update docs
americast Nov 7, 2023
199c6bb
change setup()
americast Nov 8, 2023
f3f02be
fix metrics logic
americast Nov 8, 2023
d5d90ef
Merge branch 'feedback' into return_plot
americast Nov 9, 2023
4ef2e70
plot as a new column
americast Nov 10, 2023
9e3ec86
updated binder
americast Nov 10, 2023
eb88e9b
update docs
americast Nov 10, 2023
3ca9ce1
Merge remote-tracking branch 'origin/staging' into plot_old
americast Nov 17, 2023
8d2d5ef
Fix xgboost and test errors
americast Nov 17, 2023
2d38124
update test
americast Nov 17, 2023
be2c080
Merge remote-tracking branch 'origin/staging' into return_plot
americast Dec 3, 2023
fbb4519
Merge branch 'return_plot' of github.com:georgia-tech-db/evadb into r…
americast Dec 3, 2023
7197e78
fix unit test
americast Dec 3, 2023
dadcb13
fix long test
americast Dec 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ jobs:
- checkout
- run:
name: Install EvaDB package from GitHub repo and run tests
no_output_timeout: 30m # 30 minute timeout
no_output_timeout: 40m # 40 minute timeout
command: |
python -m venv test_evadb
source test_evadb/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/ai/model-forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ EvaDB's default forecast framework is `statsforecast <https://nixtla.github.io/s

.. note::

`Forecasting` function also logs suggestions. Logged information, such as metrics and suggestions, is sent to STDOUT by default. If you wish not to print it, please send `FALSE` as an optional argument while calling the function. Eg. `SELECT Forecast(FALSE);`
`Forecasting` function also logs suggestions. Logged information, such as metrics and suggestions, is sent to STDOUT by default. A figure is also plotted and is saved in a binary format supported by OpenCV in the `plot` column of the output table. It maybe rendered using the `cv2.imdecode` function. If you wish not to obtain the logged information, please send `FALSE` as an optional argument while calling the function. Eg. `SELECT Forecast(FALSE);`


Below is an example query specifying the above parameters:
Expand Down
13 changes: 6 additions & 7 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from evadb.binder.statement_binder_context import StatementBinderContext
from evadb.catalog.catalog_type import ColumnType, TableType
from evadb.catalog.catalog_utils import is_document_table
from evadb.catalog.sql_config import RESTRICTED_COL_NAMES
from evadb.expression.abstract_expression import AbstractExpression, ExpressionType
from evadb.expression.function_expression import FunctionExpression
from evadb.expression.tuple_value_expression import TupleValueExpression
Expand Down Expand Up @@ -137,6 +136,12 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement):
None,
None,
),
ColumnDefinition(
"plot",
ColumnType.ANY,
None,
None,
),
]
)
else:
Expand Down Expand Up @@ -211,12 +216,6 @@ def _bind_delete_statement(self, node: DeleteTableStatement):

@bind.register(CreateTableStatement)
def _bind_create_statement(self, node: CreateTableStatement):
# we don't allow certain keywords in the column_names
for col in node.column_list:
assert (
col.name.lower() not in RESTRICTED_COL_NAMES
), f"EvaDB does not allow to create a table with column name {col.name}"

if node.query is not None:
self.bind(node.query)

Expand Down
6 changes: 5 additions & 1 deletion evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def get_optuna_config(trial):
raise FunctionIODefinitionError(err_msg)

model = StatsForecast(
[model_here(season_length=season_length)], freq=new_freq
[model_here(season_length=season_length)], freq=new_freq, n_jobs=-1
)

data["ds"] = pd.to_datetime(data["ds"])
Expand Down Expand Up @@ -668,6 +668,8 @@ def get_optuna_config(trial):
model_path = os.path.join(model_dir, existing_model_files[-1])
io_list = self._resolve_function_io(None)
data["ds"] = data.ds.astype(str)
last_ds = list(data["ds"])[-2 * horizon :]
last_y = list(data["y"])[-2 * horizon :]
metadata_here = [
FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
FunctionMetadataCatalogEntry("model_path", model_path),
Expand All @@ -683,6 +685,8 @@ def get_optuna_config(trial):
FunctionMetadataCatalogEntry("horizon", horizon),
FunctionMetadataCatalogEntry("library", library),
FunctionMetadataCatalogEntry("conf", conf),
FunctionMetadataCatalogEntry("last_ds", last_ds),
FunctionMetadataCatalogEntry("last_y", last_y),
]

return (
Expand Down
8 changes: 6 additions & 2 deletions evadb/expression/function_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,12 @@ def evaluate(self, batch: Batch, **kwargs) -> Batch:

# process outcomes only if output is not empty
if outcomes.frames.empty is False:
outcomes = outcomes.project(self.projection_columns)
outcomes.modify_column_alias(self.alias)
if self._function().name == "ForecastModel":
outcomes = outcomes.project(self.projection_columns, forecast=True)
outcomes.modify_column_alias(self.alias, forecast=True)
else:
outcomes = outcomes.project(self.projection_columns)
outcomes.modify_column_alias(self.alias)

# record the number of function calls
self._stats.num_calls += len(batch)
Expand Down
68 changes: 66 additions & 2 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import os
import pickle

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from evadb.functions.abstract.abstract_function import AbstractFunction
Expand All @@ -39,6 +42,8 @@ def setup(
horizon: int,
library: str,
conf: int,
last_ds: list,
last_y: list,
):
self.library = library
if "neuralforecast" in self.library:
Expand Down Expand Up @@ -67,6 +72,8 @@ def setup(
self.rmse = float(f.readline())
if "arima" in model_name.lower():
self.hypers = "p,d,q: " + f.readline()
self.last_ds = last_ds
self.last_y = last_y

def forward(self, data) -> pd.DataFrame:
log_str = ""
Expand All @@ -79,7 +86,7 @@ def forward(self, data) -> pd.DataFrame:

# Feedback
if len(data) == 0 or list(list(data.iloc[0]))[0] is True:
# Suggestions
## Suggestions
suggestion_list = []
# 1: Flat predictions
if self.library == "statsforecast":
Expand All @@ -95,12 +102,69 @@ def forward(self, data) -> pd.DataFrame:
for suggestion in set(suggestion_list):
log_str += "\nSUGGESTION: " + self.suggestion_dict[suggestion]

# Metrics
## Metrics
if self.rmse is not None:
log_str += "\nMean normalized RMSE: " + str(self.rmse)
if self.hypers is not None:
log_str += "\nHyperparameters: " + self.hypers

## Plot figure

pred_plt = self.last_y + list(
forecast_df[
self.model_name
if self.library == "statsforecast"
else self.model_name + "-median"
]
)
pred_plt_lo = self.last_y + list(
forecast_df[self.model_name + "-lo-" + str(self.conf)]
)
pred_plt_hi = self.last_y + list(
forecast_df[self.model_name + "-hi-" + str(self.conf)]
)

plt.plot(pred_plt, label="Prediction")
plt.fill_between(
x=range(len(pred_plt)), y1=pred_plt_lo, y2=pred_plt_hi, alpha=0.3
)
plt.plot(self.last_y, label="Actual")
plt.xlabel("Time")
plt.ylabel("Value")
xtick_strs = self.last_ds + list(forecast_df["ds"])
num_to_keep_args = list(
range(0, len(xtick_strs), int((len(xtick_strs) - 2) / 8))
) + [len(xtick_strs) - 1]
xtick_strs = [
x if i in num_to_keep_args else "" for i, x in enumerate(xtick_strs)
]
plt.xticks(range(len(pred_plt)), xtick_strs, rotation=85)
plt.legend()
plt.tight_layout()

# convert plt figure to opencv, inspired from https://copyprogramming.com/howto/convert-matplotlib-figure-to-cv2-image-a-complete-guide-with-examples#converting-matplotlib-figure-to-cv2-image
# convert figure to canvas
canvas = plt.get_current_fig_manager().canvas

# render the canvas
canvas.draw()

# convert canvas to image
img = np.fromstring(canvas.tostring_rgb(), dtype="uint8")
img = img.reshape(canvas.get_width_height()[::-1] + (3,))

# convert image to cv2 format
cv2_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

# Conver to bytes
_, buffer = cv2.imencode(".jpg", cv2_img)
img_bytes = buffer.tobytes()

# Add to dataframe as a plot
forecast_df["plot"] = [img_bytes] + [None] * (len(forecast_df) - 1)

log_str += "\nA plot has been saved in the 'plot' column of the output table. It maybe rendered using the cv2.imdecode function."

print(log_str)

forecast_df = forecast_df.rename(
Expand Down
21 changes: 16 additions & 5 deletions evadb/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,17 @@ def update_indices(self, indices: List, other: Batch):
def file_paths(self) -> Iterable:
yield from self._frames["file_path"]

def project(self, cols: None) -> Batch:
def project(self, cols: None, forecast: bool = False) -> Batch:
"""
Takes as input the column list, returns the projection.
We do a copy for now.
"""
cols = cols or []
verified_cols = [c for c in cols if c in self._frames]
unknown_cols = list(set(cols) - set(verified_cols))
assert len(unknown_cols) == 0, unknown_cols
assert len(unknown_cols) == 0 or (
forecast is True and unknown_cols == ["plot"]
), unknown_cols
return Batch(self._frames[verified_cols])

@classmethod
Expand Down Expand Up @@ -405,14 +407,20 @@ def reset_index(self):
"""Resets the index of the data frame in the batch"""
self._frames.reset_index(drop=True, inplace=True)

def modify_column_alias(self, alias: Union[Alias, str]) -> None:
def modify_column_alias(
self, alias: Union[Alias, str], forecast: bool = False
) -> None:
# a, b, c -> table1.a, table1.b, table1.c
# t1.a -> t2.a
if isinstance(alias, str):
alias = Alias(alias)
new_col_names = []
if len(alias.col_names):
if len(self.columns) != len(alias.col_names):
if (len(self.columns) != len(alias.col_names) and forecast is False) or (
forecast is True
and (self.columns != alias.col_names)
and list(set(alias.col_names) - set(self.columns)) != ["plot"]
):
err_msg = (
f"Expected {len(alias.col_names)} columns {alias.col_names},"
f"got {len(self.columns)} columns {self.columns}."
Expand All @@ -431,7 +439,10 @@ def modify_column_alias(self, alias: Union[Alias, str]) -> None:
else:
new_col_names.append("{}.{}".format(alias.alias_name, col_name))

self._frames.columns = new_col_names
if forecast and list(set(alias.col_names) - set(self.columns)) == ["plot"]:
self._frames.columns = new_col_names[:-1]
else:
self._frames.columns = new_col_names

def drop_column_alias(self) -> None:
# table1.a, table1.b, table1.c -> a, b, c
Expand Down
4 changes: 3 additions & 1 deletion test/integration_tests/long/test_model_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_forecast(self):
"airforecast.y",
"airforecast.y-lo",
"airforecast.y-hi",
"airforecast.plot",
],
)

Expand All @@ -122,7 +123,7 @@ def test_forecast_neuralforecast(self):
execute_query_fetch_all(self.evadb, create_predict_udf)

predict_query = """
SELECT AirPanelForecast() order by y;
SELECT AirPanelForecast(FALSE) order by y;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 24)
Expand Down Expand Up @@ -167,6 +168,7 @@ def test_forecast_with_column_rename(self):
"homeforecast.ma",
"homeforecast.ma-lo",
"homeforecast.ma-hi",
"homeforecast.plot",
],
)

Expand Down
12 changes: 12 additions & 0 deletions test/unit_tests/binder/test_statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ def test_bind_create_function_should_bind_forecast_with_default_columns(self):
type=ColumnType.FLOAT,
array_type=None,
)
plot_col_obj = ColumnCatalogEntry(
name="plot",
type=ColumnType.ANY,
array_type=None,
)
create_function_statement.query.target_list = [
TupleValueExpression(
name=id_col_obj.name, table_alias="a", col_object=id_col_obj
Expand Down Expand Up @@ -522,6 +527,7 @@ def test_bind_create_function_should_bind_forecast_with_default_columns(self):
y_col_obj,
y_lo_col_obj,
y_hi_col_obj,
plot_col_obj,
)
]
)
Expand Down Expand Up @@ -560,6 +566,11 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self):
type=ColumnType.FLOAT,
array_type=None,
)
plot_col_obj = ColumnCatalogEntry(
name="plot",
type=ColumnType.ANY,
array_type=None,
)
create_function_statement.query.target_list = [
TupleValueExpression(
name=id_col_obj.name, table_alias="a", col_object=id_col_obj
Expand Down Expand Up @@ -601,6 +612,7 @@ def test_bind_create_function_should_bind_forecast_with_renaming_columns(self):
y_col_obj,
y_lo_col_obj,
y_hi_col_obj,
plot_col_obj,
)
]
)
Expand Down