Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamically generate Dashboard (resolves #78) #85

Merged
merged 2 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions covid19_drdfm/data/processed/Test_Data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
Time,Pandemic_10,Pandemic_9,Pandemic_1,Pandemic_7,Factor
1/1/2020,0,0,0,0,0.2045454545
2/1/2020,0,0,0,0,0
3/1/2020,0.0058740602,0.0397350993,0.0042229655,0.0609357998,0
4/1/2020,0.0693139098,0.0198675497,0.0024131231,0.0282916213,0.0619834711
5/1/2020,0.1379229323,0.0264900662,0.0025728151,0.0304678999,0.0289256198
6/1/2020,0.1734022556,0.0264900662,0.0108767976,0.0631120783,0.0309917355
7/1/2020,0.4814379699,0.1059602649,0.040277864,0.1828073993,0.0640495868
8/1/2020,0.313674812,0.1258278146,0.0370130504,0.204570185,0.1838842975
9/1/2020,0.1459116541,0.1125827815,0.0532484009,0.1566920566,0.2045454545
10/1/2020,0.0773026316,0.3112582781,0.1582192571,0.4243743199,0.1570247934
11/1/2020,0.155075188,0.5894039735,0.3026162868,0.752992383,0.423553719
12/1/2020,0.5227913534,0.5033112583,0.1744191205,0.5571273123,0.7541322314
1/1/2021,1,0.178807947,0.0982904087,0.2176278564,0.5578512397
2/1/2021,0.6719924812,0.0993377483,0.0558034724,0.1523394995,0.2190082645
3/1/2021,0.2319078947,0.1125827815,0.072748565,0.2241566921,0.152892562
4/1/2021,0.0838815789,0.0066225166,0.0778320928,0.289445049,0.2252066116
5/1/2021,0.0714285714,0.1390728477,0.0299511165,0.1284004353,0.2892561983
6/1/2021,0.0723684211,0.0529801325,0.0149666865,0.0957562568,0.1301652893
7/1/2021,0.0728383459,0.0794701987,0.0817711614,0.39390642,0.0950413223
8/1/2021,0.1268796992,0.3245033113,0.26157545,0.8019586507,0.3925619835
9/1/2021,0.2814849624,0.8278145695,0.4060079669,1,0.8037190083
10/1/2021,0.274906015,0.9470198675,0.3689594294,0.8335146899,1
11/1/2021,0.2709116541,1,0.1806914662,0.4494015234,0.8347107438
12/1/2021,0.4515977444,0.1059602649,0.1362261238,0.2655059848,0.4504132231
1/1/2022,0.4642857143,0.2847682119,1,0.478781284,0.2665289256
2/1/2022,0.4090695489,0.5165562914,0.2854316563,0.2491838955,0.479338843
3/1/2022,0.3106203008,0.4304635762,0.1077388504,0.1284004353,0.25
4/1/2022,0.1604793233,0.1589403974,0.0940142126,0.0718171926,0.1280991736
5/1/2022,0.0817669173,0.2185430464,0.1257574279,0.1099020675,0.0723140496
6/1/2022,0.0507518797,0,0.1572433617,0.1077257889,0.1095041322
7/1/2022,0.0594454887,0.1523178808,0.1569505931,0.1088139282,0.1074380165
8/1/2022,0.0812969925,0.1920529801,0.0951143128,0.112078346,0.1095041322
9/1/2022,0.0601503759,0.1655629139,0.0211769298,0.0402611534,0.1115702479
10/1/2022,0,0,0.0135738176,0.0293797606,0.041322314
11/1/2022,0,0,0.0134850999,0.0054406964,0.0289256198
12/1/2022,0,0,0.0121365899,0.0087051143,0.0061983471
4 changes: 2 additions & 2 deletions covid19_drdfm/streamlit/Dashboard.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import yaml
import time
from pathlib import Path

import pandas as pd
import plotly.io as pio
import streamlit as st
import yaml

from covid19_drdfm.constants import FACTORS
from covid19_drdfm.covid19 import get_df, get_project_h5ad
Expand Down Expand Up @@ -32,7 +32,7 @@ def get_data():
var_df["Variables"] = var_df.index
ad.obs["Time"] = pd.to_datetime(ad.obs.index)

center_title("Dynamic Factor Model Runner")
center_title("Legacy Dynamic Factor Model Runner for Covid-19")

with st.expander("Variable correlations"):
st.write("Data is normalized between [0, 1] before calculating correlation")
Expand Down
179 changes: 179 additions & 0 deletions covid19_drdfm/streamlit/pages/0_Dynamic_Factor_Model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from pathlib import Path
from typing import Optional

import anndata as ann
import pandas as pd
import plotly.io as pio
import streamlit as st

from covid19_drdfm.dfm import ModelRunner

st.set_page_config(layout="wide")
pio.templates.default = "plotly_white"


def center_title(text):
return st.markdown(f"<h1 style='text-align: center; color: grey;'>{text}</h1>", unsafe_allow_html=True)


class DataHandler:
"""
Handles data loading and preprocessing for a Streamlit application.
"""

def __init__(self):
self.df: Optional[pd.DataFrame] = None
self.ad: Optional[ann.AnnData] = None
self.batch_col: Optional[str] = None
self.non_batch_cols: Optional[list[str]] = None

def get_data(self) -> "DataHandler":
self.file_uploader().get_factor_mappings().apply_transforms().create_anndata()
return self

def file_uploader(self) -> "DataHandler":
"""
Uploads a file and reads it into a DataFrame. Supported file types are CSV, TSV, and XLSX.

Returns:
A pandas DataFrame loaded from the uploaded file.

Raises:
RuntimeError: If no file is uploaded.
"""
file = st.file_uploader("Upload a data file (CSV, TSV, XLSX)", type=["csv", "tsv", "xlsx"])
if file is None:
st.error("Please provide input file")
st.stop()
self.df = self.load_data(file)
with st.expander("Raw Input Data"):
st.dataframe(self.df)
if self.df is None:
st.error("DataFrame is empty! Check input data")
st.stop()
batch_col = st.sidebar.selectbox("Select a batch column (optional):", ["None", *list(self.df.columns)])
if batch_col == "None":
self.batch_col = None
self.non_batch_cols = [col for col in self.df.columns if col != batch_col]
return self

@staticmethod
def load_data(file) -> pd.DataFrame:
"""
Loads a DataFrame from an uploaded file based on its MIME type.

Args:
file: UploadedFile object from Streamlit.

Returns:
A DataFrame containing the data from the file.

Raises:
ValueError: If the file type is unsupported.
"""
file_type = file.type.split("/")[-1]
read_function = {
"csv": lambda f: pd.read_csv(f, index_col=0),
"tsv": lambda f: pd.read_csv(f, index_col=0, sep="\t"),
"xlsx": lambda f: pd.read_excel(f, index_col=0),
}.get(file_type, lambda _: None)

if read_function is None:
raise ValueError(f"Unsupported file type: {file_type}")

return read_function(file)

def apply_transforms(self) -> "DataHandler":
options = st.multiselect(
"Select columns to apply transformations:", self.non_batch_cols, format_func=lambda x: f"Transform {x}"
)
transforms = {}
for i, opt in enumerate(options):
if i % 2 == 0:
cols = st.columns(2)
transform = cols[i % 2].radio(
f"Select transform type for {opt}:", ("difference", "logdiff"), key=f"trans_{opt}"
)
transforms[opt] = transform
self.ad.var[transform] = None
self.ad.var.loc[opt, transform] = True
return self

def get_factor_mappings(self) -> "DataHandler":
factor_input = st.text_input("Enter all factor options separated by space:")
factor_options = factor_input.split()
if not factor_options:
st.warning("Enter at least one factor to assign to variables")
st.stop()
factor_mappings = {}
for i, col in enumerate(self.non_batch_cols):
if i % 2 == 0:
cols = st.columns(2)
col_factor = cols[i % 2].radio(
f"Select factor for {col}:",
options=factor_options,
key=col,
format_func=lambda x: f"{x}",
horizontal=True,
)
if col_factor:
factor_mappings[col] = col_factor

if len(factor_mappings) != len(self.non_batch_cols):
st.warning("Select a factor for each variable!")
st.stop()
self.factor_mappings = factor_mappings
return self

def create_anndata(self) -> ann.AnnData:
"""
Creates an AnnData object from the loaded DataFrame with optional batch column handling.

Args:
factor_mappings: A dictionary mapping column names to their respective factors.
batch_col: Optional; the name of the column to use as the batch category.

Returns:
An AnnData object with additional metadata.
"""
if self.batch_col and self.batch_col in self.df.columns:
ad = ann.AnnData(self.df.drop(columns=self.batch_col))
ad.obs[self.batch_col] = self.df[self.batch_col]
else:
ad = ann.AnnData(self.df)

ad.var["factor"] = [self.factor_mappings[x] for x in ad.var.index]
self.ad = ad
return ad


def additional_params():
global_multiplier = st.sidebar.slider("Global Multiplier", min_value=0, max_value=4, value=0)
out_dir = st.sidebar.text_input("Output Directory", value=None)
if not out_dir:
st.warning("Specify output directory (in sidebar) to continue")
st.stop()
return global_multiplier, out_dir


def run_model(ad, out_dir, batch, global_multiplier) -> ModelRunner:
dfm = ModelRunner(ad, Path(out_dir), batch=batch)
dfm.run(global_multiplier=global_multiplier)
st.subheader("Results")
for result in dfm.results:
if batch is not None:
st.subheader(result.name)
st.write(result.result.summary())
st.divider()
st.write(result.model.summary())
return dfm


center_title("Dynamic Factor Model Runner")
data = DataHandler().get_data()
ad = data.ad
global_multiplier, out_dir = additional_params()
batch = None if ad.obs.empty else ad.obs.columns[0]
dfm = run_model(ad, out_dir, batch, global_multiplier)
st.balloons()
st.stop()
Loading