Skip to content

Commit

Permalink
Merge pull request #9 from jvivian/Parameterize-EM
Browse files Browse the repository at this point in the history
Parameterize em
  • Loading branch information
jvivian authored Feb 4, 2024
2 parents 3e1158b + 6c9f0c6 commit 816a937
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 55 deletions.
Binary file modified .coverage
Binary file not shown.
95 changes: 47 additions & 48 deletions coverage.xml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
<?xml version="1.0" ?>
<coverage version="7.3.1" timestamp="1705032760490" lines-valid="165" lines-covered="147" line-rate="0.8909" branches-valid="54" branches-covered="49" branch-rate="0.9074" complexity="0">
<coverage version="7.3.1" timestamp="1707084314193" lines-valid="164" lines-covered="146" line-rate="0.8902" branches-valid="54" branches-covered="49" branch-rate="0.9074" complexity="0">
<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.1 -->
<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
<sources>
<source>/Users/jvivian/Library/CloudStorage/[email protected]/My Drive/projects/covid19-drDFM/covid19_drdfm</source>
<source>/home/jvivian/covid19-drDFM/covid19_drdfm</source>
</sources>
<packages>
<package name="." line-rate="0.8909" branch-rate="0.9074" complexity="0">
<package name="." line-rate="0.8902" branch-rate="0.9074" complexity="0">
<classes>
<class name="cli.py" filename="cli.py" complexity="0" line-rate="1" branch-rate="1">
<methods/>
Expand Down Expand Up @@ -37,12 +37,12 @@
<lines>
<line number="1" hits="1"/>
<line number="56" hits="1"/>
<line number="113" hits="1"/>
<line number="128" hits="1"/>
<line number="147" hits="1"/>
<line number="112" hits="1"/>
<line number="127" hits="1"/>
<line number="146" hits="1"/>
</lines>
</class>
<class name="dfm.py" filename="dfm.py" complexity="0" line-rate="0.9143" branch-rate="0.7917">
<class name="dfm.py" filename="dfm.py" complexity="0" line-rate="0.913" branch-rate="0.7917">
<methods/>
<lines>
<line number="7" hits="1"/>
Expand All @@ -61,60 +61,59 @@
<line number="28" hits="1"/>
<line number="31" hits="1"/>
<line number="41" hits="1"/>
<line number="43" hits="1"/>
<line number="42" hits="1"/>
<line number="43" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="44" hits="1"/>
<line number="45" hits="1"/>
<line number="47" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="48" hits="1"/>
<line number="46" hits="1"/>
<line number="49" hits="1"/>
<line number="50" hits="1"/>
<line number="53" hits="1"/>
<line number="62" hits="1"/>
<line number="63" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="64" hits="1"/>
<line number="58" hits="1"/>
<line number="59" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="60" hits="1"/>
<line number="61" hits="1"/>
<line number="62" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="63" hits="1"/>
<line number="65" hits="1"/>
<line number="66" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="67" hits="1"/>
<line number="66" hits="1"/>
<line number="69" hits="1"/>
<line number="70" hits="1"/>
<line number="73" hits="1"/>
<line number="81" hits="1"/>
<line number="82" hits="1"/>
<line number="84" hits="1"/>
<line number="85" hits="1"/>
<line number="86" hits="1"/>
<line number="88" hits="1"/>
<line number="89" hits="1"/>
<line number="90" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="92" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="93"/>
<line number="93" hits="0"/>
<line number="94" hits="0"/>
<line number="96" hits="1"/>
<line number="97" hits="1"/>
<line number="98" hits="1"/>
<line number="99" hits="1"/>
<line number="100" hits="0"/>
<line number="101" hits="0" branch="true" condition-coverage="0% (0/2)" missing-branches="102,103"/>
<line number="102" hits="0"/>
<line number="103" hits="0"/>
<line number="86" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="88" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="89"/>
<line number="89" hits="0"/>
<line number="90" hits="0"/>
<line number="92" hits="1"/>
<line number="93" hits="1"/>
<line number="94" hits="1"/>
<line number="95" hits="1"/>
<line number="96" hits="0"/>
<line number="97" hits="0" branch="true" condition-coverage="0% (0/2)" missing-branches="98,99"/>
<line number="98" hits="0"/>
<line number="99" hits="0"/>
<line number="101" hits="1"/>
<line number="102" hits="1"/>
<line number="105" hits="1"/>
<line number="106" hits="1"/>
<line number="109" hits="1"/>
<line number="113" hits="1"/>
<line number="114" hits="1"/>
<line number="115" hits="1"/>
<line number="116" hits="1"/>
<line number="117" hits="1"/>
<line number="118" hits="1"/>
<line number="119" hits="1"/>
<line number="120" hits="1"/>
<line number="121" hits="1"/>
<line number="122" hits="1"/>
<line number="125" hits="1"/>
<line number="135" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="140"/>
<line number="136" hits="1"/>
<line number="131" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="136"/>
<line number="132" hits="1"/>
<line number="133" hits="1"/>
<line number="135" hits="1"/>
<line number="136" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="137" hits="1"/>
<line number="138" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="139" hits="1"/>
<line number="140" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="141" hits="1"/>
<line number="140" hits="1"/>
<line number="141" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="exit"/>
<line number="142" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="143" hits="1"/>
<line number="144" hits="1"/>
<line number="145" hits="1" branch="true" condition-coverage="50% (1/2)" missing-branches="exit"/>
<line number="146" hits="1" branch="true" condition-coverage="100% (2/2)"/>
<line number="147" hits="1"/>
</lines>
</class>
<class name="processing.py" filename="processing.py" complexity="0" line-rate="1" branch-rate="1">
Expand Down
6 changes: 2 additions & 4 deletions covid19_drdfm/dfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def state_process(df: pd.DataFrame, state: str) -> pd.DataFrame:
pd.DataFrame: Processed DataFrame, ready for model
"""
df = df[df.State == state]
#! Test double-norm
df = normalize(df).fillna(0)
#! TEST REMOVE
const_cols = [x for x in df.columns if is_constant(df[x])]
pprint(f"Constant Columns...dropping\n{const_cols}")
df = df.drop(columns=const_cols).set_index("Time", drop=True)
Expand All @@ -68,7 +66,7 @@ def get_nonstationary_columns(df: pd.DataFrame) -> list[str]:
return non_stationary_columns


def run_model(df: pd.DataFrame, state: str, outdir: Path): # -> sm.tsa.DynamicFactor:
def run_model(df: pd.DataFrame, state: str, outdir: Path, maxiter: int = 10_000): # -> sm.tsa.DynamicFactor:
"""Run DFM for a given state
Args:
Expand All @@ -94,7 +92,7 @@ def run_model(df: pd.DataFrame, state: str, outdir: Path): # -> sm.tsa.DynamicF
try:
factor_multiplicities = {"Global": 2}
model = sm.tsa.DynamicFactorMQ(df, factors=FACTORS, factor_multiplicities=factor_multiplicities)
results = model.fit(disp=10, maxiter=5_000)
results = model.fit(disp=10, maxiter=maxiter)
except Exception as e:
with open(outdir / "failed_convergence.txt", "a") as f:
f.write(f"{state}\t{e}\n")
Expand Down
12 changes: 9 additions & 3 deletions covid19_drdfm/streamlit/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def center_title(text):


def run_parameterized_model(
df: pd.DataFrame, state: str, outdir: Path, columns: list[str], global_multiplier: int = 2
df: pd.DataFrame,
state: str,
outdir: Path,
columns: list[str],
global_multiplier: int = 2,
maxiter: int = 10_000,
) -> sm.tsa.DynamicFactor:
"""Run DFM for a given state
Expand Down Expand Up @@ -134,9 +139,10 @@ def get_data():

# State selections
state_sel = st.multiselect("States", df.State.unique(), default=df.State.unique())
c1, c2 = st.columns([0.7, 0.3])
c1, c2, c3 = st.columns([0.5, 0.25, 0.25])
outdir = c1.text_input("Output Directory", value="./")
mult_sel = c2.slider("Global Multiplier", 0, 4, 2)
maxiter = c3.slider("Max EM Iterations", 1000, 20_000, 10_000, 100)

# Metrics
lengths = [len(selectors[x]) for x in selectors]
Expand Down Expand Up @@ -169,7 +175,7 @@ def get_data():
n = len(state_sel)

for i, state in enumerate(state_sel):
run_parameterized_model(df, state, outdir, columns=columns, global_multiplier=mult_sel)
run_parameterized_model(df, state, outdir, columns=columns, global_multiplier=mult_sel, maxiter=maxiter)
my_bar.progress((i + 1) / n, text=progress_text)

my_bar.empty()
Expand Down

0 comments on commit 816a937

Please sign in to comment.