From 8058ad85ef2ca31afe8410c920a75f53ac378922 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 10 May 2024 20:40:30 -0400 Subject: [PATCH 01/27] use index_select --- nequip/data/AtomicDataDict.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nequip/data/AtomicDataDict.py b/nequip/data/AtomicDataDict.py index f7713e6f..ba8c75d1 100644 --- a/nequip/data/AtomicDataDict.py +++ b/nequip/data/AtomicDataDict.py @@ -5,6 +5,7 @@ Authors: Albert Musaelian """ + from typing import Dict, Any import torch @@ -67,7 +68,10 @@ def with_edge_vectors(data: Type, with_lengths: bool = True) -> Type: # (2) works on a Batch constructed from AtomicData pos = data[_keys.POSITIONS_KEY] edge_index = data[_keys.EDGE_INDEX_KEY] - edge_vec = pos[edge_index[1]] - pos[edge_index[0]] + # edge_vec = pos[edge_index[1]] - pos[edge_index[0]] + edge_vec = torch.index_select(pos, 0, edge_index[1]) - torch.index_select( + pos, 0, edge_index[0] + ) if _keys.CELL_KEY in data: # ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero. # -1 gives a batch dim no matter what From be22479f00202901ac838c3d1d09552d0a9d73dc Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Sun, 9 Jun 2024 16:38:45 -0400 Subject: [PATCH 02/27] Avoid unnecessary e3nn-related JIT warning --- nequip/scripts/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 3d10049b..f6046eb7 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -23,6 +23,9 @@ from nequip.utils._global_options import _set_global_options from nequip.scripts._logger import set_up_script_logger +warnings.filterwarnings( # unnecessary e3nn-related JIT warning + "ignore", message="The TorchScript type system doesn't support instance-level annotations" +) default_config = dict( root="./", tensorboard=False, From c3ac2e9913a2a6c2aa96cea013523336dd2a84c2 Mon Sep 17 00:00:00 2001 From: Chuin Wei Tan <87742566+cw-tan@users.noreply.github.com> Date: Fri, 21 Jun 2024 17:40:15 -0400 Subject: [PATCH 03/27] Cartesian tensors and entry point extension registration --------- Co-authored-by: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Co-authored-by: cw-tan --- CHANGELOG.md | 10 +++ nequip/__init__.py | 39 +++++++++++ nequip/data/AtomicData.py | 95 +++++++++++++++++++++----- nequip/data/__init__.py | 2 + nequip/data/_build.py | 6 +- nequip/data/_dataset/_base_datasets.py | 2 +- nequip/scripts/deploy.py | 2 + nequip/train/_loss.py | 3 +- nequip/utils/_global_options.py | 10 --- nequip/utils/test.py | 34 +++++---- setup.py | 1 + 11 files changed, 157 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c60ed185..aefad560 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. +## Unreleased - 0.6.1 +### Added +- add support for equivariance testing of arbitrary Cartesian tensor outputs +- [Breaking] use entry points for `nequip.extension`s (e.g. for field registration) + + +### Changed + + ## Unreleased + ## [0.6.0] - 2024-5-10 ### Added - add Tensorboard as logger option diff --git a/nequip/__init__.py b/nequip/__init__.py index 3a8d6d5c..654cd78a 100644 --- a/nequip/__init__.py +++ b/nequip/__init__.py @@ -1 +1,40 @@ +import sys + from ._version import __version__ # noqa: F401 + +import packaging.version + +import torch +import warnings + +# torch version checks +torch_version = packaging.version.parse(torch.__version__) + +# only allow 1.11*, 1.13* or higher (no 1.12.*) +assert (torch_version > packaging.version.parse("1.11.0")) and not ( + packaging.version.parse("1.12.0") + <= torch_version + < packaging.version.parse("1.13.0") +), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found" + +# warn if using 1.13* or 2.0.* +if packaging.version.parse("1.13.0") <= torch_version < packaging.version.parse("2.1"): + warnings.warn( + f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.0.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue." + ) + + +# Load all installed nequip extension packages +# This allows installed extensions to register themselves in +# the nequip infrastructure with calls like `register_fields` + +# see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +_DISCOVERED_NEQUIP_EXTENSION = entry_points(group="nequip.extension") +for ep in _DISCOVERED_NEQUIP_EXTENSION: + if ep.name == "init_always": + ep.load() diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 70c8fd2e..2267930c 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -18,6 +18,7 @@ import torch import e3nn.o3 +from e3nn.io import CartesianTensor from . import AtomicDataDict from ._util import _TORCH_INTEGER_DTYPES @@ -26,6 +27,7 @@ # A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case) PBC = Union[bool, Tuple[bool, bool, bool]] +# === Key Registration === _DEFAULT_LONG_FIELDS: Set[str] = { AtomicDataDict.EDGE_INDEX_KEY, @@ -61,10 +63,15 @@ AtomicDataDict.CELL_KEY, AtomicDataDict.BATCH_PTR_KEY, } +_DEFAULT_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = { + AtomicDataDict.STRESS_KEY: "ij=ji", + AtomicDataDict.VIRIAL_KEY: "ij=ji", +} _NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS) _EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS) _GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS) _LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS) +_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = dict(_DEFAULT_CARTESIAN_TENSOR_FIELDS) def register_fields( @@ -72,6 +79,7 @@ def register_fields( edge_fields: Sequence[str] = [], graph_fields: Sequence[str] = [], long_fields: Sequence[str] = [], + cartesian_tensor_fields: Dict[str, str] = {}, ) -> None: r"""Register fields as being per-atom, per-edge, or per-frame. @@ -83,18 +91,36 @@ def register_fields( edge_fields: set = set(edge_fields) graph_fields: set = set(graph_fields) long_fields: set = set(long_fields) - allfields = node_fields.union(edge_fields, graph_fields) - assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields) + + # error checking: prevents registering fields as contradictory types + # potentially unregistered fields + assert len(node_fields.intersection(edge_fields)) == 0 + assert len(node_fields.intersection(graph_fields)) == 0 + assert len(edge_fields.intersection(graph_fields)) == 0 + # already registered fields + assert len(_NODE_FIELDS.intersection(edge_fields)) == 0 + assert len(_NODE_FIELDS.intersection(graph_fields)) == 0 + assert len(_EDGE_FIELDS.intersection(node_fields)) == 0 + assert len(_EDGE_FIELDS.intersection(graph_fields)) == 0 + assert len(_GRAPH_FIELDS.intersection(edge_fields)) == 0 + assert len(_GRAPH_FIELDS.intersection(node_fields)) == 0 + + # check that Cartesian tensor fields to add are rank-2 (higher ranks not supported) + for cart_tensor_key in cartesian_tensor_fields: + cart_tensor_rank = len( + CartesianTensor(cartesian_tensor_fields[cart_tensor_key]).indices + ) + if cart_tensor_rank != 2: + raise NotImplementedError( + f"Only rank-2 tensor data processing supported, but got {cart_tensor_key} is rank {cart_tensor_rank}. Consider raising a GitHub issue if higher-rank tensor data processing is desired." + ) + + # update fields _NODE_FIELDS.update(node_fields) _EDGE_FIELDS.update(edge_fields) _GRAPH_FIELDS.update(graph_fields) _LONG_FIELDS.update(long_fields) - if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < ( - len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS) - ): - raise ValueError( - "At least one key was registered as more than one of node, edge, or graph!" - ) + _CARTESIAN_TENSOR_FIELDS.update(cartesian_tensor_fields) def deregister_fields(*fields: Sequence[str]) -> None: @@ -109,9 +135,16 @@ def deregister_fields(*fields: Sequence[str]) -> None: assert f not in _DEFAULT_NODE_FIELDS, "Cannot deregister built-in field" assert f not in _DEFAULT_EDGE_FIELDS, "Cannot deregister built-in field" assert f not in _DEFAULT_GRAPH_FIELDS, "Cannot deregister built-in field" + assert f not in _DEFAULT_LONG_FIELDS, "Cannot deregister built-in field" + assert ( + f not in _DEFAULT_CARTESIAN_TENSOR_FIELDS + ), "Cannot deregister built-in field" + _NODE_FIELDS.discard(f) _EDGE_FIELDS.discard(f) _GRAPH_FIELDS.discard(f) + _LONG_FIELDS.discard(f) + _CARTESIAN_TENSOR_FIELDS.pop(f, None) def _register_field_prefix(prefix: str) -> None: @@ -125,6 +158,9 @@ def _register_field_prefix(prefix: str) -> None: ) +# === AtomicData === + + def _process_dict(kwargs, ignore_fields=[]): """Convert a dict of data into correct dtypes/shapes according to key""" # Deal with _some_ dtype issues @@ -449,17 +485,40 @@ def from_ase( cell = kwargs.pop("cell", atoms.get_cell()) pbc = kwargs.pop("pbc", atoms.pbc) - # handle ASE-style 6 element Voigt order stress - for key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY): - if key in add_fields: - if add_fields[key].shape == (3, 3): - # it's already 3x3, do nothing else - pass - elif add_fields[key].shape == (6,): - # it's Voigt order - add_fields[key] = voigt_6_to_full_3x3_stress(add_fields[key]) + # IMPORTANT: the following reshape logic only applies to rank-2 Cartesian tensor fields + for key in add_fields: + if key in _CARTESIAN_TENSOR_FIELDS: + # enforce (3, 3) shape for graph fields, e.g. stress, virial + if key in _GRAPH_FIELDS: + # handle ASE-style 6 element Voigt order stress + if key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY): + if add_fields[key].shape == (6,): + add_fields[key] = voigt_6_to_full_3x3_stress( + add_fields[key] + ) + if add_fields[key].shape == (3, 3): + # it's already 3x3, do nothing else + pass + elif add_fields[key].shape == (9,): + add_fields[key] = add_fields[key].reshape((3, 3)) + else: + raise RuntimeError( + f"bad shape for {key} registered as a Cartesian tensor graph field---please note that only rank-2 Cartesian tensors are currently supported" + ) + # enforce (N_atom, 3, 3) shape for node fields, e.g. Born effective charges + elif key in _NODE_FIELDS: + if add_fields[key].shape[1:] == (3, 3): + pass + elif add_fields[key].shape[1:] == (9,): + add_fields[key] = add_fields[key].reshape((-1, 3, 3)) + else: + raise RuntimeError( + f"bad shape for {key} registered as a Cartesian tensor node field---please note that only rank-2 Cartesian tensors are currently supported" + ) else: - raise RuntimeError(f"bad shape for {key}") + raise RuntimeError( + f"{key} registered as a Cartesian tensor field was not registered as either a graph or node field" + ) return cls.from_points( pos=atoms.positions, diff --git a/nequip/data/__init__.py b/nequip/data/__init__.py index 02c41d55..5cbbc853 100644 --- a/nequip/data/__init__.py +++ b/nequip/data/__init__.py @@ -8,6 +8,7 @@ _EDGE_FIELDS, _GRAPH_FIELDS, _LONG_FIELDS, + _CARTESIAN_TENSOR_FIELDS, ) from ._dataset import ( AtomicDataset, @@ -39,5 +40,6 @@ _EDGE_FIELDS, _GRAPH_FIELDS, _LONG_FIELDS, + _CARTESIAN_TENSOR_FIELDS, EMTTestDataset, ] diff --git a/nequip/data/_build.py b/nequip/data/_build.py index 35b59dba..80c46a94 100644 --- a/nequip/data/_build.py +++ b/nequip/data/_build.py @@ -3,7 +3,7 @@ from nequip import data from nequip.data.transforms import TypeMapper -from nequip.data import AtomicDataset, register_fields +from nequip.data import AtomicDataset from nequip.utils import instantiate, get_w_prefix @@ -71,10 +71,6 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset: # Build a TypeMapper from the config type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config) - # Register fields: - # This might reregister fields, but that's OK: - instantiate(register_fields, all_args=config) - instance, _ = instantiate( class_name, prefix=prefix, diff --git a/nequip/data/_dataset/_base_datasets.py b/nequip/data/_dataset/_base_datasets.py index bda86734..933a87a6 100644 --- a/nequip/data/_dataset/_base_datasets.py +++ b/nequip/data/_dataset/_base_datasets.py @@ -416,7 +416,7 @@ def statistics( if field not in selectors: # this means field is not selected and so not available raise RuntimeError( - f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such using `nequip.data.register_fields`" + f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such" ) arr = data_transformed[field] if field in _NODE_FIELDS: diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index a0772df9..19bc791a 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -71,6 +71,8 @@ def _set_deploy_metadata(key: str, value) -> None: global _current_metadata if _current_metadata is None: pass # not deploying right now + elif key not in _ALL_METADATA_KEYS: + raise KeyError(f"{key} is not a registered model deployment metadata key") elif key in _current_metadata: raise RuntimeError(f"{key} already set in the deployment metadata") else: diff --git a/nequip/train/_loss.py b/nequip/train/_loss.py index 6442c0d4..144d348d 100644 --- a/nequip/train/_loss.py +++ b/nequip/train/_loss.py @@ -83,7 +83,8 @@ def __call__( # zero the nan entries has_nan = self.ignore_nan and torch.isnan(ref.sum()) N = torch.bincount(ref_dict[AtomicDataDict.BATCH_KEY]) - N = N.reshape((-1, 1)) + # as many dimensions of size 1 as there are non-batch dimensions in the data + N = N.reshape((-1,) + (1,) * (pred.ndim - 1)) if has_nan: not_nan = (ref == ref).int() loss = self.func(pred, torch.nan_to_num(ref, nan=0.0)) * not_nan / N diff --git a/nequip/utils/_global_options.py b/nequip/utils/_global_options.py index bc5bc2d9..3a08e55e 100644 --- a/nequip/utils/_global_options.py +++ b/nequip/utils/_global_options.py @@ -7,9 +7,7 @@ import e3nn import e3nn.util.jit -from nequip.data import register_fields from .misc import dtype_from_name -from .auto_init import instantiate from .test import set_irreps_debug from .config import Config @@ -53,12 +51,6 @@ def _set_global_options(config, warn_on_override: bool = False) -> None: # Temporary warning due to unresolved upstream issue torch_version = version.parse(torch.__version__) - if torch_version < version.parse("1.11"): - warnings.warn("We currently recommend the use of PyTorch 1.11") - elif torch_version > version.parse("1.11"): - warnings.warn( - "!! Upstream issues in PyTorch versions >1.11 have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. At present we *strongly* recommend the use of PyTorch 1.11 if using CUDA devices; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue." - ) if torch_version >= version.parse("1.11"): # PyTorch >= 1.11 @@ -122,6 +114,4 @@ def _set_global_options(config, warn_on_override: bool = False) -> None: e3nn.set_optimization_defaults(**config.get("e3nn_optimization_defaults", {})) - # Register fields: - instantiate(register_fields, all_args=config) return diff --git a/nequip/utils/test.py b/nequip/utils/test.py index 7c0bde3f..de843418 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -10,6 +10,7 @@ AtomicDataDict, _NODE_FIELDS, _EDGE_FIELDS, + _CARTESIAN_TENSOR_FIELDS, ) @@ -209,17 +210,26 @@ def assert_AtomicData_equivariant( # must be this to actually rotate it when flattened irps[AtomicDataDict.CELL_KEY] = "3x1o" - stress_keys = (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY) - for k in stress_keys: + cartesian_keys = _CARTESIAN_TENSOR_FIELDS.keys() + for k in ( + AtomicDataDict.STRESS_KEY, + AtomicDataDict.VIRIAL_KEY, + ): # TODO should this be cartesian_keys? irreps_in.pop(k, None) - if any(k in irreps_out for k in stress_keys): + if any(k in irreps_out for k in cartesian_keys): from e3nn.io import CartesianTensor - stress_cart_tensor = CartesianTensor("ij=ji") # stress is symmetric - stress_rtp = stress_cart_tensor.reduced_tensor_products().to(device, dtype) - # symmetric 3x3 cartesian tensor as irreps - for k in stress_keys: - irreps_out[k] = stress_cart_tensor + cartesian_tensor = { + k: CartesianTensor(_CARTESIAN_TENSOR_FIELDS[k]) + for k in cartesian_keys + if k in irreps_out + } + cartesian_rtp = { + k: ct.reduced_tensor_products().to(device, dtype) + for k, ct in cartesian_tensor.items() + } + for k, ct in cartesian_tensor.items(): + irreps_out[k] = ct def wrapper(*args): arg_dict = {k: v for k, v in zip(irreps_in, args)} @@ -238,12 +248,12 @@ def wrapper(*args): val = output[key] assert val.shape[-2:] == (3, 3) output[key] = val.reshape(val.shape[:-2] + (9,)) - # stress is also a special case, + # cartesian tensors like stress are also a special case, # we need it to be decomposed into irreps for equivar testing - for k in stress_keys: + for k in cartesian_keys: if k in output: - output[k] = stress_cart_tensor.from_cartesian( - output[k], rtp=stress_rtp.to(output[k].dtype) + output[k] = cartesian_tensor[k].from_cartesian( + output[k], rtp=cartesian_rtp[k].to(output[k].dtype) ) return [output[k] for k in irreps_out] diff --git a/setup.py b/setup.py index 6ca9e3cf..af851a0e 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "contextlib2;python_version<'3.7'", # backport of nullcontext 'contextvars;python_version<"3.7"', # backport of contextvars for savenload "typing_extensions;python_version<'3.8'", # backport of Final + "importlib_metadata;python_version<'3.10'", # backport of importlib "torch-runstats>=0.2.0", "torch-ema>=0.3.0", ], From e625e123381cb5c057417ac0a88028049a78b125 Mon Sep 17 00:00:00 2001 From: Chuin Wei Tan <87742566+cw-tan@users.noreply.github.com> Date: Fri, 21 Jun 2024 17:43:49 -0400 Subject: [PATCH 04/27] Vesin neighborlist support --------- Co-authored-by: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Co-authored-by: cw-tan --- CHANGELOG.md | 6 ++---- nequip/data/AtomicData.py | 39 ++++++++++++++++++++++++++++++++------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aefad560..20eea468 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,12 +10,10 @@ Most recent change on the bottom. ### Added - add support for equivariance testing of arbitrary Cartesian tensor outputs - [Breaking] use entry points for `nequip.extension`s (e.g. for field registration) - +- alternate neighborlist support enabled with `NEQUIP_NL` environment variable, which can be set to `ase` (default), `matscipy` or `vesin` ### Changed - - -## Unreleased +- [Breaking] `NEQUIP_MATSCIPY_NL` environment variable no longer supported ## [0.6.0] - 2024-5-10 diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index 2267930c..805d0cf5 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -10,7 +10,6 @@ import os import numpy as np -import ase.neighborlist import ase from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator from ase.calculators.calculator import all_properties as ase_all_properties @@ -764,12 +763,21 @@ def without_nodes(self, which_nodes): assert _ERROR_ON_NO_EDGES in ("true", "false") _ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true" -_NEQUIP_MATSCIPY_NL: Final[bool] = os.environ.get("NEQUIP_MATSCIPY_NL", "false").lower() -assert _NEQUIP_MATSCIPY_NL in ("true", "false") -_NEQUIP_MATSCIPY_NL = _NEQUIP_MATSCIPY_NL == "true" +# use "ase" as default +# TODO: eventually, choose fastest as default +# NOTE: +# - vesin and matscipy do not support self-interaction +# - vesin does not allow for mixed pbcs +_NEQUIP_NL: Final[str] = os.environ.get("NEQUIP_NL", "ase").lower() -if _NEQUIP_MATSCIPY_NL: +if _NEQUIP_NL == "vesin": + from vesin import NeighborList as vesin_nl +elif _NEQUIP_NL == "matscipy": import matscipy.neighbours +elif _NEQUIP_NL == "ase": + import ase.neighborlist +else: + raise NotImplementedError(f"Unknown neighborlist NEQUIP_NL = {_NEQUIP_NL}") def neighbor_list_and_relative_vec( @@ -849,7 +857,24 @@ def neighbor_list_and_relative_vec( # ASE dependent part temp_cell = ase.geometry.complete_cell(temp_cell) - if _NEQUIP_MATSCIPY_NL: + if _NEQUIP_NL == "vesin": + assert strict_self_interaction and not self_interaction + # use same mixed pbc logic as + # https://github.com/Luthaf/vesin/blob/main/python/vesin/src/vesin/_ase.py + if pbc[0] and pbc[1] and pbc[2]: + periodic = True + elif not pbc[0] and not pbc[1] and not pbc[2]: + periodic = False + else: + raise ValueError( + "different periodic boundary conditions on different axes are not supported by vesin neighborlist, use ASE or matscipy" + ) + + first_idex, second_idex, shifts = vesin_nl( + cutoff=float(r_max), full_list=True + ).compute(points=temp_pos, box=temp_cell, periodic=periodic, quantities="ijS") + + elif _NEQUIP_NL == "matscipy": assert strict_self_interaction and not self_interaction first_idex, second_idex, shifts = matscipy.neighbours.neighbour_list( "ijS", @@ -858,7 +883,7 @@ def neighbor_list_and_relative_vec( positions=temp_pos, cutoff=float(r_max), ) - else: + elif _NEQUIP_NL == "ase": first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( "ijS", pbc, From 0a03b0bb79f4ff0d0905b32aff2ad44fae8f9186 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Fri, 21 Jun 2024 22:22:08 -0400 Subject: [PATCH 05/27] Only attempt restart when `trainer.pth` present, not if folder exists (e.g. if training crashed during first epoch, during data loading (due to memory...), hitting walltime etc before model saved) --- nequip/scripts/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index f6046eb7..4caf5636 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -7,7 +7,7 @@ # Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance. import numpy as np # noqa: F401 -from os.path import isdir +from os.path import exists from pathlib import Path import torch @@ -74,7 +74,7 @@ def main(args=None, running_as_script: bool = True): if running_as_script: set_up_script_logger(config.get("log", None), config.verbose) - found_restart_file = isdir(f"{config.root}/{config.run_name}") + found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth") if found_restart_file and not config.append: raise RuntimeError( f"Training instance exists at {config.root}/{config.run_name}; " From 4c5cd25d9f68b43ea217a6990c496b2981df3725 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sat, 22 Jun 2024 12:33:44 -0400 Subject: [PATCH 06/27] Likely resolve #431 --- nequip/scripts/deploy.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index 19bc791a..24d36a52 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -110,10 +110,23 @@ def load_deployed_model( f"{model_path} does not seem to be a deployed NequIP model file. Did you forget to deploy it using `nequip-deploy`? \n\n(Underlying error: {e})" ) # Confirm nequip made it - if metadata[NEQUIP_VERSION_KEY] == "": - raise ValueError( - f"{model_path} does not seem to be a deployed NequIP model file" - ) + if len(metadata[NEQUIP_VERSION_KEY]) == 0: + if len(metadata[JIT_BAILOUT_KEY]) != 0: + # In versions <0.6.0, there may have been a bug leading to empty "*_version" + # metadata keys. We can be pretty confident this is a NequIP model from + # those versions, though, if it stored "_jit_bailout_depth" + # https://github.com/mir-group/nequip/commit/2f43aa84542df733bbe38cb9d6cca176b0e98054 + # Likely addresses https://github.com/mir-group/nequip/issues/431 + warnings.warn( + f"{model_path} appears to be from a older (0.5.* or earlier) version of `nequip` " + "that pre-dates a variety of breaking changes. Please carefully check the " + "correctness of your results for unexpected behaviour, and consider re-deploying " + "your model using this current `nequip` installation." + ) + else: + raise ValueError( + f"{model_path} does not seem to be a deployed NequIP model file" + ) # Confirm its TorchScript assert isinstance(model, torch.jit.ScriptModule) # Make sure we're in eval mode @@ -129,11 +142,14 @@ def load_deployed_model( if metadata[DEFAULT_DTYPE_KEY] == "": # Default and model go together assert metadata[MODEL_DTYPE_KEY] == "" - # If there isn't a dtype, it should be older than 0.6.0: - assert packaging.version.parse( - metadata[NEQUIP_VERSION_KEY] - ) < packaging.version.parse("0.6.0") - # i.e. no value due to L85 above + # If there isn't a dtype, it should be older than 0.6.0---but + # this may not be reflected in the version fields (see above check) + # So we only check if it is available: + if len(metadata[NEQUIP_VERSION_KEY]) > 0: + assert packaging.version.parse( + metadata[NEQUIP_VERSION_KEY] + ) < packaging.version.parse("0.6.0") + # The old pre-0.6.0 defaults: metadata[DEFAULT_DTYPE_KEY] = "float32" metadata[MODEL_DTYPE_KEY] = "float32" From c32efeebd112b79c960a1ca3fb93c94324beb998 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Mon, 24 Jun 2024 16:32:01 -0400 Subject: [PATCH 07/27] Update `pre-commit` config, `flake8` no longer on gitlab --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a5f9bb9..4d4b9abb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,11 +4,11 @@ fail_fast: true repos: - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 24.4.2 hooks: - id: black - - repo: https://gitlab.com/pycqa/flake8 - rev: 4.0.1 + - repo: https://github.com/pycqa/flake8 + rev: 7.1.0 hooks: - id: flake8 From b27ec10ef564f95d9fc1a0277baf9c5ad40b06d4 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Mon, 24 Jun 2024 16:32:28 -0400 Subject: [PATCH 08/27] Allow `n_train` and `n_val` to be set as percentages, and add tests --- nequip/train/trainer.py | 80 +++++++++++++--- tests/unit/trainer/test_trainer.py | 143 +++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+), 15 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index bdfb4f17..21c7199d 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -7,6 +7,7 @@ make an interface with ray """ + import sys import inspect import logging @@ -157,9 +158,9 @@ class Trainer: batch_size (int): size of each batch validation_batch_size (int): batch size for evaluating the model for validation shuffle (bool): parameters for dataloader - n_train (int): # of frames for training + n_train (int, str): # of frames for training (as int, or as a percentage string) n_train_per_epoch (optional int): how many frames from `n_train` to use each epoch; see `PartialSampler`. When `None`, all `n_train` frames will be used each epoch. - n_val (int): # of frames for validation + n_val (int), str: # of frames for validation (as int, or as a percentage string) exclude_keys (list): fields from dataset to ignore. dataloader_num_workers (int): `num_workers` for the `DataLoader`s train_idcs (optional, list): list of frames to use for training @@ -250,9 +251,9 @@ def __init__( batch_size: int = 5, validation_batch_size: int = 5, shuffle: bool = True, - n_train: Optional[int] = None, + n_train: Optional[Union[int, str]] = None, n_train_per_epoch: Optional[int] = None, - n_val: Optional[int] = None, + n_val: Optional[Union[int, str]] = None, dataloader_num_workers: int = 0, train_idcs: Optional[list] = None, val_idcs: Optional[list] = None, @@ -754,7 +755,6 @@ def init_metrics(self): ) def train(self): - """Training""" if getattr(self, "dl_train", None) is None: raise RuntimeError("You must call `set_dataset()` before calling `train()`") @@ -1144,12 +1144,55 @@ def __del__(self): for i in range(len(logger.handlers)): logger.handlers.pop() + def _parse_n_train_n_val( + self, train_dataset_size: int, val_dataset_size: int + ) -> tuple[int]: + # parse n_train and n_val (can be ints or str with percentage): + n_train_n_val = [] + for n_name in ["n_train", "n_val"]: + n = getattr(self, n_name) + if isinstance(n, str) and "%" in n: + dataset_size = ( + train_dataset_size if n_name == "n_train" else val_dataset_size + ) + n_train_n_val.append( + (float(n.strip("%")) / 100) * dataset_size + ) # convert to float first + elif isinstance(n, int): + n_train_n_val.append(n) + else: + raise ValueError( + f"Invalid value/type for {n_name}: {n} -- must be either int or str with %!" + ) + + floored_n_train_n_val = [int(n) for n in n_train_n_val] + # if n_train and n_val were both set as percentages which summed to 100%, make sure that sum of + # floored values comes to 100% of dataset size (i.e. that flooring doesn't omit a frame) + if ( + train_dataset_size == val_dataset_size + and isinstance(self.n_train, str) + and isinstance(self.n_val, str) + and np.isclose( + float(self.n_train.strip("%")) + float(self.n_val.strip("%")), 100 + ) + ): + if ( + sum(floored_n_train_n_val) != train_dataset_size + ): # one frame was cut, add to larger of the + # two float values (i.e. round up the percentage which gave a >= x.5 float value) + floored_n_train_n_val[ + np.argmax(n_train_n_val) + ] += train_dataset_size - sum(floored_n_train_n_val) + + return tuple(floored_n_train_n_val) + def set_dataset( self, dataset: AtomicDataset, validation_dataset: Optional[AtomicDataset] = None, ) -> None: - """Set the dataset(s) used by this trainer. + """ + Set the dataset(s) used by this trainer. Training and validation datasets will be sampled from them in accordance with the trainer's parameters. @@ -1163,7 +1206,10 @@ def set_dataset( if validation_dataset is None: # Sample both from `dataset`: total_n = len(dataset) - if (self.n_train + self.n_val) > total_n: + n_train, n_val = self._parse_n_train_n_val( + train_dataset_size=total_n, val_dataset_size=total_n + ) + if (n_train + n_val) > total_n: raise ValueError( "too little data for training and validation. please reduce n_train and n_val" ) @@ -1177,25 +1223,29 @@ def set_dataset( f"splitting mode {self.train_val_split} not implemented" ) - self.train_idcs = idcs[: self.n_train] - self.val_idcs = idcs[self.n_train : self.n_train + self.n_val] + self.train_idcs = idcs[:n_train] + self.val_idcs = idcs[n_train : n_train + n_val] else: - if self.n_train > len(dataset): + n_train, n_val = self._parse_n_train_n_val( + train_dataset_size=len(dataset), + val_dataset_size=len(validation_dataset), + ) + if n_train > len(dataset): raise ValueError("Not enough data in dataset for requested n_train") - if self.n_val > len(validation_dataset): + if n_val > len(validation_dataset): raise ValueError( "Not enough data in validation dataset for requested n_val" ) if self.train_val_split == "random": self.train_idcs = torch.randperm( len(dataset), generator=self.dataset_rng - )[: self.n_train] + )[:n_train] self.val_idcs = torch.randperm( len(validation_dataset), generator=self.dataset_rng - )[: self.n_val] + )[:n_val] elif self.train_val_split == "sequential": - self.train_idcs = torch.arange(self.n_train) - self.val_idcs = torch.arange(self.n_val) + self.train_idcs = torch.arange(n_train) + self.val_idcs = torch.arange(n_val) else: raise NotImplementedError( f"splitting mode {self.train_val_split} not implemented" diff --git a/tests/unit/trainer/test_trainer.py b/tests/unit/trainer/test_trainer.py index 197f3897..2383b829 100644 --- a/tests/unit/trainer/test_trainer.py +++ b/tests/unit/trainer/test_trainer.py @@ -1,6 +1,7 @@ """ Trainer tests """ + import pytest import numpy as np @@ -43,6 +44,10 @@ def dummy_builder(): early_stopping_lower_bounds={"LR": 1e-10}, model_builders=[dummy_builder], ) +N_TRAIN_PERCENT = "75%" +N_VAL_PERCENT = "15%" +N_TRAIN_PERCENT_100 = "70%" +N_VAL_PERCENT_100 = "30%" @pytest.fixture(scope="function") @@ -59,6 +64,45 @@ def trainer(float_tolerance): yield c +@pytest.fixture(scope="function") +def trainer_w_percent_n_train_n_val(float_tolerance): + """ + Generate a class instance with minimal configurations, + where n_train and n_val are given as percentage of the + dataset size. + """ + conf = minimal_config.copy() + conf["n_train"] = N_TRAIN_PERCENT + conf["n_val"] = N_VAL_PERCENT # note that summed percentages don't have to be 100% + conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] + model = model_from_config(conf) + with tempfile.TemporaryDirectory(prefix="output") as path: + conf["root"] = path + c = Trainer(model=model, **conf) + yield c + + +@pytest.fixture(scope="function") +def trainer_w_percent_n_train_n_val_flooring(float_tolerance): + """ + Generate a class instance with minimal configurations, + where n_train and n_val are given as percentage of the + dataset size, summing to 100% but with a split that gives + non-integer numbers of frames for n_train and n_val. + (i.e. n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames, + so final n_train is 6 and n_val is 2) + """ + conf = minimal_config.copy() + conf["n_train"] = N_TRAIN_PERCENT_100 + conf["n_val"] = N_VAL_PERCENT_100 + conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] + model = model_from_config(conf) + with tempfile.TemporaryDirectory(prefix="output") as path: + conf["root"] = path + c = Trainer(model=model, **conf) + yield c + + class TestTrainerSetUp: """ test initialization @@ -158,6 +202,105 @@ def test_split(self, trainer, nequip_dataset, mode): else: assert n_samples == trainer.n_train + @pytest.mark.parametrize("mode", ["random", "sequential"]) + def test_split_w_percent_n_train_n_val( + self, trainer_w_percent_n_train_n_val, nequip_dataset, mode + ): + # nequip_dataset has 8 frames, so setting n_train to 75% and n_val to 15% should give 6 and 1 + # frames respectively + trainer_w_percent_n_train_n_val.train_val_split = mode + trainer_w_percent_n_train_n_val.set_dataset(nequip_dataset) + for epoch_i in range(3): + trainer_w_percent_n_train_n_val.dl_train_sampler.step_epoch(epoch_i) + n_samples: int = 0 + n_val_samples: int = 0 + for i, batch in enumerate(trainer_w_percent_n_train_n_val.dl_train): + n_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + if trainer_w_percent_n_train_n_val.n_train_per_epoch is not None: + assert n_samples == trainer_w_percent_n_train_n_val.n_train_per_epoch + else: + assert ( + n_samples != trainer_w_percent_n_train_n_val.n_train + ) # n_train now a percentage + assert trainer_w_percent_n_train_n_val.n_train == N_TRAIN_PERCENT # 75% + assert n_samples == int( + (float(N_TRAIN_PERCENT.strip("%")) / 100) * len(nequip_dataset) + ) # 6 + assert trainer_w_percent_n_train_n_val.n_val == N_VAL_PERCENT # 15% + + for i, batch in enumerate(trainer_w_percent_n_train_n_val.dl_val): + n_val_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + + assert ( + n_val_samples != trainer_w_percent_n_train_n_val.n_val + ) # n_val now a percentage + assert trainer_w_percent_n_train_n_val.n_val == N_VAL_PERCENT # 15% + assert n_val_samples == int( + (float(N_VAL_PERCENT.strip("%")) / 100) * len(nequip_dataset) + ) # 1 (floored) + + @pytest.mark.parametrize("mode", ["random", "sequential"]) + def test_split_w_percent_n_train_n_val_flooring( + self, trainer_w_percent_n_train_n_val_flooring, nequip_dataset, mode + ): + # nequip_dataset has 8 frames, so n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames, + # so final n_train is 6 and n_val is 2 + trainer_w_percent_n_train_n_val_flooring.train_val_split = mode + trainer_w_percent_n_train_n_val_flooring.set_dataset(nequip_dataset) + for epoch_i in range(3): + trainer_w_percent_n_train_n_val_flooring.dl_train_sampler.step_epoch( + epoch_i + ) + n_samples: int = 0 + n_val_samples: int = 0 + for i, batch in enumerate( + trainer_w_percent_n_train_n_val_flooring.dl_train + ): + n_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + if trainer_w_percent_n_train_n_val_flooring.n_train_per_epoch is not None: + assert ( + n_samples + == trainer_w_percent_n_train_n_val_flooring.n_train_per_epoch + ) + else: + assert ( + n_samples != trainer_w_percent_n_train_n_val_flooring.n_train + ) # n_train now a percentage + assert ( + trainer_w_percent_n_train_n_val_flooring.n_train + == N_TRAIN_PERCENT_100 + ) # 70% + # _not_ equal to the bare floored value now: + assert n_samples != int( + (float(N_TRAIN_PERCENT_100.strip("%")) / 100) * len(nequip_dataset) + ) # 5 + assert ( + n_samples + == int( # equal to floored value plus 1 + (float(N_TRAIN_PERCENT_100.strip("%")) / 100) + * len(nequip_dataset) + ) + + 1 + ) # 6 + assert ( + trainer_w_percent_n_train_n_val_flooring.n_val == N_VAL_PERCENT_100 + ) # 30% + + for i, batch in enumerate(trainer_w_percent_n_train_n_val_flooring.dl_val): + n_val_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 + + assert ( + n_val_samples != trainer_w_percent_n_train_n_val_flooring.n_val + ) # n_val now a percentage + assert ( + trainer_w_percent_n_train_n_val_flooring.n_val == N_VAL_PERCENT_100 + ) # 30% + assert n_val_samples == int( + (float(N_VAL_PERCENT_100.strip("%")) / 100) * len(nequip_dataset) + ) # 2 (floored) + + assert n_samples + n_val_samples == len(nequip_dataset) # 100% coverage + class TestTrain: def test_train(self, trainer, nequip_dataset): From 3021d4e0d1f87d928d460bebce04b7e42ed0b511 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Mon, 24 Jun 2024 16:33:00 -0400 Subject: [PATCH 09/27] Pre-commit formatting --- nequip/data/AtomicDataDict.py | 1 + nequip/data/_dataset/_ase_dataset.py | 10 ++++++---- nequip/data/_keys.py | 1 + nequip/data/dataloader.py | 1 + nequip/nn/_convnetlayer.py | 6 +++--- nequip/nn/_grad_output.py | 3 +++ nequip/nn/_interaction_block.py | 1 + nequip/scripts/evaluate.py | 8 +++++--- nequip/scripts/train.py | 4 +++- nequip/utils/config.py | 1 + nequip/utils/savenload.py | 1 + nequip/utils/test.py | 8 +++++--- nequip/utils/unittests/model_tests.py | 10 ++++++---- tests/integration/test_evaluate.py | 8 +++++--- tests/unit/model/test_pair/test_zbl.py | 2 +- tests/unit/utils/test_config.py | 1 + tests/unit/utils/test_output.py | 1 + 17 files changed, 45 insertions(+), 22 deletions(-) diff --git a/nequip/data/AtomicDataDict.py b/nequip/data/AtomicDataDict.py index f7713e6f..9e7bf37c 100644 --- a/nequip/data/AtomicDataDict.py +++ b/nequip/data/AtomicDataDict.py @@ -5,6 +5,7 @@ Authors: Albert Musaelian """ + from typing import Dict, Any import torch diff --git a/nequip/data/_dataset/_ase_dataset.py b/nequip/data/_dataset/_ase_dataset.py index 3246d791..633b5e48 100644 --- a/nequip/data/_dataset/_ase_dataset.py +++ b/nequip/data/_dataset/_ase_dataset.py @@ -51,10 +51,12 @@ def _ase_dataset_reader( datas.append( ( global_index, - AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs) - if global_index in include_frames - # in-memory dataset will ignore this later, but needed for indexing to work out - else None, + ( + AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs) + if global_index in include_frames + # in-memory dataset will ignore this later, but needed for indexing to work out + else None + ), ) ) # Save to a tempfile--- diff --git a/nequip/data/_keys.py b/nequip/data/_keys.py index edd04cbe..93f926d3 100644 --- a/nequip/data/_keys.py +++ b/nequip/data/_keys.py @@ -2,6 +2,7 @@ This is a seperate module to compensate for a TorchScript bug that can only recognize constants when they are accessed as attributes of an imported module. """ + import sys from typing import List diff --git a/nequip/data/dataloader.py b/nequip/data/dataloader.py index ea9c7fc9..9b95cd66 100644 --- a/nequip/data/dataloader.py +++ b/nequip/data/dataloader.py @@ -84,6 +84,7 @@ class PartialSampler(Sampler[int]): If `None`, defaults to `len(data_source)`. generator (Generator): Generator used in sampling. """ + data_source: Dataset num_samples_per_epoch: int shuffle: bool diff --git a/nequip/nn/_convnetlayer.py b/nequip/nn/_convnetlayer.py index 9e5437a8..6d339cab 100644 --- a/nequip/nn/_convnetlayer.py +++ b/nequip/nn/_convnetlayer.py @@ -149,9 +149,9 @@ def __init__( # updated with whatever the convolution outputs (which is a full graph module) self.irreps_out.update(self.conv.irreps_out) # but with the features updated by the nonlinearity - self.irreps_out[ - AtomicDataDict.NODE_FEATURES_KEY - ] = self.equivariant_nonlin.irreps_out + self.irreps_out[AtomicDataDict.NODE_FEATURES_KEY] = ( + self.equivariant_nonlin.irreps_out + ) def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # save old features for resnet diff --git a/nequip/nn/_grad_output.py b/nequip/nn/_grad_output.py index ee0ce6f9..eb04d78a 100644 --- a/nequip/nn/_grad_output.py +++ b/nequip/nn/_grad_output.py @@ -20,6 +20,7 @@ class GradientOutput(GraphModuleMixin, torch.nn.Module): out_field: the field in which to return the computed gradients. Defaults to ``f"d({of})/d({wrt})"`` for each field in ``wrt``. sign: either 1 or -1; the returned gradient is multiplied by this. """ + sign: float _negate: bool skip: bool @@ -119,6 +120,7 @@ class PartialForceOutput(GraphModuleMixin, torch.nn.Module): vectorize: the vectorize option to ``torch.autograd.functional.jacobian``, false by default since it doesn't work well. """ + vectorize: bool def __init__( @@ -183,6 +185,7 @@ class StressOutput(GraphModuleMixin, torch.nn.Module): func: the energy model to wrap do_forces: whether to compute forces as well """ + do_forces: bool def __init__( diff --git a/nequip/nn/_interaction_block.py b/nequip/nn/_interaction_block.py index f3164709..a9dcecd7 100644 --- a/nequip/nn/_interaction_block.py +++ b/nequip/nn/_interaction_block.py @@ -1,4 +1,5 @@ """ Interaction Block """ + from typing import Optional, Dict, Callable import torch diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index 20382eef..b40c3a8a 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -383,9 +383,11 @@ def main(args=None, running_as_script: bool = True): if do_metrics: display_bar = context_stack.enter_context( tqdm( - bar_format="" - if prog.disable # prog.ncols doesn't exist if disabled - else ("{desc:." + str(prog.ncols) + "}"), + bar_format=( + "" + if prog.disable # prog.ncols doesn't exist if disabled + else ("{desc:." + str(prog.ncols) + "}") + ), disable=None, ) ) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 4caf5636..02460aad 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -1,4 +1,5 @@ """ Train a network.""" + import logging import argparse import warnings @@ -24,7 +25,8 @@ from nequip.scripts._logger import set_up_script_logger warnings.filterwarnings( # unnecessary e3nn-related JIT warning - "ignore", message="The TorchScript type system doesn't support instance-level annotations" + "ignore", + message="The TorchScript type system doesn't support instance-level annotations", ) default_config = dict( root="./", diff --git a/nequip/utils/config.py b/nequip/utils/config.py index ca79f576..c160a135 100644 --- a/nequip/utils/config.py +++ b/nequip/utils/config.py @@ -34,6 +34,7 @@ If a parameter is updated, the updated value will be formatted back to the same type. """ + from typing import Set, Dict, Any, List import inspect diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 53b09fcf..ffe60b19 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -1,6 +1,7 @@ """ utilities that involve file searching and operations (i.e. save/load) """ + from typing import Union, List, Tuple, Optional, Callable import sys import logging diff --git a/nequip/utils/test.py b/nequip/utils/test.py index 7c0bde3f..3de97194 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -45,9 +45,11 @@ def assert_permutation_equivariant( if tolerance is None: atol = PERMUTATION_FLOAT_TOLERANCE[ - func.model_dtype - if isinstance(func, GraphModel) - else torch.get_default_dtype() + ( + func.model_dtype + if isinstance(func, GraphModel) + else torch.get_default_dtype() + ) ] else: atol = tolerance diff --git a/nequip/utils/unittests/model_tests.py b/nequip/utils/unittests/model_tests.py index 37e9dcb6..f2755c9d 100644 --- a/nequip/utils/unittests/model_tests.py +++ b/nequip/utils/unittests/model_tests.py @@ -449,10 +449,12 @@ def test_partial_forces(self, config, atomic_batch, device, strict_locality): assert torch.allclose( output[k], output_partial[k], - atol=1e-8 - if k == AtomicDataDict.TOTAL_ENERGY_KEY - and torch.get_default_dtype() == torch.float64 - else 1e-5, + atol=( + 1e-8 + if k == AtomicDataDict.TOTAL_ENERGY_KEY + and torch.get_default_dtype() == torch.float64 + else 1e-5 + ), ) else: assert torch.equal(output[k], output_partial[k]) diff --git a/tests/integration/test_evaluate.py b/tests/integration/test_evaluate.py index 4dd9bce0..4a1388ed 100644 --- a/tests/integration/test_evaluate.py +++ b/tests/integration/test_evaluate.py @@ -171,9 +171,11 @@ def runit(params: dict): assert np.allclose( err, 0.0, - atol=1e-8 - if true_identity - else (1e-2 if metric.startswith("e") else 1e-4), + atol=( + 1e-8 + if true_identity + else (1e-2 if metric.startswith("e") else 1e-4) + ), ), f"Metric `{metric}` wasn't zero!" elif builder == ConstFactorModel: # TODO: check comperable to naive numpy compute diff --git a/tests/unit/model/test_pair/test_zbl.py b/tests/unit/model/test_pair/test_zbl.py index b862b624..c578cb6c 100644 --- a/tests/unit/model/test_pair/test_zbl.py +++ b/tests/unit/model/test_pair/test_zbl.py @@ -81,7 +81,7 @@ def test_lammps_repro(self, config): # $ lmp -in zbl_data.lmps # $ python -c "import numpy as np; d = np.loadtxt('zbl.dat', skiprows=1); np.save('zbl.npy', d)" refdata = np.load(Path(__file__).parent / "zbl.npy") - for (r, Zi, Zj, pe, fxi, fxj) in refdata: + for r, Zi, Zj, pe, fxi, fxj in refdata: if r >= r_max: continue atoms.positions[1, 0] = r diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py index 35ae7b68..22025ffe 100644 --- a/tests/unit/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -1,6 +1,7 @@ """ Config tests """ + import pytest from os import remove diff --git a/tests/unit/utils/test_output.py b/tests/unit/utils/test_output.py index cdc7b4ac..ec79dd1f 100644 --- a/tests/unit/utils/test_output.py +++ b/tests/unit/utils/test_output.py @@ -1,6 +1,7 @@ """ Config tests """ + import pytest import tempfile From 5c22f4a36a7d3366dc640754bd15dcfa05ed7051 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Mon, 24 Jun 2024 18:10:20 -0400 Subject: [PATCH 10/27] Add example percentage setting to `full.yaml` --- configs/full.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/configs/full.yaml b/configs/full.yaml index d13ff041..765d90d8 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -181,6 +181,9 @@ save_ema_checkpoint_freq: -1 # training n_train: 100 # number of training data n_val: 50 # number of validation data +# alternatively, n_train and n_val can be set as percentages of the dataset size: +# n_train: 70% # 70% of dataset +# n_val: 30% # 30% of dataset (if validation_dataset not set), or 30% of validation_dataset (if set) learning_rate: 0.005 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune batch_size: 5 # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better validation_batch_size: 10 # batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training. From db2c2e7d8b57759b5a31a7080d9049398fae80ea Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Mon, 24 Jun 2024 18:21:24 -0400 Subject: [PATCH 11/27] Update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c60ed185..dc6d5cce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## Unreleased +- Allow `n_train` and `n_val` to be specified as percentages of datasets. +- Only attempt training restart if `trainer.pth` file present (prevents unnecessary crashes due to file-not-found errors in some cases) ## [0.6.0] - 2024-5-10 ### Added From 950b5760547d59d358139346c9133ecb582fd2f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Se=C3=A1n=20Kavanagh?= <51478689+kavanase@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:12:42 -0400 Subject: [PATCH 12/27] Update nequip/train/trainer.py Alby review point 2 Co-authored-by: Alby M. <1473644+Linux-cpp-lisp@users.noreply.github.com> --- nequip/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 21c7199d..d421c68d 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -1156,7 +1156,7 @@ def _parse_n_train_n_val( train_dataset_size if n_name == "n_train" else val_dataset_size ) n_train_n_val.append( - (float(n.strip("%")) / 100) * dataset_size + (float(n.rstrip("%")) / 100) * dataset_size ) # convert to float first elif isinstance(n, int): n_train_n_val.append(n) From f05b395ad9ae8f51945e692073d695748fad5c80 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Thu, 27 Jun 2024 11:12:53 -0400 Subject: [PATCH 13/27] Alby review --- nequip/train/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 21c7199d..ffa177c9 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -1149,12 +1149,12 @@ def _parse_n_train_n_val( ) -> tuple[int]: # parse n_train and n_val (can be ints or str with percentage): n_train_n_val = [] - for n_name in ["n_train", "n_val"]: + for n_name, dataset_size in ( + ("n_train", train_dataset_size), + ("n_val", val_dataset_size), + ): n = getattr(self, n_name) if isinstance(n, str) and "%" in n: - dataset_size = ( - train_dataset_size if n_name == "n_train" else val_dataset_size - ) n_train_n_val.append( (float(n.strip("%")) / 100) * dataset_size ) # convert to float first From 7e1ff56b5d243d8ce87ec379295fbcba419e36fc Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Thu, 27 Jun 2024 11:19:32 -0400 Subject: [PATCH 14/27] Revert "Avoid unnecessary e3nn-related JIT warning" (be22479f), fixed upstream by @Linux-cpp-lisp in https://github.com/e3nn/e3nn/pull/437 --- nequip/scripts/train.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 02460aad..7ba1538b 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -24,10 +24,6 @@ from nequip.utils._global_options import _set_global_options from nequip.scripts._logger import set_up_script_logger -warnings.filterwarnings( # unnecessary e3nn-related JIT warning - "ignore", - message="The TorchScript type system doesn't support instance-level annotations", -) default_config = dict( root="./", tensorboard=False, From 1177cfd6297983d779b2dd780bd3d5dbec38de00 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Thu, 27 Jun 2024 11:41:50 -0400 Subject: [PATCH 15/27] Update `lint.yaml` to latest versions --- .github/workflows/lint.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 5f7c96cd..df800d8e 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -18,7 +18,7 @@ jobs: - name: Black Check uses: psf/black@stable with: - version: "22.3.0" + version: "24.4.2" flake8: runs-on: ubuntu-latest @@ -29,7 +29,7 @@ jobs: python-version: '3.x' - name: Install flake8 run: | - pip install flake8==7.0.0 + pip install flake8==7.1.0 - name: run flake8 run: | flake8 . --count --show-source --statistics From 5db1bd733201dc8f64ea5309ae76a8d43e024214 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sat, 29 Jun 2024 18:04:21 -0400 Subject: [PATCH 16/27] fix version check --- nequip/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nequip/__init__.py b/nequip/__init__.py index 654cd78a..537242b3 100644 --- a/nequip/__init__.py +++ b/nequip/__init__.py @@ -11,10 +11,8 @@ torch_version = packaging.version.parse(torch.__version__) # only allow 1.11*, 1.13* or higher (no 1.12.*) -assert (torch_version > packaging.version.parse("1.11.0")) and not ( - packaging.version.parse("1.12.0") - <= torch_version - < packaging.version.parse("1.13.0") +assert (torch_version == packaging.version.parse("1.11")) or ( + torch_version >= packaging.version.parse("1.13") ), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found" # warn if using 1.13* or 2.0.* From 8d5179e0beed83a7d932251df757b78a6c68db4f Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Sat, 29 Jun 2024 18:05:49 -0400 Subject: [PATCH 17/27] Fix performance warning --- nequip/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nequip/__init__.py b/nequip/__init__.py index 537242b3..ce145b41 100644 --- a/nequip/__init__.py +++ b/nequip/__init__.py @@ -16,9 +16,9 @@ ), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found" # warn if using 1.13* or 2.0.* -if packaging.version.parse("1.13.0") <= torch_version < packaging.version.parse("2.1"): +if packaging.version.parse("1.13.0") <= torch_version: warnings.warn( - f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.0.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue." + f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue." ) From ee03f0c540f584129507dc4f74d441240a1cbf3c Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:54:46 -0400 Subject: [PATCH 18/27] fewer spurious failures --- nequip/utils/unittests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/utils/unittests/conftest.py b/nequip/utils/unittests/conftest.py index a2dc103d..59bee4ec 100644 --- a/nequip/utils/unittests/conftest.py +++ b/nequip/utils/unittests/conftest.py @@ -45,7 +45,7 @@ # The default float tolerance FLOAT_TOLERANCE = { t: torch.as_tensor(v, dtype=dtype_from_name(t)) - for t, v in {"float32": 1e-3, "float64": 1e-10}.items() + for t, v in {"float32": 1e-3, "float64": 1e-8}.items() } From 0eabec7c421f2330adbdc7e9d50c2d649563fc7f Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 1 Jul 2024 16:07:27 -0400 Subject: [PATCH 19/27] take FLOAT_TOLERANCE from nequip, not e3nn --- nequip/utils/test.py | 22 +++++++++++++++++----- nequip/utils/unittests/conftest.py | 9 +-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/nequip/utils/test.py b/nequip/utils/test.py index de843418..4c7450b5 100644 --- a/nequip/utils/test.py +++ b/nequip/utils/test.py @@ -2,7 +2,7 @@ import torch from e3nn import o3 -from e3nn.util.test import equivariance_error, FLOAT_TOLERANCE +from e3nn.util.test import equivariance_error from nequip.nn import GraphModuleMixin, GraphModel from nequip.data import ( @@ -12,7 +12,17 @@ _EDGE_FIELDS, _CARTESIAN_TENSOR_FIELDS, ) - +from nequip.utils.misc import dtype_from_name + +# The default float tolerance +FLOAT_TOLERANCE = { + t: torch.as_tensor(v, dtype=dtype_from_name(t)) + for t, v in {"float32": 1e-3, "float64": 1e-10}.items() +} +# Allow lookup by name or dtype object: +for t, v in list(FLOAT_TOLERANCE.items()): + FLOAT_TOLERANCE[dtype_from_name(t)] = v +del t, v # This has to be somewhat large because of float32 sum reductions over many edges/atoms PERMUTATION_FLOAT_TOLERANCE = {torch.float32: 1e-4, torch.float64: 1e-10} @@ -46,9 +56,11 @@ def assert_permutation_equivariant( if tolerance is None: atol = PERMUTATION_FLOAT_TOLERANCE[ - func.model_dtype - if isinstance(func, GraphModel) - else torch.get_default_dtype() + ( + func.model_dtype + if isinstance(func, GraphModel) + else torch.get_default_dtype() + ) ] else: atol = tolerance diff --git a/nequip/utils/unittests/conftest.py b/nequip/utils/unittests/conftest.py index 59bee4ec..aa716d07 100644 --- a/nequip/utils/unittests/conftest.py +++ b/nequip/utils/unittests/conftest.py @@ -12,12 +12,11 @@ import torch -from nequip.utils.test import set_irreps_debug +from nequip.utils.test import set_irreps_debug, FLOAT_TOLERANCE from nequip.data import AtomicData, ASEDataset from nequip.data.transforms import TypeMapper from nequip.utils.torch_geometric import Batch from nequip.utils._global_options import _set_global_options -from nequip.utils.misc import dtype_from_name # Sometimes we run parallel using pytest-xdist, and want to be able to use # as many GPUs as are available @@ -42,12 +41,6 @@ # Test parallelization, but don't waste time spawning tons of workers if lots of cores available os.environ["NEQUIP_NUM_TASKS"] = "2" -# The default float tolerance -FLOAT_TOLERANCE = { - t: torch.as_tensor(v, dtype=dtype_from_name(t)) - for t, v in {"float32": 1e-3, "float64": 1e-8}.items() -} - @pytest.fixture(scope="session", autouse=True, params=["float32", "float64"]) def float_tolerance(request): From a4650260cc20bf3c03660b124eb43971f6f10a51 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 1 Jul 2024 16:09:33 -0400 Subject: [PATCH 20/27] relax tolerance for F64 --- nequip/utils/unittests/model_tests.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/nequip/utils/unittests/model_tests.py b/nequip/utils/unittests/model_tests.py index 37e9dcb6..b9d6790b 100644 --- a/nequip/utils/unittests/model_tests.py +++ b/nequip/utils/unittests/model_tests.py @@ -228,7 +228,13 @@ def test_equivariance(self, model, atomic_batch, device): instance, out_fields = model instance = instance.to(device=device) atomic_batch = atomic_batch.to(device=device) - assert_AtomicData_equivariant(func=instance, data_in=atomic_batch) + assert_AtomicData_equivariant( + func=instance, + data_in=atomic_batch, + e3_tolerance={torch.float32: 1e-3, torch.float64: 1e-8}[ + torch.get_default_dtype() + ], + ) def test_embedding_cutoff(self, model, config, device): instance, out_fields = model @@ -449,10 +455,12 @@ def test_partial_forces(self, config, atomic_batch, device, strict_locality): assert torch.allclose( output[k], output_partial[k], - atol=1e-8 - if k == AtomicDataDict.TOTAL_ENERGY_KEY - and torch.get_default_dtype() == torch.float64 - else 1e-5, + atol=( + 1e-8 + if k == AtomicDataDict.TOTAL_ENERGY_KEY + and torch.get_default_dtype() == torch.float64 + else 1e-5 + ), ) else: assert torch.equal(output[k], output_partial[k]) From 94979e86af9377b4eeead6665bc230441abd2542 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Tue, 2 Jul 2024 12:09:14 -0400 Subject: [PATCH 21/27] Clear previous run folder and warn if no model present --- nequip/scripts/train.py | 12 +++++++++++- nequip/train/trainer.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index 7ba1538b..e83bd299 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -8,7 +8,8 @@ # Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance. import numpy as np # noqa: F401 -from os.path import exists +from os.path import exists, isdir +from shutil import rmtree from pathlib import Path import torch @@ -78,6 +79,15 @@ def main(args=None, running_as_script: bool = True): f"Training instance exists at {config.root}/{config.run_name}; " "either set append to True or use a different root or runname" ) + elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"): + # output directory exists but no ``trainer.pth`` file, suggesting previous run crash during + # first training epoch (usually due to memory): + warnings.warn( + f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model " + f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will " + f"be started." + ) + rmtree(f"{config.root}/{config.run_name}") # for fresh new train if not found_restart_file: diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 8788c84a..8fd4d803 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -108,7 +108,7 @@ class Trainer: - "trainer_save.pth": all the training information. The file used for loading and restart For restart run, the default set up is to not append to the original folders and files. - The Output class will automatically build a folder call root/run_name + The Output class will automatically build a folder called ``root/run_name`` If append mode is on, the log file will be appended and the best model and last model will be overwritten. More examples can be found in tests/train/test_trainer.py From 3e0009d39f7a1de89f0e84d07055224a1042b18c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Se=C3=A1n=20Kavanagh?= <51478689+kavanase@users.noreply.github.com> Date: Tue, 2 Jul 2024 12:10:36 -0400 Subject: [PATCH 22/27] Update nequip/train/trainer.py Co-authored-by: Alby M. <1473644+Linux-cpp-lisp@users.noreply.github.com> --- nequip/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 8fd4d803..96a1e94b 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -1146,7 +1146,7 @@ def __del__(self): def _parse_n_train_n_val( self, train_dataset_size: int, val_dataset_size: int - ) -> tuple[int]: + ) -> Tuple[int, int]: # parse n_train and n_val (can be ints or str with percentage): n_train_n_val = [] for n_name, dataset_size in ( From 4de79c3faf0d0c8f9a88efeebe35fb397e28b7d8 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Tue, 2 Jul 2024 12:14:09 -0400 Subject: [PATCH 23/27] Check n_train/n_val are at least 1 and add error if not --- nequip/train/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 8fd4d803..ab301e93 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -1166,6 +1166,10 @@ def _parse_n_train_n_val( ) floored_n_train_n_val = [int(n) for n in n_train_n_val] + for n, n_name in zip(floored_n_train_n_val, ["n_train", "n_val"]): + if n < 1: + raise ValueError(f"{n_name} must be at least 1! Got {n}.") + # if n_train and n_val were both set as percentages which summed to 100%, make sure that sum of # floored values comes to 100% of dataset size (i.e. that flooring doesn't omit a frame) if ( From 27dcae02340d9728c34edf0947604cea0d56df71 Mon Sep 17 00:00:00 2001 From: Sean Kavanagh Date: Tue, 2 Jul 2024 14:27:19 -0400 Subject: [PATCH 24/27] Update trainer tests; don't duplicate trainer setup, and also test some more combos via pytest parametrize --- tests/unit/trainer/test_trainer.py | 108 ++++++++++++++--------------- 1 file changed, 52 insertions(+), 56 deletions(-) diff --git a/tests/unit/trainer/test_trainer.py b/tests/unit/trainer/test_trainer.py index 2383b829..9e2bd64e 100644 --- a/tests/unit/trainer/test_trainer.py +++ b/tests/unit/trainer/test_trainer.py @@ -44,36 +44,16 @@ def dummy_builder(): early_stopping_lower_bounds={"LR": 1e-10}, model_builders=[dummy_builder], ) -N_TRAIN_PERCENT = "75%" -N_VAL_PERCENT = "15%" -N_TRAIN_PERCENT_100 = "70%" -N_VAL_PERCENT_100 = "30%" -@pytest.fixture(scope="function") -def trainer(float_tolerance): - """ - Generate a class instance with minimal configurations - """ - conf = minimal_config.copy() - conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] - model = model_from_config(conf) - with tempfile.TemporaryDirectory(prefix="output") as path: - conf["root"] = path - c = Trainer(model=model, **conf) - yield c - - -@pytest.fixture(scope="function") -def trainer_w_percent_n_train_n_val(float_tolerance): +def create_trainer(float_tolerance, **kwargs): """ Generate a class instance with minimal configurations, - where n_train and n_val are given as percentage of the - dataset size. + with the option to modify the configurations using + kwargs. """ conf = minimal_config.copy() - conf["n_train"] = N_TRAIN_PERCENT - conf["n_val"] = N_VAL_PERCENT # note that summed percentages don't have to be 100% + conf.update(kwargs) conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] model = model_from_config(conf) with tempfile.TemporaryDirectory(prefix="output") as path: @@ -83,24 +63,11 @@ def trainer_w_percent_n_train_n_val(float_tolerance): @pytest.fixture(scope="function") -def trainer_w_percent_n_train_n_val_flooring(float_tolerance): +def trainer(float_tolerance): """ - Generate a class instance with minimal configurations, - where n_train and n_val are given as percentage of the - dataset size, summing to 100% but with a split that gives - non-integer numbers of frames for n_train and n_val. - (i.e. n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames, - so final n_train is 6 and n_val is 2) + Generate a class instance with minimal configurations. """ - conf = minimal_config.copy() - conf["n_train"] = N_TRAIN_PERCENT_100 - conf["n_val"] = N_VAL_PERCENT_100 - conf["default_dtype"] = str(torch.get_default_dtype())[len("torch.") :] - model = model_from_config(conf) - with tempfile.TemporaryDirectory(prefix="output") as path: - conf["root"] = path - c = Trainer(model=model, **conf) - yield c + yield from create_trainer(float_tolerance) class TestTrainerSetUp: @@ -203,11 +170,25 @@ def test_split(self, trainer, nequip_dataset, mode): assert n_samples == trainer.n_train @pytest.mark.parametrize("mode", ["random", "sequential"]) + @pytest.mark.parametrize( + "n_train_percent, n_val_percent", [("75%", "15%"), ("20%", "30%")] + ) def test_split_w_percent_n_train_n_val( - self, trainer_w_percent_n_train_n_val, nequip_dataset, mode + self, nequip_dataset, mode, float_tolerance, n_train_percent, n_val_percent ): + """ + Test case where n_train and n_val are given as percentage of the + dataset size, and here they don't sum to 100%. + """ # nequip_dataset has 8 frames, so setting n_train to 75% and n_val to 15% should give 6 and 1 - # frames respectively + # frames respectively. Note that summed percentages don't have to be 100% + trainer_w_percent_n_train_n_val = next( + create_trainer( + float_tolerance=float_tolerance, + n_train=n_train_percent, + n_val=n_val_percent, + ) + ) trainer_w_percent_n_train_n_val.train_val_split = mode trainer_w_percent_n_train_n_val.set_dataset(nequip_dataset) for epoch_i in range(3): @@ -222,11 +203,11 @@ def test_split_w_percent_n_train_n_val( assert ( n_samples != trainer_w_percent_n_train_n_val.n_train ) # n_train now a percentage - assert trainer_w_percent_n_train_n_val.n_train == N_TRAIN_PERCENT # 75% + assert trainer_w_percent_n_train_n_val.n_train == n_train_percent # 75% assert n_samples == int( - (float(N_TRAIN_PERCENT.strip("%")) / 100) * len(nequip_dataset) + (float(n_train_percent.strip("%")) / 100) * len(nequip_dataset) ) # 6 - assert trainer_w_percent_n_train_n_val.n_val == N_VAL_PERCENT # 15% + assert trainer_w_percent_n_train_n_val.n_val == n_val_percent # 15% for i, batch in enumerate(trainer_w_percent_n_train_n_val.dl_val): n_val_samples += batch[AtomicDataDict.BATCH_PTR_KEY].shape[0] - 1 @@ -234,17 +215,34 @@ def test_split_w_percent_n_train_n_val( assert ( n_val_samples != trainer_w_percent_n_train_n_val.n_val ) # n_val now a percentage - assert trainer_w_percent_n_train_n_val.n_val == N_VAL_PERCENT # 15% + assert trainer_w_percent_n_train_n_val.n_val == n_val_percent # 15% assert n_val_samples == int( - (float(N_VAL_PERCENT.strip("%")) / 100) * len(nequip_dataset) + (float(n_val_percent.strip("%")) / 100) * len(nequip_dataset) ) # 1 (floored) @pytest.mark.parametrize("mode", ["random", "sequential"]) + @pytest.mark.parametrize( + "n_train_percent, n_val_percent", [("70%", "30%"), ("55%", "45%")] + ) def test_split_w_percent_n_train_n_val_flooring( - self, trainer_w_percent_n_train_n_val_flooring, nequip_dataset, mode + self, nequip_dataset, mode, float_tolerance, n_train_percent, n_val_percent ): + """ + Test case where n_train and n_val are given as percentage of the + dataset size, summing to 100% but with a split that gives + non-integer numbers of frames for n_train and n_val. + (i.e. n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames, + so final n_train is 6 and n_val is 2) + """ # nequip_dataset has 8 frames, so n_train = 70% = 5.6 frames, n_val = 30% = 2.4 frames, # so final n_train is 6 and n_val is 2 + trainer_w_percent_n_train_n_val_flooring = next( + create_trainer( + float_tolerance=float_tolerance, + n_train=n_train_percent, + n_val=n_val_percent, + ) + ) trainer_w_percent_n_train_n_val_flooring.train_val_split = mode trainer_w_percent_n_train_n_val_flooring.set_dataset(nequip_dataset) for epoch_i in range(3): @@ -267,23 +265,21 @@ def test_split_w_percent_n_train_n_val_flooring( n_samples != trainer_w_percent_n_train_n_val_flooring.n_train ) # n_train now a percentage assert ( - trainer_w_percent_n_train_n_val_flooring.n_train - == N_TRAIN_PERCENT_100 + trainer_w_percent_n_train_n_val_flooring.n_train == n_train_percent ) # 70% # _not_ equal to the bare floored value now: assert n_samples != int( - (float(N_TRAIN_PERCENT_100.strip("%")) / 100) * len(nequip_dataset) + (float(n_train_percent.strip("%")) / 100) * len(nequip_dataset) ) # 5 assert ( n_samples == int( # equal to floored value plus 1 - (float(N_TRAIN_PERCENT_100.strip("%")) / 100) - * len(nequip_dataset) + (float(n_train_percent.strip("%")) / 100) * len(nequip_dataset) ) + 1 ) # 6 assert ( - trainer_w_percent_n_train_n_val_flooring.n_val == N_VAL_PERCENT_100 + trainer_w_percent_n_train_n_val_flooring.n_val == n_val_percent ) # 30% for i, batch in enumerate(trainer_w_percent_n_train_n_val_flooring.dl_val): @@ -293,10 +289,10 @@ def test_split_w_percent_n_train_n_val_flooring( n_val_samples != trainer_w_percent_n_train_n_val_flooring.n_val ) # n_val now a percentage assert ( - trainer_w_percent_n_train_n_val_flooring.n_val == N_VAL_PERCENT_100 + trainer_w_percent_n_train_n_val_flooring.n_val == n_val_percent ) # 30% assert n_val_samples == int( - (float(N_VAL_PERCENT_100.strip("%")) / 100) * len(nequip_dataset) + (float(n_val_percent.strip("%")) / 100) * len(nequip_dataset) ) # 2 (floored) assert n_samples + n_val_samples == len(nequip_dataset) # 100% coverage From b73bf6f32b93cdc6dfbeb77cf79231109e869438 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Tue, 9 Jul 2024 10:08:48 -0400 Subject: [PATCH 25/27] prepare for 0.6.1 release --- CHANGELOG.md | 5 ++++- nequip/_version.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bfab0c68..e1c228d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. -## Unreleased - 0.6.1 +## Unreleased + + +## [0.6.1] - 2024-7-9 ### Added - add support for equivariance testing of arbitrary Cartesian tensor outputs - [Breaking] use entry points for `nequip.extension`s (e.g. for field registration) diff --git a/nequip/_version.py b/nequip/_version.py index 8e22989a..6c1533d0 100644 --- a/nequip/_version.py +++ b/nequip/_version.py @@ -2,4 +2,4 @@ # See Python packaging guide # https://packaging.python.org/guides/single-sourcing-package-version/ -__version__ = "0.6.0" +__version__ = "0.6.1" From feff269f8f273f041b906c63724dac86d629f762 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Tue, 9 Jul 2024 11:09:53 -0400 Subject: [PATCH 26/27] fix torch-numpy dependency in github workflows --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0fb33150..b1e75392 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,6 +31,7 @@ jobs: run: | python -m pip install --upgrade pip pip install setuptools wheel + if [ ${TORCH} == 1.13.1 ]; then pip install numpy==1.26.4; fi # older torch versions fail with numpy 2 pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install h5py scikit-learn # install packages that aren't required dependencies but that the tests do need pip install --upgrade-strategy only-if-needed . From 69385abe0e1a2cfed25851336fe17ec1c6571b71 Mon Sep 17 00:00:00 2001 From: Chuin Wei Tan <87742566+cw-tan@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:34:30 -0400 Subject: [PATCH 27/27] Update github workflows Co-authored-by: Alby M. <1473644+Linux-cpp-lisp@users.noreply.github.com> --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b1e75392..b09cb076 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,7 +31,7 @@ jobs: run: | python -m pip install --upgrade pip pip install setuptools wheel - if [ ${TORCH} == 1.13.1 ]; then pip install numpy==1.26.4; fi # older torch versions fail with numpy 2 + if [ ${TORCH} = "1.13.1" ]; then pip install numpy==1.*; fi # older torch versions fail with numpy 2 pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install h5py scikit-learn # install packages that aren't required dependencies but that the tests do need pip install --upgrade-strategy only-if-needed .