Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
Add ase.calculator to frame (#5)
Browse files Browse the repository at this point in the history
* Add ase.calculator to frame

* Fix type errors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove duplicates

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix duplicate check

* Fix duplicate case

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
phohenberger and pre-commit-ci[bot] authored Jan 19, 2024
1 parent 1635c2c commit cb648dd
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion znframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import ase.cell
from ase.data.colors import jmol_colors
from ase.calculators.singlepoint import SinglePointCalculator
from copy import deepcopy
import json
import typing as t
Expand Down Expand Up @@ -38,6 +39,14 @@ def _ndarray_to_list(array: t.Union[dict, np.ndarray]) -> t.Union[dict, list]:
return array


def _npnumber_to_number(number):
if isinstance(number, np.floating):
return float(number)
if isinstance(number, np.integer):
return int(number)
return number


@define
class Frame:
numbers: np.ndarray = field(converter=_list_to_array, eq=cmp_using(np.array_equal))
Expand Down Expand Up @@ -83,6 +92,11 @@ def from_atoms(cls, atoms: ase.Atoms):
arrays = deepcopy(atoms.arrays)
info = deepcopy(atoms.info)

if hasattr(atoms.calc, "results"):
duplicates = list(set(atoms.calc.results.keys()) & set(atoms.arrays.keys()))
for key in duplicates:
arrays.pop(key)

frame = cls(
numbers=arrays.pop("numbers"),
positions=arrays.pop("positions"),
Expand All @@ -92,6 +106,16 @@ def from_atoms(cls, atoms: ase.Atoms):
cell=atoms.cell,
)

try:
calc_data = {}
for key, value in atoms.calc.results.items():
if isinstance(value, np.ndarray):
value = value.tolist()
calc_data[key] = value
frame.info["calc"] = calc_data
except AttributeError:
pass

try:
frame.connectivity = atoms.connectivity
except AttributeError:
Expand All @@ -107,6 +131,14 @@ def to_atoms(self) -> ase.Atoms:
cell=self.cell,
)

if "calc" in self.info:
calc = self.info.pop("calc", None)
atoms.calc = SinglePointCalculator(atoms)
atoms.calc.results = {
key: np.array(val) if isinstance(val, list) else val
for key, val in calc
}

atoms.arrays.update(self.arrays)
atoms.info.update(self.info)
atoms.connectivity = self.connectivity
Expand All @@ -118,7 +150,13 @@ def to_dict(self, built_in_types: bool = True) -> dict:
if built_in_types:
return data
else:
return _ndarray_to_list(data)
data = _ndarray_to_list(data)
for key, value in data["info"].items():
if isinstance(value, np.generic):
data["info"][key] = _npnumber_to_number(value)
else:
data["info"][key] = value
return data

@classmethod
def from_dict(cls, d: dict):
Expand Down

0 comments on commit cb648dd

Please sign in to comment.