Skip to content

Commit 2cdd8c5

Browse files
Merge pull request #438 from kavanase/main
Minor Updates (`n_train/val` as percent, pre-commit, restart handling)
2 parents a465026 + 27dcae0 commit 2cdd8c5

20 files changed

+269
-40
lines changed

.github/workflows/lint.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
- name: Black Check
1919
uses: psf/black@stable
2020
with:
21-
version: "22.3.0"
21+
version: "24.4.2"
2222

2323
flake8:
2424
runs-on: ubuntu-latest
@@ -29,7 +29,7 @@ jobs:
2929
python-version: '3.x'
3030
- name: Install flake8
3131
run: |
32-
pip install flake8==7.0.0
32+
pip install flake8==7.1.0
3333
- name: run flake8
3434
run: |
3535
flake8 . --count --show-source --statistics

.pre-commit-config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ fail_fast: true
44

55
repos:
66
- repo: https://github.com/psf/black
7-
rev: 22.3.0
7+
rev: 24.4.2
88
hooks:
99
- id: black
1010

11-
- repo: https://gitlab.com/pycqa/flake8
12-
rev: 4.0.1
11+
- repo: https://github.com/pycqa/flake8
12+
rev: 7.1.0
1313
hooks:
1414
- id: flake8

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
Most recent change on the bottom.
88

9+
910
## Unreleased - 0.6.1
1011
### Added
1112
- add support for equivariance testing of arbitrary Cartesian tensor outputs
1213
- [Breaking] use entry points for `nequip.extension`s (e.g. for field registration)
1314
- alternate neighborlist support enabled with `NEQUIP_NL` environment variable, which can be set to `ase` (default), `matscipy` or `vesin`
15+
- Allow `n_train` and `n_val` to be specified as percentages of datasets.
16+
- Only attempt training restart if `trainer.pth` file present (prevents unnecessary crashes due to file-not-found errors in some cases)
1417

1518
### Changed
1619
- [Breaking] `NEQUIP_MATSCIPY_NL` environment variable no longer supported
1720

21+
### Fixed
22+
- Fixed `flake8` install location in `pre-commit-config.yaml`
23+
1824

1925
## [0.6.0] - 2024-5-10
2026
### Added

configs/full.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ save_ema_checkpoint_freq: -1
181181
# training
182182
n_train: 100 # number of training data
183183
n_val: 50 # number of validation data
184+
# alternatively, n_train and n_val can be set as percentages of the dataset size:
185+
# n_train: 70% # 70% of dataset
186+
# n_val: 30% # 30% of dataset (if validation_dataset not set), or 30% of validation_dataset (if set)
184187
learning_rate: 0.005 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune
185188
batch_size: 5 # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better
186189
validation_batch_size: 10 # batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training.

nequip/data/_dataset/_ase_dataset.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ def _ase_dataset_reader(
5151
datas.append(
5252
(
5353
global_index,
54-
AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs)
55-
if global_index in include_frames
56-
# in-memory dataset will ignore this later, but needed for indexing to work out
57-
else None,
54+
(
55+
AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs)
56+
if global_index in include_frames
57+
# in-memory dataset will ignore this later, but needed for indexing to work out
58+
else None
59+
),
5860
)
5961
)
6062
# Save to a tempfile---

nequip/data/_keys.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
This is a seperate module to compensate for a TorchScript bug that can only recognize constants when they are accessed as attributes of an imported module.
44
"""
5+
56
import sys
67
from typing import List
78

nequip/data/dataloader.py

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class PartialSampler(Sampler[int]):
8484
If `None`, defaults to `len(data_source)`.
8585
generator (Generator): Generator used in sampling.
8686
"""
87+
8788
data_source: Dataset
8889
num_samples_per_epoch: int
8990
shuffle: bool

nequip/nn/_convnetlayer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ def __init__(
149149
# updated with whatever the convolution outputs (which is a full graph module)
150150
self.irreps_out.update(self.conv.irreps_out)
151151
# but with the features updated by the nonlinearity
152-
self.irreps_out[
153-
AtomicDataDict.NODE_FEATURES_KEY
154-
] = self.equivariant_nonlin.irreps_out
152+
self.irreps_out[AtomicDataDict.NODE_FEATURES_KEY] = (
153+
self.equivariant_nonlin.irreps_out
154+
)
155155

156156
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
157157
# save old features for resnet

nequip/nn/_grad_output.py

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class GradientOutput(GraphModuleMixin, torch.nn.Module):
2020
out_field: the field in which to return the computed gradients. Defaults to ``f"d({of})/d({wrt})"`` for each field in ``wrt``.
2121
sign: either 1 or -1; the returned gradient is multiplied by this.
2222
"""
23+
2324
sign: float
2425
_negate: bool
2526
skip: bool
@@ -119,6 +120,7 @@ class PartialForceOutput(GraphModuleMixin, torch.nn.Module):
119120
vectorize: the vectorize option to ``torch.autograd.functional.jacobian``,
120121
false by default since it doesn't work well.
121122
"""
123+
122124
vectorize: bool
123125

124126
def __init__(
@@ -183,6 +185,7 @@ class StressOutput(GraphModuleMixin, torch.nn.Module):
183185
func: the energy model to wrap
184186
do_forces: whether to compute forces as well
185187
"""
188+
186189
do_forces: bool
187190

188191
def __init__(

nequip/nn/_interaction_block.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
""" Interaction Block """
2+
23
from typing import Optional, Dict, Callable
34

45
import torch

nequip/scripts/evaluate.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,11 @@ def main(args=None, running_as_script: bool = True):
383383
if do_metrics:
384384
display_bar = context_stack.enter_context(
385385
tqdm(
386-
bar_format=""
387-
if prog.disable # prog.ncols doesn't exist if disabled
388-
else ("{desc:." + str(prog.ncols) + "}"),
386+
bar_format=(
387+
""
388+
if prog.disable # prog.ncols doesn't exist if disabled
389+
else ("{desc:." + str(prog.ncols) + "}")
390+
),
389391
disable=None,
390392
)
391393
)

nequip/scripts/train.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
""" Train a network."""
2+
23
import logging
34
import argparse
45
import warnings
@@ -7,7 +8,8 @@
78
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
89
import numpy as np # noqa: F401
910

10-
from os.path import isdir
11+
from os.path import exists, isdir
12+
from shutil import rmtree
1113
from pathlib import Path
1214

1315
import torch
@@ -71,12 +73,21 @@ def main(args=None, running_as_script: bool = True):
7173
if running_as_script:
7274
set_up_script_logger(config.get("log", None), config.verbose)
7375

74-
found_restart_file = isdir(f"{config.root}/{config.run_name}")
76+
found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth")
7577
if found_restart_file and not config.append:
7678
raise RuntimeError(
7779
f"Training instance exists at {config.root}/{config.run_name}; "
7880
"either set append to True or use a different root or runname"
7981
)
82+
elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"):
83+
# output directory exists but no ``trainer.pth`` file, suggesting previous run crash during
84+
# first training epoch (usually due to memory):
85+
warnings.warn(
86+
f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model "
87+
f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will "
88+
f"be started."
89+
)
90+
rmtree(f"{config.root}/{config.run_name}")
8091

8192
# for fresh new train
8293
if not found_restart_file:

nequip/train/trainer.py

+70-16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
make an interface with ray
88
99
"""
10+
1011
import sys
1112
import inspect
1213
import logging
@@ -107,7 +108,7 @@ class Trainer:
107108
- "trainer_save.pth": all the training information. The file used for loading and restart
108109
109110
For restart run, the default set up is to not append to the original folders and files.
110-
The Output class will automatically build a folder call root/run_name
111+
The Output class will automatically build a folder called ``root/run_name``
111112
If append mode is on, the log file will be appended and the best model and last model will be overwritten.
112113
113114
More examples can be found in tests/train/test_trainer.py
@@ -157,9 +158,9 @@ class Trainer:
157158
batch_size (int): size of each batch
158159
validation_batch_size (int): batch size for evaluating the model for validation
159160
shuffle (bool): parameters for dataloader
160-
n_train (int): # of frames for training
161+
n_train (int, str): # of frames for training (as int, or as a percentage string)
161162
n_train_per_epoch (optional int): how many frames from `n_train` to use each epoch; see `PartialSampler`. When `None`, all `n_train` frames will be used each epoch.
162-
n_val (int): # of frames for validation
163+
n_val (int), str: # of frames for validation (as int, or as a percentage string)
163164
exclude_keys (list): fields from dataset to ignore.
164165
dataloader_num_workers (int): `num_workers` for the `DataLoader`s
165166
train_idcs (optional, list): list of frames to use for training
@@ -250,9 +251,9 @@ def __init__(
250251
batch_size: int = 5,
251252
validation_batch_size: int = 5,
252253
shuffle: bool = True,
253-
n_train: Optional[int] = None,
254+
n_train: Optional[Union[int, str]] = None,
254255
n_train_per_epoch: Optional[int] = None,
255-
n_val: Optional[int] = None,
256+
n_val: Optional[Union[int, str]] = None,
256257
dataloader_num_workers: int = 0,
257258
train_idcs: Optional[list] = None,
258259
val_idcs: Optional[list] = None,
@@ -754,7 +755,6 @@ def init_metrics(self):
754755
)
755756

756757
def train(self):
757-
758758
"""Training"""
759759
if getattr(self, "dl_train", None) is None:
760760
raise RuntimeError("You must call `set_dataset()` before calling `train()`")
@@ -1144,12 +1144,59 @@ def __del__(self):
11441144
for i in range(len(logger.handlers)):
11451145
logger.handlers.pop()
11461146

1147+
def _parse_n_train_n_val(
1148+
self, train_dataset_size: int, val_dataset_size: int
1149+
) -> Tuple[int, int]:
1150+
# parse n_train and n_val (can be ints or str with percentage):
1151+
n_train_n_val = []
1152+
for n_name, dataset_size in (
1153+
("n_train", train_dataset_size),
1154+
("n_val", val_dataset_size),
1155+
):
1156+
n = getattr(self, n_name)
1157+
if isinstance(n, str) and "%" in n:
1158+
n_train_n_val.append(
1159+
(float(n.rstrip("%")) / 100) * dataset_size
1160+
) # convert to float first
1161+
elif isinstance(n, int):
1162+
n_train_n_val.append(n)
1163+
else:
1164+
raise ValueError(
1165+
f"Invalid value/type for {n_name}: {n} -- must be either int or str with %!"
1166+
)
1167+
1168+
floored_n_train_n_val = [int(n) for n in n_train_n_val]
1169+
for n, n_name in zip(floored_n_train_n_val, ["n_train", "n_val"]):
1170+
if n < 1:
1171+
raise ValueError(f"{n_name} must be at least 1! Got {n}.")
1172+
1173+
# if n_train and n_val were both set as percentages which summed to 100%, make sure that sum of
1174+
# floored values comes to 100% of dataset size (i.e. that flooring doesn't omit a frame)
1175+
if (
1176+
train_dataset_size == val_dataset_size
1177+
and isinstance(self.n_train, str)
1178+
and isinstance(self.n_val, str)
1179+
and np.isclose(
1180+
float(self.n_train.strip("%")) + float(self.n_val.strip("%")), 100
1181+
)
1182+
):
1183+
if (
1184+
sum(floored_n_train_n_val) != train_dataset_size
1185+
): # one frame was cut, add to larger of the
1186+
# two float values (i.e. round up the percentage which gave a >= x.5 float value)
1187+
floored_n_train_n_val[
1188+
np.argmax(n_train_n_val)
1189+
] += train_dataset_size - sum(floored_n_train_n_val)
1190+
1191+
return tuple(floored_n_train_n_val)
1192+
11471193
def set_dataset(
11481194
self,
11491195
dataset: AtomicDataset,
11501196
validation_dataset: Optional[AtomicDataset] = None,
11511197
) -> None:
1152-
"""Set the dataset(s) used by this trainer.
1198+
"""
1199+
Set the dataset(s) used by this trainer.
11531200
11541201
Training and validation datasets will be sampled from
11551202
them in accordance with the trainer's parameters.
@@ -1163,7 +1210,10 @@ def set_dataset(
11631210
if validation_dataset is None:
11641211
# Sample both from `dataset`:
11651212
total_n = len(dataset)
1166-
if (self.n_train + self.n_val) > total_n:
1213+
n_train, n_val = self._parse_n_train_n_val(
1214+
train_dataset_size=total_n, val_dataset_size=total_n
1215+
)
1216+
if (n_train + n_val) > total_n:
11671217
raise ValueError(
11681218
"too little data for training and validation. please reduce n_train and n_val"
11691219
)
@@ -1177,25 +1227,29 @@ def set_dataset(
11771227
f"splitting mode {self.train_val_split} not implemented"
11781228
)
11791229

1180-
self.train_idcs = idcs[: self.n_train]
1181-
self.val_idcs = idcs[self.n_train : self.n_train + self.n_val]
1230+
self.train_idcs = idcs[:n_train]
1231+
self.val_idcs = idcs[n_train : n_train + n_val]
11821232
else:
1183-
if self.n_train > len(dataset):
1233+
n_train, n_val = self._parse_n_train_n_val(
1234+
train_dataset_size=len(dataset),
1235+
val_dataset_size=len(validation_dataset),
1236+
)
1237+
if n_train > len(dataset):
11841238
raise ValueError("Not enough data in dataset for requested n_train")
1185-
if self.n_val > len(validation_dataset):
1239+
if n_val > len(validation_dataset):
11861240
raise ValueError(
11871241
"Not enough data in validation dataset for requested n_val"
11881242
)
11891243
if self.train_val_split == "random":
11901244
self.train_idcs = torch.randperm(
11911245
len(dataset), generator=self.dataset_rng
1192-
)[: self.n_train]
1246+
)[:n_train]
11931247
self.val_idcs = torch.randperm(
11941248
len(validation_dataset), generator=self.dataset_rng
1195-
)[: self.n_val]
1249+
)[:n_val]
11961250
elif self.train_val_split == "sequential":
1197-
self.train_idcs = torch.arange(self.n_train)
1198-
self.val_idcs = torch.arange(self.n_val)
1251+
self.train_idcs = torch.arange(n_train)
1252+
self.val_idcs = torch.arange(n_val)
11991253
else:
12001254
raise NotImplementedError(
12011255
f"splitting mode {self.train_val_split} not implemented"

nequip/utils/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
If a parameter is updated, the updated value will be formatted back to the same type.
3535
3636
"""
37+
3738
from typing import Set, Dict, Any, List
3839

3940
import inspect

nequip/utils/savenload.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
utilities that involve file searching and operations (i.e. save/load)
33
"""
4+
45
from typing import Union, List, Tuple, Optional, Callable
56
import sys
67
import logging

tests/integration/test_evaluate.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ def runit(params: dict):
171171
assert np.allclose(
172172
err,
173173
0.0,
174-
atol=1e-8
175-
if true_identity
176-
else (1e-2 if metric.startswith("e") else 1e-4),
174+
atol=(
175+
1e-8
176+
if true_identity
177+
else (1e-2 if metric.startswith("e") else 1e-4)
178+
),
177179
), f"Metric `{metric}` wasn't zero!"
178180
elif builder == ConstFactorModel:
179181
# TODO: check comperable to naive numpy compute

tests/unit/model/test_pair/test_zbl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_lammps_repro(self, config):
8181
# $ lmp -in zbl_data.lmps
8282
# $ python -c "import numpy as np; d = np.loadtxt('zbl.dat', skiprows=1); np.save('zbl.npy', d)"
8383
refdata = np.load(Path(__file__).parent / "zbl.npy")
84-
for (r, Zi, Zj, pe, fxi, fxj) in refdata:
84+
for r, Zi, Zj, pe, fxi, fxj in refdata:
8585
if r >= r_max:
8686
continue
8787
atoms.positions[1, 0] = r

0 commit comments

Comments
 (0)