Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

New tests #7

Merged
merged 24 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7e133a1
test singlepoint calculator
PythonFZ Jan 19, 2024
cea38ef
add forces
PythonFZ Jan 19, 2024
ae2006c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
daed33a
Fix calc bug
phohenberger Jan 19, 2024
34e357f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
3510941
PBC
PythonFZ Jan 19, 2024
345bfb6
Update test_water_with_calc
phohenberger Jan 19, 2024
2f44ded
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
18619cb
Fix
phohenberger Jan 19, 2024
0763246
Another try to fix
phohenberger Jan 19, 2024
100c517
add assert
PythonFZ Jan 19, 2024
d7dd30c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
369b657
update test
PythonFZ Jan 19, 2024
e3a3c47
merge
PythonFZ Jan 19, 2024
3a7f988
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
ee249de
update test (calrify)
PythonFZ Jan 19, 2024
75c3e75
Merge branch 'tests' of https://github.com/zincware/ZnFrame into tests
PythonFZ Jan 19, 2024
c190057
update how it should look like
PythonFZ Jan 19, 2024
713b05b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
d722cf0
Move self.info["calc"] to self.calc
phohenberger Jan 19, 2024
de04c02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
5b80719
Fix tests
phohenberger Jan 19, 2024
f46db0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
96e09fc
Another try to fix test
phohenberger Jan 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znframe"
version = "0.1.2"
version = "0.1.3"
description = "ZnFrame - ASE-like Interface based on dataclasses"
authors = ["zincwarecode <[email protected]>"]
readme = "README.md"
Expand Down
54 changes: 54 additions & 0 deletions tests/test_znframe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ase.build import molecule
import ase.io
from ase.calculators.singlepoint import SinglePointCalculator
from znframe import Frame
import pytest
import numpy as np
Expand All @@ -18,6 +20,17 @@ def ammonia() -> Frame:
return Frame.from_atoms(ammonia)


@pytest.fixture
def waterWithCalc() -> ase.Atoms:
atoms = molecule("H2O")
atoms.cell = [[10, 0, 0], [0, 10, 0], [0, 0, 10]]
atoms.pbc = [True, True, True]
atoms.calc = SinglePointCalculator(
atoms, energy=-1234, forces=np.random.random((3, 3)), stress=np.random.random(6)
)
return atoms


def test_frame_from_ase_molecule(ammonia):
assert ammonia.numbers.shape == (4,)
assert ammonia.positions.shape == (4, 3)
Expand All @@ -44,6 +57,7 @@ def test_frame_to_dict(water):
"connectivity": water.connectivity,
"arrays": water.arrays,
"info": water.info,
"calc": water.calc,
"pbc": water.pbc,
"cell": water.cell,
}
Expand All @@ -55,3 +69,43 @@ def test_frame_from_dict(ammonia):

def test_to_json(ammonia):
assert Frame.from_json(ammonia.to_json()) == ammonia


def test_water_with_calc(waterWithCalc):
assert "forces" in waterWithCalc.calc.results
assert "stress" in waterWithCalc.calc.results
assert "energy" in waterWithCalc.calc.results

assert "forces" not in waterWithCalc.info
assert "stress" not in waterWithCalc.arrays
assert "energy" not in waterWithCalc.arrays

assert "forces" not in waterWithCalc.arrays
assert "stress" not in waterWithCalc.info
assert "energy" not in waterWithCalc.info

frame = Frame.from_atoms(waterWithCalc)
intersection = set(frame.info) & set(frame.arrays)
if intersection:
raise ValueError(f"Duplicate keys: {intersection}")

assert "forces" not in frame.arrays
assert "stress" not in frame.arrays
assert "energy" not in frame.arrays

assert "stress" not in frame.info
assert "energy" not in frame.info
assert "forces" not in frame.info

assert "forces" in frame.calc
assert "stress" in frame.calc
assert "energy" in frame.calc

atoms = frame.to_atoms()
for key in atoms.calc.results.keys():
if isinstance(atoms.calc.results[key], np.ndarray):
np.testing.assert_array_equal(
atoms.calc.results[key], waterWithCalc.calc.results[key]
)
else:
assert atoms.calc.results[key] == waterWithCalc.calc.results[key]
24 changes: 18 additions & 6 deletions znframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class Frame:
converter=_list_to_array, eq=False, factory=dict
)

# Do we need this, or can we just create calc in post_init?
calc: dict[str, t.Union[float, int, np.ndarray]] = field(
converter=_list_to_array, eq=False, factory=dict
)

pbc: np.ndarray = field(
converter=_list_to_array,
eq=cmp_using(np.array_equal),
Expand Down Expand Up @@ -93,10 +98,18 @@ def from_atoms(cls, atoms: ase.Atoms):
info = deepcopy(atoms.info)

if hasattr(atoms.calc, "results"):
duplicates = list(set(atoms.calc.results.keys()) & set(atoms.arrays.keys()))
for key in duplicates:
duplicates_arrays = list(
set(atoms.calc.results.keys()) & set(atoms.arrays.keys())
)
for key in duplicates_arrays:
arrays.pop(key)

duplicates_info = list(
set(atoms.calc.results.keys()) & set(atoms.info.keys())
)
for key in duplicates_info:
info.pop(key)

frame = cls(
numbers=arrays.pop("numbers"),
positions=arrays.pop("positions"),
Expand All @@ -112,7 +125,7 @@ def from_atoms(cls, atoms: ase.Atoms):
if isinstance(value, np.ndarray):
value = value.tolist()
calc_data[key] = value
frame.info["calc"] = calc_data
frame.calc = calc_data
except AttributeError:
pass

Expand All @@ -131,12 +144,11 @@ def to_atoms(self) -> ase.Atoms:
cell=self.cell,
)

if "calc" in self.info:
calc = self.info.pop("calc", None)
if self.calc:
atoms.calc = SinglePointCalculator(atoms)
atoms.calc.results = {
key: np.array(val) if isinstance(val, list) else val
for key, val in calc
for key, val in self.calc.items()
}

atoms.arrays.update(self.arrays)
Expand Down