From ff86ef4505a412b6201c1dcc192b888f8b029eb7 Mon Sep 17 00:00:00 2001 From: jvivian Date: Fri, 14 Jun 2024 22:42:56 -0700 Subject: [PATCH] Make fixes based on validation data Fixes #87 --- .../streamlit/pages/0_Dynamic_Factor_Model.py | 7 ++- .../streamlit/pages/1_Factor_Analysis.py | 44 ++----------------- 2 files changed, 7 insertions(+), 44 deletions(-) diff --git a/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py b/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py index 709885c..e7aa9c2 100644 --- a/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py +++ b/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py @@ -53,7 +53,8 @@ def file_uploader(self) -> "DataHandler": st.stop() self.batch_col = st.sidebar.selectbox("Select a batch column (optional):", ["None", *list(self.df.columns)]) if self.batch_col == "None": - self.batch_col = None + self.df["Batch"] = "Batch1" + # self.batch_col = None self.non_batch_cols = [col for col in self.df.columns if col != self.batch_col] return self @@ -84,9 +85,7 @@ def load_data(file) -> pd.DataFrame: return read_function(file) def apply_transforms(self) -> "DataHandler": - options = st.multiselect( - "Select columns to apply transformations:", self.non_batch_cols, format_func=lambda x: f"Transform {x}" - ) + options = st.multiselect("Select columns to apply transformations:", self.non_batch_cols) transforms = {} for i, opt in enumerate(options): if i % 2 == 0: diff --git a/covid19_drdfm/streamlit/pages/1_Factor_Analysis.py b/covid19_drdfm/streamlit/pages/1_Factor_Analysis.py index bce7f69..8d5d075 100644 --- a/covid19_drdfm/streamlit/pages/1_Factor_Analysis.py +++ b/covid19_drdfm/streamlit/pages/1_Factor_Analysis.py @@ -63,68 +63,31 @@ def get_factors(res_dir): # Grab first state to fetch valid variables state_df = pd.read_csv(res_dir / state / "df.csv") cols = [x for x in df.columns if x in state_df.columns] + ["State"] -# st.write(cols) -# Get factors and columns -# factor_vars = [x for x in FACTORS_GROUPED[factor.split("_")[1]] if x in valid_cols] # and '_' in x] -# columns = [*factor_vars, "State", "Time"] - -# Make df from res_dir -dfs = [] -for subdir in res_dir.iterdir(): - if not subdir.is_dir(): - continue - path = subdir / "df.csv" - if not path.exists(): - st.write(f"Skipping {path}, not found") - continue - sub = pd.read_csv(path, index_col=0) - sub["State"] = subdir.stem - dfs.append(sub) - -new = pd.concat(dfs) +new = pd.read_csv(res_dir / state / "df.csv", index_col=0) # Normalize original data for state / valid variables ad = ann.read_h5ad(res_dir / "data.h5ad") factor_map = ad.var["factor"].to_frame() factor_set = factor_map["factor"].unique().to_list() + [x for x in df.columns if "Global" in x] -# st.dataframe(factor_map) factor = st.sidebar.selectbox("Factor", factor_set) -# new = ad.to_df().reset_index() -# new["State"] = ad.obs["State"].to_list() -# new = normalize(new[new.State == state]) + # Normalize factors and add to new dataframe if st.sidebar.checkbox("Invert Factor"): df[factor] = df[factor] * -1 -df = normalize(df[df.State == state]) # .reset_index(drop=True) +df = normalize(df[df.State == state]) df = df[df["State"] == state] df = df[[factor]].join(new, on="Time") -# st.write(factor_map) -# st.dataframe(df.head()) -# st.dataframe(new.head()) - -# Coerce time bullshit to get dates standardized -# df["Time"] = pd.to_datetime(df["Time"]).dt.date -# new["Time"] = pd.to_datetime(new["Time"]).dt.date col_opts = [x for x in df.columns.to_list() if x != "State"] cols = st.multiselect("Variables to plot", col_opts, default=col_opts) with st.expander("Graph Data"): - # factor_cols = factor_map[factor_map["factor"] == factor] - # if factor_cols.empty: - # factor_cols = new.columns - # else: - # factor_cols = factor_cols.index.to_list() - # factor_cols += [factor] - # factor_cols = [x for x in factor_cols if x in df.columns] - # st.write(factor_cols) st.dataframe(df[cols]) df = df[cols].reset_index() # Melt into format for plotting -# melted_df = df.drop(columns="State").melt(id_vars=["Time"], value_name="value") melted_df = df.melt(id_vars=["Time"], value_name="value") melted_df["Label"] = [5 if x == factor else 1 for x in melted_df.variable] @@ -136,6 +99,7 @@ def get_factors(res_dir): results_path = res_dir / state / "results.csv" model_path = res_dir / state / "model.csv" +# Metrics for run values = pd.Series() with open(results_path) as f: for line in f.readlines():