Skip to content

Commit

Permalink
attempt fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
americast committed Nov 11, 2023
1 parent 823e3fe commit c94f1f6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ jobs:
- checkout
- run:
name: Install EvaDB package from GitHub repo and run tests
no_output_timeout: 40m # 40 minute timeout
no_output_timeout: 30m # 30 minute timeout
command: |
python -m venv test_evadb
source test_evadb/bin/activate
Expand Down
6 changes: 1 addition & 5 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def handle_forecasting_function(self):
AutoNHITS,
AutoPatchTST,
AutoTFT,
AutoTimesNet,
)

# from neuralforecast.auto import AutoAutoformer as AutoAFormer
Expand All @@ -436,7 +435,6 @@ def handle_forecasting_function(self):
FEDformer,
Informer,
PatchTST,
TimesNet,
)

# from neuralforecast.models import Autoformer as AFormer
Expand All @@ -456,8 +454,6 @@ def handle_forecasting_function(self):
# "AutoAFormer": AutoAFormer,
"Informer": Informer,
"AutoInformer": AutoInformer,
"TimesNet": TimesNet,
"AutoTimesNet": AutoTimesNet,
"TFT": TFT,
"AutoTFT": AutoTFT,
}
Expand Down Expand Up @@ -546,7 +542,7 @@ def get_optuna_config(trial):
raise FunctionIODefinitionError(err_msg)

model = StatsForecast(
[model_here(season_length=season_length)], freq=new_freq, n_jobs=-1
[model_here(season_length=season_length)], freq=new_freq
)

data["ds"] = pd.to_datetime(data["ds"])
Expand Down
46 changes: 23 additions & 23 deletions test/integration_tests/long/test_model_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,55 +80,55 @@ def tearDownClass(cls):
@forecast_skip_marker
def test_forecast(self):
create_predict_udf = """
CREATE FUNCTION AirForecast FROM
(SELECT unique_id, ds, y FROM AirData)
CREATE FUNCTION AirPanelForecast FROM
(SELECT unique_id, ds, y, trend FROM AirDataPanel)
TYPE Forecasting
HORIZON 12
PREDICT 'y';
PREDICT 'y'
LIBRARY 'neuralforecast'
AUTO 'false'
FREQUENCY 'M';
"""
execute_query_fetch_all(self.evadb, create_predict_udf)

predict_query = """
SELECT AirForecast() order by y;
SELECT AirPanelForecast() order by y;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 12)
self.assertEqual(len(result), 24)
self.assertEqual(
result.columns,
[
"airforecast.unique_id",
"airforecast.ds",
"airforecast.y",
"airforecast.y-lo",
"airforecast.y-hi",
"airpanelforecast.unique_id",
"airpanelforecast.ds",
"airpanelforecast.y",
"airpanelforecast.y-lo",
"airpanelforecast.y-hi",
],
)

create_predict_udf = """
CREATE FUNCTION AirPanelForecast FROM
(SELECT unique_id, ds, y, trend FROM AirDataPanel)
CREATE FUNCTION AirForecast FROM
(SELECT unique_id, ds, y FROM AirData)
TYPE Forecasting
HORIZON 12
PREDICT 'y'
LIBRARY 'neuralforecast'
AUTO 'false'
FREQUENCY 'M';
PREDICT 'y';
"""
execute_query_fetch_all(self.evadb, create_predict_udf)

predict_query = """
SELECT AirPanelForecast() order by y;
SELECT AirForecast() order by y;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 24)
self.assertEqual(len(result), 12)
self.assertEqual(
result.columns,
[
"airpanelforecast.unique_id",
"airpanelforecast.ds",
"airpanelforecast.y",
"airpanelforecast.y-lo",
"airpanelforecast.y-hi",
"airforecast.unique_id",
"airforecast.ds",
"airforecast.y",
"airforecast.y-lo",
"airforecast.y-hi",
],
)

Expand Down

0 comments on commit c94f1f6

Please sign in to comment.