diff --git a/znframe/frame.py b/znframe/frame.py index f22843d..6edc98c 100644 --- a/znframe/frame.py +++ b/znframe/frame.py @@ -3,6 +3,7 @@ import numpy as np import ase.cell from ase.data.colors import jmol_colors +from ase.calculators.singlepoint import SinglePointCalculator from copy import deepcopy import json import typing as t @@ -38,6 +39,14 @@ def _ndarray_to_list(array: t.Union[dict, np.ndarray]) -> t.Union[dict, list]: return array +def _npnumber_to_number(number): + if isinstance(number, np.floating): + return float(number) + if isinstance(number, np.integer): + return int(number) + return number + + @define class Frame: numbers: np.ndarray = field(converter=_list_to_array, eq=cmp_using(np.array_equal)) @@ -83,6 +92,11 @@ def from_atoms(cls, atoms: ase.Atoms): arrays = deepcopy(atoms.arrays) info = deepcopy(atoms.info) + if hasattr(atoms.calc, "results"): + duplicates = list(set(atoms.calc.results.keys()) & set(atoms.arrays.keys())) + for key in duplicates: + arrays.pop(key) + frame = cls( numbers=arrays.pop("numbers"), positions=arrays.pop("positions"), @@ -92,6 +106,16 @@ def from_atoms(cls, atoms: ase.Atoms): cell=atoms.cell, ) + try: + calc_data = {} + for key, value in atoms.calc.results.items(): + if isinstance(value, np.ndarray): + value = value.tolist() + calc_data[key] = value + frame.info["calc"] = calc_data + except AttributeError: + pass + try: frame.connectivity = atoms.connectivity except AttributeError: @@ -107,6 +131,14 @@ def to_atoms(self) -> ase.Atoms: cell=self.cell, ) + if "calc" in self.info: + calc = self.info.pop("calc", None) + atoms.calc = SinglePointCalculator(atoms) + atoms.calc.results = { + key: np.array(val) if isinstance(val, list) else val + for key, val in calc + } + atoms.arrays.update(self.arrays) atoms.info.update(self.info) atoms.connectivity = self.connectivity @@ -118,7 +150,13 @@ def to_dict(self, built_in_types: bool = True) -> dict: if built_in_types: return data else: - return _ndarray_to_list(data) + data = _ndarray_to_list(data) + for key, value in data["info"].items(): + if isinstance(value, np.generic): + data["info"][key] = _npnumber_to_number(value) + else: + data["info"][key] = value + return data @classmethod def from_dict(cls, d: dict):