From 71212156bff7ba322bfd16225a7e207a434a92c7 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Fri, 19 Jan 2024 18:57:59 +0100 Subject: [PATCH] new znframe versionn --- pyproject.toml | 2 +- znframe/frame.py | 38 ++++++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 604d32b..0274b20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "znframe" -version = "0.1.4" +version = "0.1.5" description = "ZnFrame - ASE-like Interface based on dataclasses" authors = ["zincwarecode "] readme = "README.md" diff --git a/znframe/frame.py b/znframe/frame.py index f841591..aac9583 100644 --- a/znframe/frame.py +++ b/znframe/frame.py @@ -7,10 +7,17 @@ from copy import deepcopy import json import typing as t +import enum from znframe.bonds import ASEComputeBonds +class ComputeProperties(enum.Enum): + bonds = "bonds" + radii = "radii" + colors = "colors" + + def _cell_to_array(cell: t.Union[np.ndarray, ase.cell.Cell]) -> np.ndarray: if isinstance(cell, np.ndarray): return cell @@ -79,21 +86,40 @@ class Frame: converter=_cell_to_array, eq=cmp_using(np.array_equal), default=np.zeros(3) ) + recompute: t.List[ComputeProperties] = field( + factory=lambda: [ + ComputeProperties.bonds, + ComputeProperties.radii, + ComputeProperties.colors, + ] + ) + def __attrs_post_init__(self): + if ComputeProperties.bonds in self.recompute: + self.connectivity = None + if ComputeProperties.radii in self.recompute: + self.arrays.pop("radii", None) + if ComputeProperties.colors in self.recompute: + self.arrays.pop("colors", None) if self.connectivity is None: ase_bond_calculator = ASEComputeBonds() self.connectivity = ase_bond_calculator.build_graph(self.to_atoms()) self.connectivity = ase_bond_calculator.get_bonds(self.connectivity) if "colors" not in self.arrays: - self.arrays["colors"] = [ - rgb2hex(jmol_colors[number]) for number in self.numbers - ] + self.arrays["colors"] = np.array( + [rgb2hex(jmol_colors[number]) for number in self.numbers] + ) if "radii" not in self.arrays: - self.arrays["radii"] = [get_radius(number) for number in self.numbers] + self.arrays["radii"] = np.array( + [get_radius(number) for number in self.numbers] + ) @classmethod - def from_atoms(cls, atoms: ase.Atoms): + def from_atoms( + cls, + atoms: ase.Atoms, + ): arrays = deepcopy(atoms.arrays) info = deepcopy(atoms.info) @@ -158,7 +184,7 @@ def to_atoms(self) -> ase.Atoms: return atoms def to_dict(self, built_in_types: bool = True) -> dict: - data = attrs.asdict(self) + data = attrs.asdict(self, filter=lambda attr, _: attr.name != "recompute") if built_in_types: return data else: