Skip to content

Commit 6cffdc9

Browse files
committed
Create a local validator
1 parent a0e8c8d commit 6cffdc9

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

Diff for: validate.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
3+
import h5py
4+
import numpy as np
5+
6+
from submission.run import Evaluator
7+
8+
DATA_PATH = "data/validation/data.hdf5"
9+
10+
11+
def main():
12+
# Load the data (combined features & targets)
13+
try:
14+
data = h5py.File(DATA_PATH, "r")
15+
except FileNotFoundError:
16+
print(f"Unable to load features at `{DATA_PATH}`")
17+
return
18+
19+
# Switch into the submission directory
20+
cwd = os.getcwd()
21+
os.chdir("submission")
22+
23+
# Make predictions on the data
24+
try:
25+
evaluator = Evaluator()
26+
27+
predictions = []
28+
for batch in evaluator.predict(features=data):
29+
assert batch.shape[-1] == 48
30+
predictions.append(batch)
31+
finally:
32+
os.chdir(cwd)
33+
34+
# Output the mean absolute error
35+
mae = np.mean(np.absolute(data["targets"] - np.concatenate(predictions)))
36+
print("MAE:", mae)
37+
38+
39+
if __name__ == "__main__":
40+
main()

0 commit comments

Comments
 (0)