diff --git a/.coverage b/.coverage index 4ed3ac3..dee71bc 100644 Binary files a/.coverage and b/.coverage differ diff --git a/coverage.xml b/coverage.xml index 0e8833b..2bc0a41 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,12 +1,12 @@ - + /home/jvivian/covid19-drDFM/covid19_drdfm - + @@ -32,7 +32,7 @@ - + @@ -51,48 +51,45 @@ - - + + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - + + + @@ -107,43 +104,38 @@ - - - - - - - - - - - - - + + + + + + + + + - - - - - - - - - - + + + + + + + + + + + + + + - + - - - - - - - - - + + + + diff --git a/tests/test_dfm.py b/tests/test_dfm.py index ed74d6e..608b437 100644 --- a/tests/test_dfm.py +++ b/tests/test_dfm.py @@ -2,12 +2,12 @@ from pathlib import Path from covid19_drdfm.dfm import run_model -from covid19_drdfm.processing import run - +from covid19_drdfm.processing import get_df +# TODO: output should go in a directory instead of dumping shit everywhere def test_run_model(): - raw = run() - run_model(raw, "NY", "./") + df = get_df() + run_model(df, "NY", Path("./")) assert Path("./model.csv").exists() assert Path("./results.csv").exists() os.remove("./model.csv") diff --git a/tests/test_processing.py b/tests/test_processing.py index 3c9b698..8309344 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -1,5 +1,7 @@ import pandas as pd import pytest +from functools import reduce + from covid19_drdfm.processing import ( add_datetime, @@ -7,10 +9,19 @@ adjust_pandemic_response, get_df, get_govt_fund_dist, - run, + DATA_DIR, + ROOT_DIR, ) +@pytest.fixture +def raw_data() -> pd.DataFrame: + with open(DATA_DIR / "df_paths.txt") as f: + paths = [ROOT_DIR / x.strip() for x in f.readlines()] + dfs = [pd.read_csv(x) for x in paths] + return reduce(lambda x, y: pd.merge(x, y, on=["State", "Year", "Period"], how="left"), dfs).fillna(0) + + # Fixture to load test data @pytest.fixture def sample_data() -> pd.DataFrame: @@ -39,14 +50,14 @@ def test_adjust_pandemic_response(sample_data): assert df[r].sum() == out[r].sum() -def test_fix_datetime(sample_data): - input_df = sample_data.copy() +def test_fix_datetime(raw_data): + input_df = raw_data.copy() output_df = add_datetime(input_df) assert isinstance(output_df["Time"][0], pd.Timestamp) def test_run(): - df = run() + df = get_df() expected_columns = ["State", "Supply_1", "Demand_1", "Pandemic_Response_13", "Time"] assert all(col in df.columns for col in expected_columns)