Skip to content

Commit 7f8fc52

Browse files
added MD22 dataset node (#247)
* added MD22 dataset node * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting * changed base class of dataset * use tmpdir * remove debug prints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed unnecessary `outs_path` --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 78a42c4 commit 7f8fc52

File tree

5 files changed

+106
-0
lines changed

5 files changed

+106
-0
lines changed

ipsuite/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
configuration_generation,
1212
configuration_selection,
1313
data_loading,
14+
datasets,
1415
fields,
1516
geometry,
1617
models,
@@ -38,6 +39,7 @@
3839
"geometry",
3940
"combine",
4041
"data_loading",
42+
"datasets",
4143
"nodes",
4244
]
4345

ipsuite/datasets/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ipsuite.datasets.md22 import MD22Dataset
2+
3+
__all__ = ["MD22Dataset"]

ipsuite/datasets/md22.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import tempfile
2+
import typing
3+
import urllib
4+
import zipfile
5+
from pathlib import Path
6+
7+
import ase
8+
import zntrack
9+
from ase import units
10+
11+
import ipsuite as ips
12+
from ipsuite import fields
13+
14+
15+
def modify_xyz_file(file_path, target_string, replacement_string):
16+
new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix)
17+
18+
with file_path.open("r") as input_file, new_file_path.open("w") as output_file:
19+
for line in input_file:
20+
# Replace all occurrences of the target string with the replacement string
21+
modified_line = line.replace(target_string, replacement_string)
22+
output_file.write(modified_line)
23+
return new_file_path
24+
25+
26+
def download_data(url: str, data_path: Path):
27+
url_path = Path(urllib.parse.urlparse(url).path)
28+
zip_path = data_path / url_path.stem
29+
file_path = zip_path.with_suffix(".xyz")
30+
urllib.request.urlretrieve(url, zip_path)
31+
32+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
33+
zip_ref.extractall(data_path)
34+
35+
file_path = modify_xyz_file(
36+
file_path, target_string="Energy", replacement_string="energy"
37+
)
38+
return file_path
39+
40+
41+
class MD22Dataset(ips.base.IPSNode):
42+
dataset: str = zntrack.params()
43+
44+
atoms: typing.List[ase.Atoms] = fields.Atoms()
45+
46+
datasets = {
47+
"Ac-Ala3-NHMe": (
48+
"http://www.quantum-machine.org/gdml/repo/static/md22_Ac-Ala3-NHMe.zip"
49+
),
50+
"DHA": "http://www.quantum-machine.org/gdml/repo/static/md22_DHA.zip",
51+
"stachyose": "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip",
52+
"AT-AT": "http://www.quantum-machine.org/gdml/repo/static/md22_AT-AT.zip",
53+
"AT-AT-CG-CG": (
54+
"http://www.quantum-machine.org/gdml/repo/static/md22_AT-AT-CG-CG.zip"
55+
),
56+
"buckyball-catcher": (
57+
"http://www.quantum-machine.org/gdml/repo/static/md22_buckyball-catcher.zip"
58+
),
59+
"double-walled_nanotube": "http://www.quantum-machine.org/gdml/repo/static/md22_double-walled_nanotube.zip",
60+
}
61+
62+
def run(self):
63+
tmpdir = tempfile.TemporaryDirectory()
64+
raw_data_dir = Path(tmpdir.name) / "raw_data"
65+
raw_data_dir.mkdir(parents=True, exist_ok=True)
66+
if self.dataset not in self.datasets.keys():
67+
raise FileNotFoundError(
68+
f"Dataset {self.dataset} is not known. Key has top be in {self.datasets}"
69+
)
70+
71+
url = self.datasets[self.dataset]
72+
73+
file_path = download_data(url, raw_data_dir)
74+
75+
self.atoms = ase.io.read(file_path, ":")
76+
for atoms in self.atoms:
77+
atoms.calc.results["energy"] *= units.kcal / units.mol
78+
atoms.calc.results["forces"] *= units.kcal / units.mol
79+
tmpdir.cleanup()

ipsuite/nodes.py

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ class _Nodes:
4444
AddDataH5MD = "ipsuite.data_loading.AddDataH5MD"
4545
ReadData = "ipsuite.data_loading.ReadData"
4646

47+
# Datasets
48+
MD22Dataset = "ipsuite.datasets.MD22Dataset"
49+
4750
# Bootstrap
4851
RattleAtoms = "ipsuite.bootstrap.RattleAtoms"
4952
TranslateMolecules = "ipsuite.bootstrap.TranslateMolecules"

tests/integration/test_i_datasets.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import requests
2+
3+
import ipsuite as ips
4+
5+
6+
def test_md22():
7+
for url in ips.datasets.MD22Dataset.datasets.values():
8+
response = requests.get(url)
9+
assert response.status_code == 200
10+
11+
project = ips.Project(automatic_node_names=True)
12+
13+
with project:
14+
data = ips.datasets.MD22Dataset("AT-AT")
15+
16+
project.run()
17+
18+
data.load()
19+
assert len(data.atoms) > 0

0 commit comments

Comments
 (0)