Skip to content

Commit

Permalink
add missing site assets: single model energy parity svelte figs and J…
Browse files Browse the repository at this point in the history
…SON hull dist errors for GRACE-{1,2}L-OAM

- Generate new energy parity plots for GRACE-1L-OAM, GRACE-2L-MPTRJ, and GRACE-2L-OAM models
- Remove deprecated GRACE2L-R6 energy parity plots
- Modify single_model_parity_energy.py to support command-line model selection
  • Loading branch information
janosh committed Feb 7, 2025
1 parent 7a0f8c8 commit 012ccfe
Show file tree
Hide file tree
Showing 11 changed files with 17 additions and 7 deletions.
File renamed without changes.
14 changes: 10 additions & 4 deletions scripts/model_figs/single_model_parity_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# %%
import itertools
import os
import sys
from typing import Literal, get_args

import pymatviz as pmv
Expand Down Expand Up @@ -34,14 +35,19 @@
df_preds = df_preds.query(MbdKey.uniq_proto)
df_metrics = df_metrics_uniq_protos

# Get list of models from command line args, fall back to all models if none specified
models_to_update = sys.argv[1:] if len(sys.argv) > 1 else df_metrics


# %% parity plot of actual vs predicted e_form_per_atom
parity_scatters_dir = f"{SITE_FIGS}/energy-parity"
os.makedirs(parity_scatters_dir, exist_ok=True)

for model_name, which_energy in itertools.product(df_metrics, (use_e_form, use_each)):
model_key = Model.from_label(model_name).key
img_name = f"{which_energy}-parity-{model_name.lower().replace(' ', '-')}"
for model_name, which_energy in itertools.product(
models_to_update, (use_e_form, use_each)
):
model = Model[model_name]
img_name = f"{which_energy}-parity-{model.key.lower().replace(' ', '-')}"
img_path = f"{parity_scatters_dir}/{img_name}.svelte"
if os.path.isfile(img_path) and not update_existing:
continue
Expand All @@ -59,7 +65,7 @@
raise ValueError(f"Unexpected {which_energy=}")

e_pred_col = f"{model_name} {e_true_col.label.replace('DFT ', '')}"
df_in = df_in.rename(columns={model_name: e_pred_col})
df_in = df_in.rename(columns={model.label: e_pred_col})

fig = pmv.density_scatter_plotly(
df=df_in.reset_index(drop=True),
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

This file was deleted.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion site/src/figs/energy-parity/each-parity-grace2l-r6.svelte

This file was deleted.

2 changes: 1 addition & 1 deletion site/src/figs/per-element-each-errors.json

Large diffs are not rendered by default.

0 comments on commit 012ccfe

Please sign in to comment.