Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add correlation matrices to data explorer Fixes #25 #30

Merged
merged 4 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion coverage.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<?xml version="1.0" ?>
<coverage version="7.3.1" timestamp="1708489476617" lines-valid="180" lines-covered="140" line-rate="0.7778" branches-valid="64" branches-covered="47" branch-rate="0.7344" complexity="0">
<coverage version="7.3.1" timestamp="1708493672279" lines-valid="182" lines-covered="138" line-rate="0.7582" branches-valid="66" branches-covered="47" branch-rate="0.7121" complexity="0">
<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.1 -->
<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
<sources>
Expand Down
25 changes: 6 additions & 19 deletions covid19_drdfm/streamlit/Dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from covid19_drdfm.constants import FACTORS
from covid19_drdfm.dfm import run_parameterized_model
from covid19_drdfm.processing import get_df
from covid19_drdfm.streamlit.plots import plot_correlations

st.set_page_config(layout="wide")
pio.templates.default = "plotly_white"
Expand All @@ -28,30 +29,16 @@ def get_data():


df = get_df()
sub = pd.Series([x for x in df.columns if x not in ["State", "Time"]], name="Variables").to_frame()
var_df = pd.Series([x for x in df.columns if x not in ["State", "Time"]], name="Variables").to_frame()
factors = FACTORS.copy()
factor_vars = list(factors.keys())
sub["Group"] = [factors[x][1] for x in sub.Variables if x in df.columns]
var_df["Group"] = [factors[x][1] for x in var_df.Variables if x in df.columns]

center_title("Dynamic Factor Model Runner")

with st.expander("Variable correlations"):
st.write("Data is normalized between [0, 1] before calculating correlation")
c1, c2 = st.columns(2)
c3, c4 = st.columns(2)
c5, c6 = st.columns(2)
for group, stcol in zip(sub.Group.unique(), [c1, c2, c3, c4, c5, c6]):
cols = sub[sub.Group == group].Variables
corr = px.imshow(
pd.DataFrame(MinMaxScaler().fit_transform(df[cols]), columns=cols).corr(),
zmin=-1,
zmax=1,
color_continuous_scale="rdbu_r",
color_continuous_midpoint=0,
)
stcol.subheader(group)
stcol.plotly_chart(corr)

plot_correlations(df, normalize=True)

with st.form("DFM Model Runner"):
st.markdown(
Expand All @@ -70,8 +57,8 @@ def get_data():
c3, c4 = st.columns(2)
c5, c6 = st.columns(2)
selectors = {}
for group, stcol in zip(sub.Group.unique(), [c1, c2, c3, c4, c5, c6]):
variables = sub[sub.Group == group].Variables
for group, stcol in zip(var_df.Group.unique(), [c1, c2, c3, c4, c5, c6]):
variables = var_df[var_df.Group == group].Variables
selectors[group] = stcol.multiselect(group, variables, variables)

# State selections
Expand Down
43 changes: 14 additions & 29 deletions covid19_drdfm/streamlit/pages/0_Data_Explorer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Raw
# Post-processed
# Normalized

from functools import reduce
from pathlib import Path

Expand All @@ -10,16 +6,15 @@
import plotly_express as px
import streamlit as st

from covid19_drdfm.constants import DIFF_COLS, LOG_DIFF_COLS, FACTORS_GROUPED
from covid19_drdfm.constants import DIFF_COLS, FACTORS_GROUPED, LOG_DIFF_COLS
from covid19_drdfm.processing import (
add_datetime,
adjust_inflation,
adjust_pandemic_response,
diff_vars,
fix_names,
normalize,
)
from covid19_drdfm.dfm import state_process
from covid19_drdfm.streamlit.plots import plot_correlations

st.set_page_config(layout="wide")
pio.templates.default = "plotly_white"
Expand All @@ -45,7 +40,7 @@ def raw_data():


@st.cache_data
def processed_data(raw_data: pd.DataFrame, state: str) -> pd.DataFrame:
def process_data(raw_data: pd.DataFrame, state: str) -> pd.DataFrame:
return (
raw_data[raw_data.State == state]
.pipe(diff_vars, cols=DIFF_COLS)
Expand All @@ -62,36 +57,26 @@ def processed_data(raw_data: pd.DataFrame, state: str) -> pd.DataFrame:
selections = ["Raw", "Processed", "Normalized"]
selection = st.sidebar.selectbox("Data Processing", selections)

proc = processed_data(raw, state)

# Filter DataFrame based on user inputs
# df_filtered = raw[(raw["State"] == state) & (raw["Time"].between(time_range[0], time_range[1]))]

# Show normalized if so
# Specify dataframe based on user choice
proc = process_data(raw, state)
df = proc if selection == "Processed" else raw
df = normalize(proc).fillna(0) if selection == "Normalized" else df[df["State"] == state]

with st.expander("Raw Data"):
st.dataframe(raw[raw.State == state])

with st.expander("Processed Data"):
st.dataframe(proc)

with st.expander("Normalized"):
norm = normalize(proc).fillna(0)
st.dataframe(norm)

df = norm if selection == "Normalized" else df[df["State"] == state]
with st.expander(f"{selection} Dataframe"):
st.dataframe(df)

# Tidy data
variables = FACTORS_GROUPED[factor] + ["Time"]

melt = df[variables].melt(
id_vars=["Time"],
var_name="Variable",
value_name="Value",
)

# Create Plotly figure
fig = px.line(melt, x="Time", y="Value", color="Variable") # replace y with your variable column

# Display Plotly figure in Streamlit
fig = px.line(melt, x="Time", y="Value", color="Variable", title=f"{selection} Data of {factor} Factor Variables")
st.plotly_chart(fig, use_container_width=True)

# Display correlations for state
st.warning(f"Correlations are calculated using {selection} dataframe")
plot_correlations(df)
36 changes: 36 additions & 0 deletions covid19_drdfm/streamlit/plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pandas as pd
import plotly.express as px
import streamlit as st
from sklearn.preprocessing import MinMaxScaler

from covid19_drdfm.constants import FACTORS


def plot_correlations(df: pd.DataFrame, normalize=False) -> None:
"""
Plots the correlations between variables in the given DataFrame.

Parameters:
- df (pd.DataFrame): The DataFrame containing the variables.

Returns:
None
"""
factors = FACTORS.copy()
var_df = pd.Series(df.drop(columns=["State", "Time"]).columns, name="Variables").to_frame()
var_df["Group"] = [factors[x][1] for x in var_df.Variables if x in df.columns]
c1, c2 = st.columns(2)
c3, c4 = st.columns(2)
c5, c6 = st.columns(2)
for group, stcol in zip(var_df.Group.unique(), [c1, c2, c3, c4, c5, c6]):
cols = var_df[var_df.Group == group].Variables
new = pd.DataFrame(MinMaxScaler().fit_transform(df[cols]), columns=cols) if normalize else df[cols]
corr = px.imshow(
new.fillna(0).corr(),
zmin=-1,
zmax=1,
color_continuous_scale="rdbu_r",
color_continuous_midpoint=0,
)
stcol.subheader(group)
stcol.plotly_chart(corr)
Loading