@@ -56,8 +56,18 @@ class PredictionMetrics(base.ComparePredictions):
56
56
- energy: meV/atom
57
57
- forces: meV/Å
58
58
- stress: eV/Å^3
59
+
60
+ Attributes
61
+ ----------
62
+ ymax: dict of label key, and figure ylim values.
63
+ Should be set when trying to compare different models.
64
+
59
65
"""
60
66
67
+ # TODO ADD OPTIONAL YMAX PARAMETER
68
+
69
+ figure_ymax : dict [str , float ] = zntrack .params ({})
70
+
61
71
data_file = zntrack .outs_path (zntrack .nwd / "data.npz" )
62
72
63
73
energy : dict = zntrack .metrics ()
@@ -91,6 +101,7 @@ def get_data(self):
91
101
energy_prediction = [x .get_potential_energy () / len (x ) for x in self .y ]
92
102
energy_prediction = np .array (energy_prediction ) * 1000
93
103
self .content ["energy_pred" ] = energy_prediction
104
+ self .content ["energy_error" ] = energy_true - energy_prediction
94
105
95
106
if "forces" in true_keys and "forces" in pred_keys :
96
107
true_forces = [x .get_forces () for x in self .x ]
@@ -100,6 +111,9 @@ def get_data(self):
100
111
pred_forces = [x .get_forces () for x in self .y ]
101
112
pred_forces = np .concatenate (pred_forces , axis = 0 ) * 1000
102
113
self .content ["forces_pred" ] = np .reshape (pred_forces , (- 1 ,))
114
+ self .content ["forces_error" ] = (
115
+ self .content ["forces_true" ] - self .content ["forces_pred" ]
116
+ )
103
117
104
118
if "stress" in true_keys and "stress" in pred_keys :
105
119
true_stress = np .array ([x .get_stress (voigt = False ) for x in self .x ])
@@ -109,10 +123,19 @@ def get_data(self):
109
123
110
124
self .content ["stress_true" ] = np .reshape (true_stress , (- 1 ,))
111
125
self .content ["stress_pred" ] = np .reshape (pred_stress , (- 1 ,))
126
+ self .content ["stress_error" ] = (
127
+ self .content ["stress_true" ] - self .content ["stress_pred" ]
128
+ )
112
129
self .content ["stress_hydro_true" ] = np .reshape (hydro_true , (- 1 ,))
113
130
self .content ["stress_hydro_pred" ] = np .reshape (hydro_pred , (- 1 ,))
131
+ self .content ["stress_hydro_error" ] = (
132
+ self .content ["stress_hydro_true" ] - self .content ["stress_hydro_pred" ]
133
+ )
114
134
self .content ["stress_deviat_true" ] = np .reshape (deviat_true , (- 1 ,))
115
135
self .content ["stress_deviat_pred" ] = np .reshape (deviat_pred , (- 1 ,))
136
+ self .content ["stress_deviat_error" ] = (
137
+ self .content ["stress_deviat_true" ] - self .content ["stress_deviat_pred" ]
138
+ )
116
139
117
140
def get_metrics (self ):
118
141
"""Update the metrics."""
@@ -146,57 +169,71 @@ def get_plots(self, save=False):
146
169
"""Create figures for all available data."""
147
170
self .plots_dir .mkdir (exist_ok = True )
148
171
172
+ e_ymax = self .figure_ymax .get ("energy" , None )
149
173
energy_plot = get_figure (
150
174
self .content ["energy_true" ],
151
- self .content ["energy_pred " ],
175
+ self .content ["energy_error " ],
152
176
datalabel = f"MAE: { self .energy ['mae' ]:.2f} meV/atom" ,
153
177
xlabel = r"$ab~initio$ energy $E$ / meV/atom" ,
154
- ylabel = r"predicted energy $E$ / meV/atom" ,
178
+ ylabel = r"$\Delta E$ / meV/atom" ,
179
+ ymax = e_ymax ,
155
180
)
156
181
if save :
157
182
energy_plot .savefig (self .plots_dir / "energy.png" )
158
183
159
184
if "forces_true" in self .content :
160
- xlabel = r"$ab~initio$ force components per atom $|F|$ / meV$ \cdot \AA^{-1}$"
161
- ylabel = r"predicted force components per atom $|F|$ / meV$ \cdot \AA^{-1}$"
185
+ xlabel = (
186
+ r"$ab~initio$ force components per atom $F_{alpha,i}$ / meV$ \cdot"
187
+ r" \AA^{-1}$"
188
+ )
189
+ ylabel = r"$\Delta F_{alpha,i}$ / meV$ \cdot \AA^{-1}$"
190
+ f_ymax = self .figure_ymax .get ("forces" , None )
162
191
forces_plot = get_figure (
163
192
self .content ["forces_true" ],
164
- self .content ["forces_pred " ],
193
+ self .content ["forces_error " ],
165
194
datalabel = rf"MAE: { self .forces ['mae' ]:.2f} meV$ / \AA$" ,
166
195
xlabel = xlabel ,
167
196
ylabel = ylabel ,
197
+ ymax = f_ymax ,
168
198
)
169
199
if save :
170
200
forces_plot .savefig (self .plots_dir / "forces.png" )
171
201
172
202
if "stress_true" in self .content :
173
203
s_true = self .content ["stress_true" ]
174
- s_pred = self .content ["stress_pred " ]
204
+ s_error = self .content ["stress_error " ]
175
205
shydro_true = self .content ["stress_hydro_true" ]
176
- shydro_pred = self .content ["stress_hydro_pred " ]
206
+ shydro_error = self .content ["stress_hydro_error " ]
177
207
sdeviat_true = self .content ["stress_deviat_true" ]
178
- sdeviat_pred = self .content ["stress_deviat_pred" ]
208
+ sdeviat_error = self .content ["stress_deviat_error" ]
209
+
210
+ s_ymax = self .figure_ymax .get ("stress" , None )
211
+ hs_ymax = self .figure_ymax .get ("stress_hydro" , None )
212
+ ds_ymax = self .figure_ymax .get ("stress_deviat" , None )
179
213
180
214
stress_plot = get_figure (
181
215
s_true ,
182
- s_pred ,
216
+ s_error ,
183
217
datalabel = rf"Max: { self .stress ['max' ]:.4f} " ,
184
218
xlabel = r"$ab~initio$ stress" ,
185
- ylabel = r"predicted stress" ,
219
+ ylabel = r"$\Delta$ stress" ,
220
+ ymax = s_ymax ,
186
221
)
187
222
hydrostatic_stress_plot = get_figure (
188
223
shydro_true ,
189
- shydro_pred ,
224
+ shydro_error ,
190
225
datalabel = rf"Max: { self .stress_hydro ['max' ]:.4f} " ,
191
226
xlabel = r"$ab~initio$ hydrostatic stress" ,
192
- ylabel = r"predicted hydrostatic stress" ,
227
+ ylabel = r"$\Delta$ hydrostatic stress" ,
228
+ ymax = hs_ymax ,
193
229
)
194
230
deviatoric_stress_plot = get_figure (
195
231
sdeviat_true ,
196
- sdeviat_pred ,
232
+ sdeviat_error ,
197
233
datalabel = rf"Max: { self .stress_deviat ['max' ]:.4f} " ,
198
234
xlabel = r"$ab~initio$ deviatoric stress" ,
199
- ylabel = r"predicted deviatoric stress" ,
235
+ ylabel = r"$\Delta$ deviatoric stress" ,
236
+ ymax = ds_ymax ,
200
237
)
201
238
if save :
202
239
stress_plot .savefig (self .plots_dir / "stress.png" )
0 commit comments