diff --git a/dfmdash/streamlit/pages/4_Synthetic_Control_Model.py b/dfmdash/streamlit/pages/4_Synthetic_Control_Model.py index 191ce2f..ba5cbe1 100644 --- a/dfmdash/streamlit/pages/4_Synthetic_Control_Model.py +++ b/dfmdash/streamlit/pages/4_Synthetic_Control_Model.py @@ -69,6 +69,8 @@ def get_results(df, outcome_var, treatment_time, treated_unit): outcome_var = st.selectbox("Outcome Variable", df.columns) treatment_time = st.date_input("Treatment Time", value=min_time, min_value=min_time, max_value=max_time) + invert = st.checkbox("Invert Factor") + submit = st.form_submit_button() if not submit: @@ -109,6 +111,10 @@ def get_results(df, outcome_var, treatment_time, treated_unit): ys = synth[0, :] y = df[df.State == treated_unit][outcome_var] +if invert: + y *= -1 + ys *= -1 + st.subheader("Original") fig = go.Figure() fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=f"{treated_unit} {outcome_var}")) @@ -123,20 +129,31 @@ def get_results(df, outcome_var, treatment_time, treated_unit): normalized_treated_outcome = normalized_treated_outcome[:, 0] normalized_synth = np.zeros(data.periods_all) most_extreme_value = np.max(np.absolute(normalized_treated_outcome)) + +if invert: + normalized_treated_outcome *= -1 + normalized_synth *= -1 fig = go.Figure() -fig.add_trace(go.Scatter(x=x, y=normalized_treated_outcome, mode="lines", name=f"Synthetic {outcome_var}")) fig.add_trace( - go.Scatter(x=x, y=normalized_synth, mode="lines", name=f"{treated_unit} {outcome_var}", line=dict(dash="dot")) + go.Scatter( + x=x, + y=normalized_synth, + mode="lines", + name=f"{treated_unit} {outcome_var}", + ) +) +fig.add_trace( + go.Scatter(x=x, y=normalized_treated_outcome, mode="lines", name=f"Synthetic {outcome_var}", line=dict(dash="dot")) ) fig.add_vline(treatment_time, name=f"{outcome_var} Response") st.plotly_chart(fig, use_container_width=True) -st.subheader("Cumulative") -cumulative_effect = np.cumsum(normalized_treated_outcome[data.periods_pre_treatment :]) -cummulative_treated_outcome = np.concatenate((np.zeros(data.periods_pre_treatment), cumulative_effect), axis=None) +# st.subheader("Cumulative") +# cumulative_effect = np.cumsum(normalized_treated_outcome[data.periods_pre_treatment :]) +# cummulative_treated_outcome = np.concatenate((np.zeros(data.periods_pre_treatment), cumulative_effect), axis=None) -fig = go.Figure() -fig.add_trace(go.Scatter(x=x, y=normalized_synth, mode="lines", name=f"Synthetic {outcome_var}", line=dict(dash="dot"))) -fig.add_trace(go.Scatter(x=x, y=cummulative_treated_outcome, mode="lines", name=f"{treated_unit} {outcome_var}")) -fig.add_vline(treatment_time, name=f"{outcome_var} Response") -st.plotly_chart(fig, use_container_width=True) +# fig = go.Figure() +# fig.add_trace(go.Scatter(x=x, y=normalized_synth, mode="lines", name=f"Synthetic {outcome_var}", line=dict(dash="dot"))) +# fig.add_trace(go.Scatter(x=x, y=cummulative_treated_outcome, mode="lines", name=f"{treated_unit} {outcome_var}")) +# fig.add_vline(treatment_time, name=f"{outcome_var} Response") +# st.plotly_chart(fig, use_container_width=True)