From 2e6f6fa397b74e936caef87210ea8f7f5f14298c Mon Sep 17 00:00:00 2001 From: jvivian Date: Sat, 8 Jun 2024 21:56:41 -0700 Subject: [PATCH 1/2] Dynamically generate Dashboard based on `AnnData` object Fixes #78 --- covid19_drdfm/data/processed/Test_Data.csv | 37 +++++++ .../streamlit/pages/0_Dynamic_Factor_Model.py | 99 +++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 covid19_drdfm/data/processed/Test_Data.csv create mode 100644 covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py diff --git a/covid19_drdfm/data/processed/Test_Data.csv b/covid19_drdfm/data/processed/Test_Data.csv new file mode 100644 index 0000000..79976a7 --- /dev/null +++ b/covid19_drdfm/data/processed/Test_Data.csv @@ -0,0 +1,37 @@ +Time,Pandemic_10,Pandemic_9,Pandemic_1,Pandemic_7,Factor +1/1/2020,0,0,0,0,0.2045454545 +2/1/2020,0,0,0,0,0 +3/1/2020,0.0058740602,0.0397350993,0.0042229655,0.0609357998,0 +4/1/2020,0.0693139098,0.0198675497,0.0024131231,0.0282916213,0.0619834711 +5/1/2020,0.1379229323,0.0264900662,0.0025728151,0.0304678999,0.0289256198 +6/1/2020,0.1734022556,0.0264900662,0.0108767976,0.0631120783,0.0309917355 +7/1/2020,0.4814379699,0.1059602649,0.040277864,0.1828073993,0.0640495868 +8/1/2020,0.313674812,0.1258278146,0.0370130504,0.204570185,0.1838842975 +9/1/2020,0.1459116541,0.1125827815,0.0532484009,0.1566920566,0.2045454545 +10/1/2020,0.0773026316,0.3112582781,0.1582192571,0.4243743199,0.1570247934 +11/1/2020,0.155075188,0.5894039735,0.3026162868,0.752992383,0.423553719 +12/1/2020,0.5227913534,0.5033112583,0.1744191205,0.5571273123,0.7541322314 +1/1/2021,1,0.178807947,0.0982904087,0.2176278564,0.5578512397 +2/1/2021,0.6719924812,0.0993377483,0.0558034724,0.1523394995,0.2190082645 +3/1/2021,0.2319078947,0.1125827815,0.072748565,0.2241566921,0.152892562 +4/1/2021,0.0838815789,0.0066225166,0.0778320928,0.289445049,0.2252066116 +5/1/2021,0.0714285714,0.1390728477,0.0299511165,0.1284004353,0.2892561983 +6/1/2021,0.0723684211,0.0529801325,0.0149666865,0.0957562568,0.1301652893 +7/1/2021,0.0728383459,0.0794701987,0.0817711614,0.39390642,0.0950413223 +8/1/2021,0.1268796992,0.3245033113,0.26157545,0.8019586507,0.3925619835 +9/1/2021,0.2814849624,0.8278145695,0.4060079669,1,0.8037190083 +10/1/2021,0.274906015,0.9470198675,0.3689594294,0.8335146899,1 +11/1/2021,0.2709116541,1,0.1806914662,0.4494015234,0.8347107438 +12/1/2021,0.4515977444,0.1059602649,0.1362261238,0.2655059848,0.4504132231 +1/1/2022,0.4642857143,0.2847682119,1,0.478781284,0.2665289256 +2/1/2022,0.4090695489,0.5165562914,0.2854316563,0.2491838955,0.479338843 +3/1/2022,0.3106203008,0.4304635762,0.1077388504,0.1284004353,0.25 +4/1/2022,0.1604793233,0.1589403974,0.0940142126,0.0718171926,0.1280991736 +5/1/2022,0.0817669173,0.2185430464,0.1257574279,0.1099020675,0.0723140496 +6/1/2022,0.0507518797,0,0.1572433617,0.1077257889,0.1095041322 +7/1/2022,0.0594454887,0.1523178808,0.1569505931,0.1088139282,0.1074380165 +8/1/2022,0.0812969925,0.1920529801,0.0951143128,0.112078346,0.1095041322 +9/1/2022,0.0601503759,0.1655629139,0.0211769298,0.0402611534,0.1115702479 +10/1/2022,0,0,0.0135738176,0.0293797606,0.041322314 +11/1/2022,0,0,0.0134850999,0.0054406964,0.0289256198 +12/1/2022,0,0,0.0121365899,0.0087051143,0.0061983471 diff --git a/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py b/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py new file mode 100644 index 0000000..c457335 --- /dev/null +++ b/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py @@ -0,0 +1,99 @@ +import time +from pathlib import Path + +import pandas as pd +import plotly.io as pio +import streamlit as st +import yaml + +from covid19_drdfm.constants import FACTORS +from covid19_drdfm.dfm import ModelRunner +import anndata as ann + +st.set_page_config(layout="wide") +pio.templates.default = "plotly_white" + + +def center_title(text): + return st.markdown(f"

{text}

", unsafe_allow_html=True) + + +def load_data(file): + if "csv" in file.type: + return pd.read_csv(file, index_col=0) + elif "tsv" in file.type: + return pd.read_csv(file, index_col=0, sep="\t") + elif "xlsx" in file.type: + return pd.read_excel(file, index_col=0) + else: + return None + + +def create_anndata(df, factor_mappings, batch_col=None): + if batch_col: + adata = ann.AnnData(df.drop(columns=batch_col)) + adata.obs[batch_col] = df[batch_col] + else: + adata = ann.AnnData(df) + adata.var["factor"] = [factor_mappings[x] for x in adata.var.index] + return adata + + +def file_uploader(): + # File uploader + file = st.file_uploader("Upload a data file (CSV, TSV, XLSX)", type=["csv", "tsv", "xlsx"]) + if file is None: + st.error("Please provide input file") + st.stop() + df = load_data(file) + with st.expander("Raw Input Data"): + st.dataframe(df) + if df is not None: + # Optional batch column + batch_col = st.selectbox("Select a batch column (optional):", ["None"] + list(df.columns)) + if batch_col == "None": + batch_col = None + + # Ask for non-batch variables and their factor mappings + non_batch_cols = [col for col in df.columns if col != batch_col] + factor_mappings = {} + for col in non_batch_cols: + factor = st.text_input(f"Enter factor for {col}:", key=col) + if factor: + # factor_cats = factor.split(",") + # factor_mappings[col] = pd.Categorical(df[col], categories=factor_cats, ordered=True) + factor_mappings[col] = factor + if len(factor_mappings) != len(non_batch_cols): + st.warning("Fill in a Factor label for all variables!") + st.stop() + + # Create anndata + ad = create_anndata(df, factor_mappings, batch_col) + + # Transformations + options = st.multiselect( + "Select columns to apply transformations:", non_batch_cols, format_func=lambda x: f"Transform {x}" + ) + transforms = {} + for opt in options: + transform = st.radio(f"Select transform type for {opt}:", ("difference", "logdiff"), key=f"trans_{opt}") + transforms[opt] = transform + ad.var[transform] = None + ad.var.loc[opt, transform] = True + + # Show anndata and transforms + st.write("Anndata object:", ad) + st.dataframe(ad.var) + return ad + + +ad = file_uploader() + +global_multiplier = st.slider("Global Multiplier", min_value=0, max_value=4, value=0) +outdir = st.text_input("Location of output!", value=None) +if not outdir: + st.stop() +batch = None if ad.obs.empty else ad.obs.columns[0] +dfm = ModelRunner(ad, Path(outdir), batch=batch) +dfm.run(global_multiplier=global_multiplier) +st.write(dfm.results) From 401f2c434d6e13c304969cabaf7bf00d58fd4c41 Mon Sep 17 00:00:00 2001 From: jvivian Date: Sun, 9 Jun 2024 16:22:35 -0700 Subject: [PATCH 2/2] Refactor dynamic factor model runner based off user input resolves #78 --- covid19_drdfm/streamlit/Dashboard.py | 4 +- .../streamlit/pages/0_Dynamic_Factor_Model.py | 228 ++++++++++++------ 2 files changed, 156 insertions(+), 76 deletions(-) diff --git a/covid19_drdfm/streamlit/Dashboard.py b/covid19_drdfm/streamlit/Dashboard.py index fa72324..f1af655 100644 --- a/covid19_drdfm/streamlit/Dashboard.py +++ b/covid19_drdfm/streamlit/Dashboard.py @@ -1,10 +1,10 @@ -import yaml import time from pathlib import Path import pandas as pd import plotly.io as pio import streamlit as st +import yaml from covid19_drdfm.constants import FACTORS from covid19_drdfm.covid19 import get_df, get_project_h5ad @@ -32,7 +32,7 @@ def get_data(): var_df["Variables"] = var_df.index ad.obs["Time"] = pd.to_datetime(ad.obs.index) -center_title("Dynamic Factor Model Runner") +center_title("Legacy Dynamic Factor Model Runner for Covid-19") with st.expander("Variable correlations"): st.write("Data is normalized between [0, 1] before calculating correlation") diff --git a/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py b/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py index c457335..e79c412 100644 --- a/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py +++ b/covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py @@ -1,14 +1,12 @@ -import time from pathlib import Path +from typing import Optional +import anndata as ann import pandas as pd import plotly.io as pio import streamlit as st -import yaml -from covid19_drdfm.constants import FACTORS from covid19_drdfm.dfm import ModelRunner -import anndata as ann st.set_page_config(layout="wide") pio.templates.default = "plotly_white" @@ -18,82 +16,164 @@ def center_title(text): return st.markdown(f"

{text}

", unsafe_allow_html=True) -def load_data(file): - if "csv" in file.type: - return pd.read_csv(file, index_col=0) - elif "tsv" in file.type: - return pd.read_csv(file, index_col=0, sep="\t") - elif "xlsx" in file.type: - return pd.read_excel(file, index_col=0) - else: - return None - - -def create_anndata(df, factor_mappings, batch_col=None): - if batch_col: - adata = ann.AnnData(df.drop(columns=batch_col)) - adata.obs[batch_col] = df[batch_col] - else: - adata = ann.AnnData(df) - adata.var["factor"] = [factor_mappings[x] for x in adata.var.index] - return adata - - -def file_uploader(): - # File uploader - file = st.file_uploader("Upload a data file (CSV, TSV, XLSX)", type=["csv", "tsv", "xlsx"]) - if file is None: - st.error("Please provide input file") - st.stop() - df = load_data(file) - with st.expander("Raw Input Data"): - st.dataframe(df) - if df is not None: - # Optional batch column - batch_col = st.selectbox("Select a batch column (optional):", ["None"] + list(df.columns)) - if batch_col == "None": - batch_col = None +class DataHandler: + """ + Handles data loading and preprocessing for a Streamlit application. + """ - # Ask for non-batch variables and their factor mappings - non_batch_cols = [col for col in df.columns if col != batch_col] - factor_mappings = {} - for col in non_batch_cols: - factor = st.text_input(f"Enter factor for {col}:", key=col) - if factor: - # factor_cats = factor.split(",") - # factor_mappings[col] = pd.Categorical(df[col], categories=factor_cats, ordered=True) - factor_mappings[col] = factor - if len(factor_mappings) != len(non_batch_cols): - st.warning("Fill in a Factor label for all variables!") + def __init__(self): + self.df: Optional[pd.DataFrame] = None + self.ad: Optional[ann.AnnData] = None + self.batch_col: Optional[str] = None + self.non_batch_cols: Optional[list[str]] = None + + def get_data(self) -> "DataHandler": + self.file_uploader().get_factor_mappings().apply_transforms().create_anndata() + return self + + def file_uploader(self) -> "DataHandler": + """ + Uploads a file and reads it into a DataFrame. Supported file types are CSV, TSV, and XLSX. + + Returns: + A pandas DataFrame loaded from the uploaded file. + + Raises: + RuntimeError: If no file is uploaded. + """ + file = st.file_uploader("Upload a data file (CSV, TSV, XLSX)", type=["csv", "tsv", "xlsx"]) + if file is None: + st.error("Please provide input file") st.stop() + self.df = self.load_data(file) + with st.expander("Raw Input Data"): + st.dataframe(self.df) + if self.df is None: + st.error("DataFrame is empty! Check input data") + st.stop() + batch_col = st.sidebar.selectbox("Select a batch column (optional):", ["None", *list(self.df.columns)]) + if batch_col == "None": + self.batch_col = None + self.non_batch_cols = [col for col in self.df.columns if col != batch_col] + return self - # Create anndata - ad = create_anndata(df, factor_mappings, batch_col) + @staticmethod + def load_data(file) -> pd.DataFrame: + """ + Loads a DataFrame from an uploaded file based on its MIME type. - # Transformations - options = st.multiselect( - "Select columns to apply transformations:", non_batch_cols, format_func=lambda x: f"Transform {x}" - ) - transforms = {} - for opt in options: - transform = st.radio(f"Select transform type for {opt}:", ("difference", "logdiff"), key=f"trans_{opt}") - transforms[opt] = transform - ad.var[transform] = None - ad.var.loc[opt, transform] = True + Args: + file: UploadedFile object from Streamlit. - # Show anndata and transforms - st.write("Anndata object:", ad) - st.dataframe(ad.var) - return ad + Returns: + A DataFrame containing the data from the file. + Raises: + ValueError: If the file type is unsupported. + """ + file_type = file.type.split("/")[-1] + read_function = { + "csv": lambda f: pd.read_csv(f, index_col=0), + "tsv": lambda f: pd.read_csv(f, index_col=0, sep="\t"), + "xlsx": lambda f: pd.read_excel(f, index_col=0), + }.get(file_type, lambda _: None) -ad = file_uploader() + if read_function is None: + raise ValueError(f"Unsupported file type: {file_type}") -global_multiplier = st.slider("Global Multiplier", min_value=0, max_value=4, value=0) -outdir = st.text_input("Location of output!", value=None) -if not outdir: - st.stop() + return read_function(file) + + def apply_transforms(self) -> "DataHandler": + options = st.multiselect( + "Select columns to apply transformations:", self.non_batch_cols, format_func=lambda x: f"Transform {x}" + ) + transforms = {} + for i, opt in enumerate(options): + if i % 2 == 0: + cols = st.columns(2) + transform = cols[i % 2].radio( + f"Select transform type for {opt}:", ("difference", "logdiff"), key=f"trans_{opt}" + ) + transforms[opt] = transform + self.ad.var[transform] = None + self.ad.var.loc[opt, transform] = True + return self + + def get_factor_mappings(self) -> "DataHandler": + factor_input = st.text_input("Enter all factor options separated by space:") + factor_options = factor_input.split() + if not factor_options: + st.warning("Enter at least one factor to assign to variables") + st.stop() + factor_mappings = {} + for i, col in enumerate(self.non_batch_cols): + if i % 2 == 0: + cols = st.columns(2) + col_factor = cols[i % 2].radio( + f"Select factor for {col}:", + options=factor_options, + key=col, + format_func=lambda x: f"{x}", + horizontal=True, + ) + if col_factor: + factor_mappings[col] = col_factor + + if len(factor_mappings) != len(self.non_batch_cols): + st.warning("Select a factor for each variable!") + st.stop() + self.factor_mappings = factor_mappings + return self + + def create_anndata(self) -> ann.AnnData: + """ + Creates an AnnData object from the loaded DataFrame with optional batch column handling. + + Args: + factor_mappings: A dictionary mapping column names to their respective factors. + batch_col: Optional; the name of the column to use as the batch category. + + Returns: + An AnnData object with additional metadata. + """ + if self.batch_col and self.batch_col in self.df.columns: + ad = ann.AnnData(self.df.drop(columns=self.batch_col)) + ad.obs[self.batch_col] = self.df[self.batch_col] + else: + ad = ann.AnnData(self.df) + + ad.var["factor"] = [self.factor_mappings[x] for x in ad.var.index] + self.ad = ad + return ad + + +def additional_params(): + global_multiplier = st.sidebar.slider("Global Multiplier", min_value=0, max_value=4, value=0) + out_dir = st.sidebar.text_input("Output Directory", value=None) + if not out_dir: + st.warning("Specify output directory (in sidebar) to continue") + st.stop() + return global_multiplier, out_dir + + +def run_model(ad, out_dir, batch, global_multiplier) -> ModelRunner: + dfm = ModelRunner(ad, Path(out_dir), batch=batch) + dfm.run(global_multiplier=global_multiplier) + st.subheader("Results") + for result in dfm.results: + if batch is not None: + st.subheader(result.name) + st.write(result.result.summary()) + st.divider() + st.write(result.model.summary()) + return dfm + + +center_title("Dynamic Factor Model Runner") +data = DataHandler().get_data() +ad = data.ad +global_multiplier, out_dir = additional_params() batch = None if ad.obs.empty else ad.obs.columns[0] -dfm = ModelRunner(ad, Path(outdir), batch=batch) -dfm.run(global_multiplier=global_multiplier) -st.write(dfm.results) +dfm = run_model(ad, out_dir, batch, global_multiplier) +st.balloons() +st.stop()