diff --git a/.coverage b/.coverage index 859debd..157772b 100644 Binary files a/.coverage and b/.coverage differ diff --git a/coverage.xml b/coverage.xml index 53bb6cb..7405ca9 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,12 +1,12 @@ - + - /Users/jvivian/Library/CloudStorage/GoogleDrive-jtvivian@gmail.com/My Drive/projects/covid19-drDFM/covid19_drdfm + /home/jvivian/covid19-drDFM/covid19_drdfm - + @@ -37,12 +37,12 @@ - - - + + + - + @@ -61,60 +61,59 @@ - + + + - - + - - - - - + + + + + + - - + - - + + + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - + + + + - - - - - - + + + + + + - - + + - - - - diff --git a/covid19_drdfm/dfm.py b/covid19_drdfm/dfm.py index 51f46e0..986903b 100644 --- a/covid19_drdfm/dfm.py +++ b/covid19_drdfm/dfm.py @@ -39,9 +39,7 @@ def state_process(df: pd.DataFrame, state: str) -> pd.DataFrame: pd.DataFrame: Processed DataFrame, ready for model """ df = df[df.State == state] - #! Test double-norm df = normalize(df).fillna(0) - #! TEST REMOVE const_cols = [x for x in df.columns if is_constant(df[x])] pprint(f"Constant Columns...dropping\n{const_cols}") df = df.drop(columns=const_cols).set_index("Time", drop=True) @@ -68,7 +66,7 @@ def get_nonstationary_columns(df: pd.DataFrame) -> list[str]: return non_stationary_columns -def run_model(df: pd.DataFrame, state: str, outdir: Path): # -> sm.tsa.DynamicFactor: +def run_model(df: pd.DataFrame, state: str, outdir: Path, maxiter: int = 10_000): # -> sm.tsa.DynamicFactor: """Run DFM for a given state Args: @@ -94,7 +92,7 @@ def run_model(df: pd.DataFrame, state: str, outdir: Path): # -> sm.tsa.DynamicF try: factor_multiplicities = {"Global": 2} model = sm.tsa.DynamicFactorMQ(df, factors=FACTORS, factor_multiplicities=factor_multiplicities) - results = model.fit(disp=10, maxiter=5_000) + results = model.fit(disp=10, maxiter=maxiter) except Exception as e: with open(outdir / "failed_convergence.txt", "a") as f: f.write(f"{state}\t{e}\n") diff --git a/covid19_drdfm/streamlit/runner.py b/covid19_drdfm/streamlit/runner.py index 89ca996..fc768e8 100644 --- a/covid19_drdfm/streamlit/runner.py +++ b/covid19_drdfm/streamlit/runner.py @@ -31,7 +31,12 @@ def center_title(text): def run_parameterized_model( - df: pd.DataFrame, state: str, outdir: Path, columns: list[str], global_multiplier: int = 2 + df: pd.DataFrame, + state: str, + outdir: Path, + columns: list[str], + global_multiplier: int = 2, + maxiter: int = 10_000, ) -> sm.tsa.DynamicFactor: """Run DFM for a given state @@ -134,9 +139,10 @@ def get_data(): # State selections state_sel = st.multiselect("States", df.State.unique(), default=df.State.unique()) - c1, c2 = st.columns([0.7, 0.3]) + c1, c2, c3 = st.columns([0.5, 0.25, 0.25]) outdir = c1.text_input("Output Directory", value="./") mult_sel = c2.slider("Global Multiplier", 0, 4, 2) + maxiter = c3.slider("Max EM Iterations", 1000, 20_000, 10_000, 100) # Metrics lengths = [len(selectors[x]) for x in selectors] @@ -169,7 +175,7 @@ def get_data(): n = len(state_sel) for i, state in enumerate(state_sel): - run_parameterized_model(df, state, outdir, columns=columns, global_multiplier=mult_sel) + run_parameterized_model(df, state, outdir, columns=columns, global_multiplier=mult_sel, maxiter=maxiter) my_bar.progress((i + 1) / n, text=progress_text) my_bar.empty()