Skip to content

Commit

Permalink
Merge pull request #88 from jvivian/jvivian/issue87
Browse files Browse the repository at this point in the history
Make fixes based on validation data
  • Loading branch information
jvivian authored Jun 15, 2024
2 parents 0d285dc + ff86ef4 commit ce9419c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 44 deletions.
7 changes: 3 additions & 4 deletions covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def file_uploader(self) -> "DataHandler":
st.stop()
self.batch_col = st.sidebar.selectbox("Select a batch column (optional):", ["None", *list(self.df.columns)])
if self.batch_col == "None":
self.batch_col = None
self.df["Batch"] = "Batch1"
# self.batch_col = None
self.non_batch_cols = [col for col in self.df.columns if col != self.batch_col]
return self

Expand Down Expand Up @@ -84,9 +85,7 @@ def load_data(file) -> pd.DataFrame:
return read_function(file)

def apply_transforms(self) -> "DataHandler":
options = st.multiselect(
"Select columns to apply transformations:", self.non_batch_cols, format_func=lambda x: f"Transform {x}"
)
options = st.multiselect("Select columns to apply transformations:", self.non_batch_cols)
transforms = {}
for i, opt in enumerate(options):
if i % 2 == 0:
Expand Down
44 changes: 4 additions & 40 deletions covid19_drdfm/streamlit/pages/1_Factor_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,68 +63,31 @@ def get_factors(res_dir):
# Grab first state to fetch valid variables
state_df = pd.read_csv(res_dir / state / "df.csv")
cols = [x for x in df.columns if x in state_df.columns] + ["State"]
# st.write(cols)
# Get factors and columns
# factor_vars = [x for x in FACTORS_GROUPED[factor.split("_")[1]] if x in valid_cols] # and '_' in x]
# columns = [*factor_vars, "State", "Time"]

# Make df from res_dir
dfs = []
for subdir in res_dir.iterdir():
if not subdir.is_dir():
continue
path = subdir / "df.csv"
if not path.exists():
st.write(f"Skipping {path}, not found")
continue
sub = pd.read_csv(path, index_col=0)
sub["State"] = subdir.stem
dfs.append(sub)

new = pd.concat(dfs)
new = pd.read_csv(res_dir / state / "df.csv", index_col=0)

# Normalize original data for state / valid variables
ad = ann.read_h5ad(res_dir / "data.h5ad")
factor_map = ad.var["factor"].to_frame()
factor_set = factor_map["factor"].unique().to_list() + [x for x in df.columns if "Global" in x]
# st.dataframe(factor_map)
factor = st.sidebar.selectbox("Factor", factor_set)
# new = ad.to_df().reset_index()
# new["State"] = ad.obs["State"].to_list()
# new = normalize(new[new.State == state])


# Normalize factors and add to new dataframe
if st.sidebar.checkbox("Invert Factor"):
df[factor] = df[factor] * -1
df = normalize(df[df.State == state]) # .reset_index(drop=True)
df = normalize(df[df.State == state])

df = df[df["State"] == state]
df = df[[factor]].join(new, on="Time")

# st.write(factor_map)
# st.dataframe(df.head())
# st.dataframe(new.head())

# Coerce time bullshit to get dates standardized
# df["Time"] = pd.to_datetime(df["Time"]).dt.date
# new["Time"] = pd.to_datetime(new["Time"]).dt.date
col_opts = [x for x in df.columns.to_list() if x != "State"]
cols = st.multiselect("Variables to plot", col_opts, default=col_opts)
with st.expander("Graph Data"):
# factor_cols = factor_map[factor_map["factor"] == factor]
# if factor_cols.empty:
# factor_cols = new.columns
# else:
# factor_cols = factor_cols.index.to_list()
# factor_cols += [factor]
# factor_cols = [x for x in factor_cols if x in df.columns]
# st.write(factor_cols)
st.dataframe(df[cols])

df = df[cols].reset_index()

# Melt into format for plotting
# melted_df = df.drop(columns="State").melt(id_vars=["Time"], value_name="value")
melted_df = df.melt(id_vars=["Time"], value_name="value")
melted_df["Label"] = [5 if x == factor else 1 for x in melted_df.variable]

Expand All @@ -136,6 +99,7 @@ def get_factors(res_dir):
results_path = res_dir / state / "results.csv"
model_path = res_dir / state / "model.csv"

# Metrics for run
values = pd.Series()
with open(results_path) as f:
for line in f.readlines():
Expand Down

0 comments on commit ce9419c

Please sign in to comment.