Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameterize em #9

Merged
merged 2 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .coverage
Binary file not shown.
95 changes: 47 additions & 48 deletions coverage.xml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
<?xml version="1.0" ?>
<coverage version="7.3.1" timestamp="1705032760490" lines-valid="165" lines-covered="147" line-rate="0.8909" branches-valid="54" branches-covered="49" branch-rate="0.9074" complexity="0">
<coverage version="7.3.1" timestamp="1707084314193" lines-valid="164" lines-covered="146" line-rate="0.8902" branches-valid="54" branches-covered="49" branch-rate="0.9074" complexity="0">
<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.1 -->
<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
<sources>
<source>/Users/jvivian/Library/CloudStorage/[email protected]/My Drive/projects/covid19-drDFM/covid19_drdfm</source>
<source>/home/jvivian/covid19-drDFM/covid19_drdfm</source>
</sources>
<packages>
<package name="." line-rate="0.8909" branch-rate="0.9074" complexity="0">
<package name="." line-rate="0.8902" branch-rate="0.9074" complexity="0">
<classes>
<class name="cli.py" filename="cli.py" complexity="0" line-rate="1" branch-rate="1">
<methods/>
Expand Down Expand Up @@ -37,12 +37,12 @@
<lines>
<line number="1" hits="1"/>
<line number="56" hits="1"/>
<line number="113" hits="1"/>
<line number="128" hits="1"/>
<line number="147" hits="1"/>
<line number="112" hits="1"/>
<line number="127" hits="1"/>
<line number="146" hits="1"/>
</lines>
</class>
<class name="dfm.py" filename="dfm.py" complexity="0" line-rate="0.9143" branch-rate="0.7917">
<class name="dfm.py" filename="dfm.py" complexity="0" line-rate="0.913" branch-rate="0.7917">
<methods/>
<lines>
<line number="7" hits="1"/>
Expand All @@ -61,60 +61,59 @@
<line number="28" hits="1"/>
<line number="31" hits="1"/>
<line number="41" hits="1"/>
<line number="43" hits="1"/>
<line number="42" hits="1"/>
<line number="43" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="44" hits="1"/>
<line number="45" hits="1"/>
<line number="47" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="48" hits="1"/>
<line number="46" hits="1"/>
<line number="49" hits="1"/>
<line number="50" hits="1"/>
<line number="53" hits="1"/>
<line number="62" hits="1"/>
<line number="63" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="64" hits="1"/>
<line number="58" hits="1"/>
<line number="59" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="60" hits="1"/>
<line number="61" hits="1"/>
<line number="62" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="63" hits="1"/>
<line number="65" hits="1"/>
<line number="66" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="67" hits="1"/>
<line number="66" hits="1"/>
<line number="69" hits="1"/>
<line number="70" hits="1"/>
<line number="73" hits="1"/>
<line number="81" hits="1"/>
<line number="82" hits="1"/>
<line number="84" hits="1"/>
<line number="85" hits="1"/>
<line number="86" hits="1"/>
<line number="88" hits="1"/>
<line number="89" hits="1"/>
<line number="90" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="92" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="93"/>
<line number="93" hits="0"/>
<line number="94" hits="0"/>
<line number="96" hits="1"/>
<line number="97" hits="1"/>
<line number="98" hits="1"/>
<line number="99" hits="1"/>
<line number="100" hits="0"/>
<line number="101" hits="0" branch="true" condition-coverage="0% (0/2)" missing-branches="102,103"/>
<line number="102" hits="0"/>
<line number="103" hits="0"/>
<line number="86" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="88" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="89"/>
<line number="89" hits="0"/>
<line number="90" hits="0"/>
<line number="92" hits="1"/>
<line number="93" hits="1"/>
<line number="94" hits="1"/>
<line number="95" hits="1"/>
<line number="96" hits="0"/>
<line number="97" hits="0" branch="true" condition-coverage="0% (0/2)" missing-branches="98,99"/>
<line number="98" hits="0"/>
<line number="99" hits="0"/>
<line number="101" hits="1"/>
<line number="102" hits="1"/>
<line number="105" hits="1"/>
<line number="106" hits="1"/>
<line number="109" hits="1"/>
<line number="113" hits="1"/>
<line number="114" hits="1"/>
<line number="115" hits="1"/>
<line number="116" hits="1"/>
<line number="117" hits="1"/>
<line number="118" hits="1"/>
<line number="119" hits="1"/>
<line number="120" hits="1"/>
<line number="121" hits="1"/>
<line number="122" hits="1"/>
<line number="125" hits="1"/>
<line number="135" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="140"/>
<line number="136" hits="1"/>
<line number="131" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="136"/>
<line number="132" hits="1"/>
<line number="133" hits="1"/>
<line number="135" hits="1"/>
<line number="136" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="137" hits="1"/>
<line number="138" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="139" hits="1"/>
<line number="140" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="141" hits="1"/>
<line number="140" hits="1"/>
<line number="141" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="exit"/>
<line number="142" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="143" hits="1"/>
<line number="144" hits="1"/>
<line number="145" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="exit"/>
<line number="146" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="147" hits="1"/>
</lines>
</class>
<class name="processing.py" filename="processing.py" complexity="0" line-rate="1" branch-rate="1">
Expand Down
6 changes: 2 additions & 4 deletions covid19_drdfm/dfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def state_process(df: pd.DataFrame, state: str) -> pd.DataFrame:
pd.DataFrame: Processed DataFrame, ready for model
"""
df = df[df.State == state]
#! Test double-norm
df = normalize(df).fillna(0)
#! TEST REMOVE
const_cols = [x for x in df.columns if is_constant(df[x])]
pprint(f"Constant Columns...dropping\n{const_cols}")
df = df.drop(columns=const_cols).set_index("Time", drop=True)
Expand All @@ -68,7 +66,7 @@ def get_nonstationary_columns(df: pd.DataFrame) -> list[str]:
return non_stationary_columns


def run_model(df: pd.DataFrame, state: str, outdir: Path): # -> sm.tsa.DynamicFactor:
def run_model(df: pd.DataFrame, state: str, outdir: Path, maxiter: int = 10_000): # -> sm.tsa.DynamicFactor:
"""Run DFM for a given state

Args:
Expand All @@ -94,7 +92,7 @@ def run_model(df: pd.DataFrame, state: str, outdir: Path): # -> sm.tsa.DynamicF
try:
factor_multiplicities = {"Global": 2}
model = sm.tsa.DynamicFactorMQ(df, factors=FACTORS, factor_multiplicities=factor_multiplicities)
results = model.fit(disp=10, maxiter=5_000)
results = model.fit(disp=10, maxiter=maxiter)
except Exception as e:
with open(outdir / "failed_convergence.txt", "a") as f:
f.write(f"{state}\t{e}\n")
Expand Down
12 changes: 9 additions & 3 deletions covid19_drdfm/streamlit/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def center_title(text):


def run_parameterized_model(
df: pd.DataFrame, state: str, outdir: Path, columns: list[str], global_multiplier: int = 2
df: pd.DataFrame,
state: str,
outdir: Path,
columns: list[str],
global_multiplier: int = 2,
maxiter: int = 10_000,
) -> sm.tsa.DynamicFactor:
"""Run DFM for a given state

Expand Down Expand Up @@ -134,9 +139,10 @@ def get_data():

# State selections
state_sel = st.multiselect("States", df.State.unique(), default=df.State.unique())
c1, c2 = st.columns([0.7, 0.3])
c1, c2, c3 = st.columns([0.5, 0.25, 0.25])
outdir = c1.text_input("Output Directory", value="./")
mult_sel = c2.slider("Global Multiplier", 0, 4, 2)
maxiter = c3.slider("Max EM Iterations", 1000, 20_000, 10_000, 100)

# Metrics
lengths = [len(selectors[x]) for x in selectors]
Expand Down Expand Up @@ -169,7 +175,7 @@ def get_data():
n = len(state_sel)

for i, state in enumerate(state_sel):
run_parameterized_model(df, state, outdir, columns=columns, global_multiplier=mult_sel)
run_parameterized_model(df, state, outdir, columns=columns, global_multiplier=mult_sel, maxiter=maxiter)
my_bar.progress((i + 1) / n, text=progress_text)

my_bar.empty()
Expand Down
Loading