diff --git a/hawk/analysis/main.py b/hawk/analysis/main.py index 2d51add..e56aae8 100644 --- a/hawk/analysis/main.py +++ b/hawk/analysis/main.py @@ -63,7 +63,8 @@ def __init__( self.tefs_features_lags = [] if self.tefs_use_contemporary_features: self.tefs_features_lags.append(0) - self.tefs_features_lags.extend(list(range(1, self.tefs_max_lag_features + 1))) + if self.tefs_max_lag_features > 0: + self.tefs_features_lags.extend(list(range(1, self.tefs_max_lag_features + 1))) self.tefs_target_lags = list(range(1, self.tefs_max_lag_target + 1)) diff --git a/hawk/processes/wps_causal.py b/hawk/processes/wps_causal.py index 4f549ae..b92d47b 100644 --- a/hawk/processes/wps_causal.py +++ b/hawk/processes/wps_causal.py @@ -227,12 +227,15 @@ def _handler(self, request, response): tefs_direction = request.inputs["tefs_direction"][0].data tefs_use_contemporary_features = request.inputs["tefs_use_contemporary_features"][0].data - tefs_max_lag_features = int(request.inputs["tefs_max_lag_features"][0].data) + if str(request.inputs["tefs_max_lag_features"][0].data) == "no_lag": + tefs_max_lag_features = 0 + else: + tefs_max_lag_features = int(request.inputs["tefs_max_lag_features"][0].data) tefs_max_lag_target = int(request.inputs["tefs_max_lag_target"][0].data) workdir = Path(self.workdir) - if not tefs_use_contemporary_features and tefs_max_lag_features == "no_lag": + if not tefs_use_contemporary_features and tefs_max_lag_features == 0: raise ValueError("You cannot use no lag features and not use contemporary features in TEFS.") causal_analysis = CausalAnalysis( diff --git a/tests/test_causal_analysis_noLag.py b/tests/test_causal_analysis_noLag.py new file mode 100644 index 0000000..dfc88d7 --- /dev/null +++ b/tests/test_causal_analysis_noLag.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python + +"""Tests for `hawk` package.""" + +import os + +import pandas as pd +import pytest +from click.testing import CliRunner # noqa: F401 + +import hawk # noqa: F401 +from hawk import cli # noqa: F401 +from hawk.analysis import CausalAnalysis + + +@pytest.fixture +def response(): + """Sample pytest fixture. + + See more at: http://doc.pytest.org/en/latest/fixture.html + """ + # import requests + # return requests.get('https://github.com/audreyr/cookiecutter-pypackage') + + +def test_content(response): + """Sample pytest test function with the pytest fixture as an argument.""" + # from bs4 import BeautifulSoup + # assert 'GitHub' in BeautifulSoup(response.content).title.string + + +def test_causal_analysis_noLag(): + df_train = pd.read_csv("hawk/demo/Ticino_train.csv", header=0) + df_test = pd.read_csv("hawk/demo/Ticino_test.csv", header=0) + target_column_name = "target" + pcmci_test_choice = "ParCorr" + pcmci_max_lag = 2 + tefs_direction = "both" + tefs_use_contemporary_features = True + tefs_max_lag_features = "no_lag" + tefs_max_lag_target = 1 + workdir = "tests/output" + + if str(tefs_max_lag_features) == "no_lag": + tefs_max_lag_features = 0 + else: + tefs_max_lag_features = int(tefs_max_lag_features) + + if not tefs_use_contemporary_features and tefs_max_lag_features == 0: + raise ValueError("You cannot use no lag features and not use contemporary features in TEFS.") + + causal_analysis = CausalAnalysis( + df_train, + df_test, + target_column_name, + pcmci_test_choice, + pcmci_max_lag, + tefs_direction, + tefs_use_contemporary_features, + tefs_max_lag_features, + tefs_max_lag_target, + workdir, + response=None, + ) + + causal_analysis.run() + + os.system("rm -r tests/output")