Skip to content

Commit 99ae6ef

Browse files
M-R-Schaeferpre-commit-ci[bot]PythonFZ
authored
Change PredictionMetrics from pred vs true to error vs true (#305)
* added optional ylim parameter to predictionMetrics and plots * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update .pre-commit-config.yaml * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Fabian Zills <[email protected]>
1 parent 0592e8e commit 99ae6ef

File tree

3 files changed

+64
-18
lines changed

3 files changed

+64
-18
lines changed

Diff for: .pre-commit-config.yaml

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ repos:
3333
- id: codespell
3434
additional_dependencies: ["tomli"]
3535
- repo: https://github.com/astral-sh/ruff-pre-commit
36-
# Ruff version.
37-
rev: 'v0.1.3'
36+
rev: v0.4.9
3837
hooks:
3938
- id: ruff
39+
args: [--fix]
40+
- id: ruff-format
4041
- repo: https://github.com/executablebooks/mdformat
4142
rev: 0.7.17
4243
hooks:

Diff for: ipsuite/analysis/model/plots.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ def density_scatter(ax, x, y, bins, **kwargs) -> None:
4646

4747

4848
def get_figure(
49-
true, prediction, datalabel: str, xlabel: str, ylabel: str, figsize: tuple = (10, 7)
49+
true,
50+
prediction,
51+
datalabel: str,
52+
xlabel: str,
53+
ylabel: str,
54+
ymax: typing.Optional[float] = None,
55+
figsize: tuple = (10, 7),
5056
) -> plt.Figure:
5157
"""Create a correlation plot for true, prediction values.
5258
@@ -66,7 +72,7 @@ def get_figure(
6672
"""
6773
sns.set()
6874
fig, ax = plt.subplots(figsize=figsize)
69-
ax.plot(true, true, color="grey", zorder=0) # plot the diagonal in the background
75+
ax.plot(true, np.zeros_like(true), color="grey", zorder=0)
7076
bins = 25
7177
if true.shape[0] < 20:
7278
# don't use density for very small datasets
@@ -77,6 +83,8 @@ def get_figure(
7783
)
7884
ax.set_xlabel(xlabel)
7985
ax.set_ylabel(ylabel)
86+
if ymax:
87+
ax.set_ylim([-ymax, ymax])
8088
ax.legend()
8189
return fig
8290

Diff for: ipsuite/analysis/model/predict.py

+51-14
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,18 @@ class PredictionMetrics(base.ComparePredictions):
5656
- energy: meV/atom
5757
- forces: meV/Å
5858
- stress: eV/Å^3
59+
60+
Attributes
61+
----------
62+
ymax: dict of label key, and figure ylim values.
63+
Should be set when trying to compare different models.
64+
5965
"""
6066

67+
# TODO ADD OPTIONAL YMAX PARAMETER
68+
69+
figure_ymax: dict[str, float] = zntrack.params({})
70+
6171
data_file = zntrack.outs_path(zntrack.nwd / "data.npz")
6272

6373
energy: dict = zntrack.metrics()
@@ -91,6 +101,7 @@ def get_data(self):
91101
energy_prediction = [x.get_potential_energy() / len(x) for x in self.y]
92102
energy_prediction = np.array(energy_prediction) * 1000
93103
self.content["energy_pred"] = energy_prediction
104+
self.content["energy_error"] = energy_true - energy_prediction
94105

95106
if "forces" in true_keys and "forces" in pred_keys:
96107
true_forces = [x.get_forces() for x in self.x]
@@ -100,6 +111,9 @@ def get_data(self):
100111
pred_forces = [x.get_forces() for x in self.y]
101112
pred_forces = np.concatenate(pred_forces, axis=0) * 1000
102113
self.content["forces_pred"] = np.reshape(pred_forces, (-1,))
114+
self.content["forces_error"] = (
115+
self.content["forces_true"] - self.content["forces_pred"]
116+
)
103117

104118
if "stress" in true_keys and "stress" in pred_keys:
105119
true_stress = np.array([x.get_stress(voigt=False) for x in self.x])
@@ -109,10 +123,19 @@ def get_data(self):
109123

110124
self.content["stress_true"] = np.reshape(true_stress, (-1,))
111125
self.content["stress_pred"] = np.reshape(pred_stress, (-1,))
126+
self.content["stress_error"] = (
127+
self.content["stress_true"] - self.content["stress_pred"]
128+
)
112129
self.content["stress_hydro_true"] = np.reshape(hydro_true, (-1,))
113130
self.content["stress_hydro_pred"] = np.reshape(hydro_pred, (-1,))
131+
self.content["stress_hydro_error"] = (
132+
self.content["stress_hydro_true"] - self.content["stress_hydro_pred"]
133+
)
114134
self.content["stress_deviat_true"] = np.reshape(deviat_true, (-1,))
115135
self.content["stress_deviat_pred"] = np.reshape(deviat_pred, (-1,))
136+
self.content["stress_deviat_error"] = (
137+
self.content["stress_deviat_true"] - self.content["stress_deviat_pred"]
138+
)
116139

117140
def get_metrics(self):
118141
"""Update the metrics."""
@@ -146,57 +169,71 @@ def get_plots(self, save=False):
146169
"""Create figures for all available data."""
147170
self.plots_dir.mkdir(exist_ok=True)
148171

172+
e_ymax = self.figure_ymax.get("energy", None)
149173
energy_plot = get_figure(
150174
self.content["energy_true"],
151-
self.content["energy_pred"],
175+
self.content["energy_error"],
152176
datalabel=f"MAE: {self.energy['mae']:.2f} meV/atom",
153177
xlabel=r"$ab~initio$ energy $E$ / meV/atom",
154-
ylabel=r"predicted energy $E$ / meV/atom",
178+
ylabel=r"$\Delta E$ / meV/atom",
179+
ymax=e_ymax,
155180
)
156181
if save:
157182
energy_plot.savefig(self.plots_dir / "energy.png")
158183

159184
if "forces_true" in self.content:
160-
xlabel = r"$ab~initio$ force components per atom $|F|$ / meV$ \cdot \AA^{-1}$"
161-
ylabel = r"predicted force components per atom $|F|$ / meV$ \cdot \AA^{-1}$"
185+
xlabel = (
186+
r"$ab~initio$ force components per atom $F_{alpha,i}$ / meV$ \cdot"
187+
r" \AA^{-1}$"
188+
)
189+
ylabel = r"$\Delta F_{alpha,i}$ / meV$ \cdot \AA^{-1}$"
190+
f_ymax = self.figure_ymax.get("forces", None)
162191
forces_plot = get_figure(
163192
self.content["forces_true"],
164-
self.content["forces_pred"],
193+
self.content["forces_error"],
165194
datalabel=rf"MAE: {self.forces['mae']:.2f} meV$ / \AA$",
166195
xlabel=xlabel,
167196
ylabel=ylabel,
197+
ymax=f_ymax,
168198
)
169199
if save:
170200
forces_plot.savefig(self.plots_dir / "forces.png")
171201

172202
if "stress_true" in self.content:
173203
s_true = self.content["stress_true"]
174-
s_pred = self.content["stress_pred"]
204+
s_error = self.content["stress_error"]
175205
shydro_true = self.content["stress_hydro_true"]
176-
shydro_pred = self.content["stress_hydro_pred"]
206+
shydro_error = self.content["stress_hydro_error"]
177207
sdeviat_true = self.content["stress_deviat_true"]
178-
sdeviat_pred = self.content["stress_deviat_pred"]
208+
sdeviat_error = self.content["stress_deviat_error"]
209+
210+
s_ymax = self.figure_ymax.get("stress", None)
211+
hs_ymax = self.figure_ymax.get("stress_hydro", None)
212+
ds_ymax = self.figure_ymax.get("stress_deviat", None)
179213

180214
stress_plot = get_figure(
181215
s_true,
182-
s_pred,
216+
s_error,
183217
datalabel=rf"Max: {self.stress['max']:.4f}",
184218
xlabel=r"$ab~initio$ stress",
185-
ylabel=r"predicted stress",
219+
ylabel=r"$\Delta$ stress",
220+
ymax=s_ymax,
186221
)
187222
hydrostatic_stress_plot = get_figure(
188223
shydro_true,
189-
shydro_pred,
224+
shydro_error,
190225
datalabel=rf"Max: {self.stress_hydro['max']:.4f}",
191226
xlabel=r"$ab~initio$ hydrostatic stress",
192-
ylabel=r"predicted hydrostatic stress",
227+
ylabel=r"$\Delta$ hydrostatic stress",
228+
ymax=hs_ymax,
193229
)
194230
deviatoric_stress_plot = get_figure(
195231
sdeviat_true,
196-
sdeviat_pred,
232+
sdeviat_error,
197233
datalabel=rf"Max: {self.stress_deviat['max']:.4f}",
198234
xlabel=r"$ab~initio$ deviatoric stress",
199-
ylabel=r"predicted deviatoric stress",
235+
ylabel=r"$\Delta$ deviatoric stress",
236+
ymax=ds_ymax,
200237
)
201238
if save:
202239
stress_plot.savefig(self.plots_dir / "stress.png")

0 commit comments

Comments
 (0)