Skip to content

Commit ea0f8fe

Browse files
committed
Merge branch 'refs/heads/develop' into stratified_metrics
2 parents 1059248 + d1afaf2 commit ea0f8fe

File tree

16 files changed

+367
-96
lines changed

16 files changed

+367
-96
lines changed

CHANGELOG.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
Most recent change on the bottom.
88

99

10-
## Unreleased
10+
## Unreleased - 0.7.0
11+
### Added
12+
- `--override` now supported as a `nequip-train` flag (similar to its use in `nequip-deploy`)
13+
- add SoftAdapt (https://arxiv.org/abs/2403.18122) callback option
14+
15+
### Changed
16+
- [Breaking] training restart behavior altered: file-wise consistency checks performed between original config and config passed to `nequip-train` on restart (instead of checking the config dicts)
17+
- [Breaking] config format for callbacks changed (see `configs/full.yaml` for an example)
1118

19+
### Fixed
20+
- fixed `wandb_watch` bug
1221

1322
## [0.6.1] - 2024-7-9
1423
### Added

CITATION.cff

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
cff-version: "1.2.0"
2+
message: "If you use this software, please cite our article."
3+
authors:
4+
- family-names: Batzner
5+
given-names: Simon
6+
- family-names: Musaelian
7+
given-names: Albert
8+
- family-names: Sun
9+
given-names: Lixin
10+
- family-names: Geiger
11+
given-names: Mario
12+
- family-names: Mailoa
13+
given-names: Jonathan P.
14+
- family-names: Kornbluth
15+
given-names: Mordechai
16+
- family-names: Molinari
17+
given-names: Nicola
18+
- family-names: Smidt
19+
given-names: Tess E.
20+
- family-names: Kozinsky
21+
given-names: Boris
22+
doi: 10.1038/s41467-022-29939-5
23+
preferred-citation:
24+
authors:
25+
- family-names: Batzner
26+
given-names: Simon
27+
- family-names: Musaelian
28+
given-names: Albert
29+
- family-names: Sun
30+
given-names: Lixin
31+
- family-names: Geiger
32+
given-names: Mario
33+
- family-names: Mailoa
34+
given-names: Jonathan P.
35+
- family-names: Kornbluth
36+
given-names: Mordechai
37+
- family-names: Molinari
38+
given-names: Nicola
39+
- family-names: Smidt
40+
given-names: Tess E.
41+
- family-names: Kozinsky
42+
given-names: Boris
43+
doi: 10.1038/s41467-022-29939-5
44+
date-published: 2022-05-04
45+
issn: 2041-1723
46+
journal: Nature Communications
47+
start: 2453
48+
title: "E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials"
49+
type: article
50+
url: "https://www.nature.com/articles/s41467-022-29939-5"
51+
volume: 13

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ Details on writing and using plugins can be found in the [Allegro tutorial](http
141141

142142
## References & citing
143143

144-
The theory behind NequIP is described in our preprint (1). NequIP's backend builds on e3nn, a general framework for building E(3)-equivariant neural networks (2). If you use this repository in your work, please consider citing NequIP (1) and e3nn (3):
144+
The theory behind NequIP is described in our [article](https://www.nature.com/articles/s41467-022-29939-5) (1).
145+
NequIP's backend builds on [`e3nn`](https://e3nn.org), a general framework for building E(3)-equivariant
146+
neural networks (2). If you use this repository in your work, please consider citing `NequIP` (1) and `e3nn` (3):
145147

146148
1. https://www.nature.com/articles/s41467-022-29939-5
147149
2. https://e3nn.org

configs/full.yaml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,17 @@ loss_coeffs:
256256
# In the "schedule" key each entry is a two-element list of:
257257
# - the 1-based epoch index at which to start the new loss coefficients
258258
# - the new loss coefficients as a dict
259-
#
260-
# start_of_epoch_callbacks:
261-
# - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
262-
#
259+
# callbacks:
260+
# start_of_epoch:
261+
# - !!python/object:nequip.train.callbacks.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
262+
263+
# You can also try using the SoftAdapt strategy for adaptively changing loss coefficients
264+
# (see https://arxiv.org/abs/2403.18122)
265+
#callbacks:
266+
# end_of_batch:
267+
# - !!python/object:nequip.train.callbacks.SoftAdapt {"batches_per_update": 5, "beta": 1.1}
268+
269+
263270

264271
# output metrics
265272
metrics_components:

docs/cite.rst

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1-
Citing Nequip
1+
Citing NequIP
22
=============
3+
If you use ``NequIP`` in your research, please cite our `article <https://doi.org/10.1038/s41467-022-29939-5>`_:
34

5+
.. code-block:: bibtex
6+
7+
@article{batzner_e3-equivariant_2022,
8+
title = {E(3)-Equivariant Graph Neural Networks for Data-Efficient and Accurate Interatomic Potentials},
9+
author = {Batzner, Simon and Musaelian, Albert and Sun, Lixin and Geiger, Mario and Mailoa, Jonathan P. and Kornbluth, Mordechai and Molinari, Nicola and Smidt, Tess E. and Kozinsky, Boris},
10+
year = {2022},
11+
month = may,
12+
journal = {Nature Communications},
13+
volume = {13},
14+
number = {1},
15+
pages = {2453},
16+
issn = {2041-1723},
17+
doi = {10.1038/s41467-022-29939-5},
18+
}
19+
20+
The theory behind NequIP is described in our `article <https://doi.org/10.1038/s41467-022-29939-5>`_ above.
21+
NequIP's backend builds on `e3nn <https://e3nn.org>`_, a general framework for building E(3)-equivariant
22+
neural networks (1). If you use this repository in your work, please consider citing ``NequIP`` and ``e3nn`` (2):
23+
24+
1. https://e3nn.org
25+
2. https://doi.org/10.5281/zenodo.3724963

examples/plot_dimers.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,20 @@
3939
print("Computing dimers...")
4040
potential = {}
4141
N_sample = args.n_samples
42-
N_combs = len(list(itertools.combinations_with_replacement(range(num_types), 2)))
43-
r = torch.zeros(N_sample * N_combs, 2, 3, device=args.device)
44-
rs_one = torch.linspace(args.r_min, model_r_max, 500, device=args.device)
45-
rs = rs_one.repeat([N_combs])
46-
assert rs.shape == (N_combs * N_sample,)
42+
type_combos = [
43+
list(e) for e in itertools.combinations_with_replacement(range(num_types), 2)
44+
]
45+
N_combos = len(type_combos)
46+
r = torch.zeros(N_sample * N_combos, 2, 3, device=args.device)
47+
rs_one = torch.linspace(args.r_min, model_r_max, N_sample, device=args.device)
48+
rs = rs_one.repeat([N_combos])
49+
assert rs.shape == (N_combos * N_sample,)
4750
r[:, 1, 0] += rs # offset second atom along x axis
48-
types = torch.as_tensor(
49-
[list(e) for e in itertools.combinations_with_replacement(range(num_types), 2)]
50-
)
51-
types = types.reshape(N_combs, 1, 2).expand(N_combs, N_sample, 2).reshape(-1)
51+
types = torch.as_tensor(type_combos)
52+
types = types.reshape(N_combos, 1, 2).expand(N_combos, N_sample, 2).reshape(-1)
5253
r = r.reshape(-1, 3)
5354
assert types.shape == r.shape[:1]
54-
N_at_total = N_sample * N_combs * 2
55+
N_at_total = N_sample * N_combos * 2
5556
assert len(types) == N_at_total
5657
edge_index = torch.vstack(
5758
(
@@ -61,14 +62,14 @@
6162
)
6263
)
6364
data = AtomicData(pos=r, atom_types=types, edge_index=edge_index)
64-
data.batch = torch.arange(N_sample * N_combs, device=args.device).repeat_interleave(2)
65-
data.ptr = torch.arange(0, 2 * N_sample * N_combs + 1, 2, device=args.device)
65+
data.batch = torch.arange(N_sample * N_combos, device=args.device).repeat_interleave(2)
66+
data.ptr = torch.arange(0, 2 * N_sample * N_combos + 1, 2, device=args.device)
6667
result = model(AtomicData.to_AtomicDataDict(data.to(device=args.device)))
6768

6869
print("Plotting...")
6970
energies = (
7071
result[AtomicDataDict.TOTAL_ENERGY_KEY]
71-
.reshape(N_combs, N_sample)
72+
.reshape(N_combos, N_sample)
7273
.cpu()
7374
.detach()
7475
.numpy()
@@ -83,9 +84,7 @@
8384
dpi=120,
8485
)
8586

86-
for i, (type1, type2) in enumerate(
87-
itertools.combinations_with_replacement(range(num_types), 2)
88-
):
87+
for i, (type1, type2) in enumerate(type_combos):
8988
ax = axs[i]
9089
ax.set_ylabel(f"{type_names[type1]}-{type_names[type2]}")
9190
ax.plot(rs_one, energies[i])

nequip/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import sys
23

34
from ._version import __version__ # noqa: F401
@@ -16,7 +17,10 @@
1617
), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found"
1718

1819
# warn if using 1.13* or 2.0.*
19-
if packaging.version.parse("1.13.0") <= torch_version:
20+
if (
21+
packaging.version.parse("1.13.0") <= torch_version
22+
and int(os.environ.get("PYTORCH_VERSION_WARNING", 1)) != 0
23+
):
2024
warnings.warn(
2125
f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue."
2226
)

nequip/scripts/train.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import logging
44
import argparse
55
import warnings
6+
import shutil
7+
import difflib
8+
import yaml
69

710
# This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch.
811
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
@@ -29,6 +32,8 @@
2932
root="./",
3033
tensorboard=False,
3134
wandb=False,
35+
wandb_watch=False,
36+
wandb_watch_kwargs={},
3237
model_builders=[
3338
"SimpleIrrepsConfig",
3439
"EnergyModel",
@@ -46,7 +51,7 @@
4651
equivariance_test=False,
4752
grad_anomaly_mode=False,
4853
gpu_oom_offload=False,
49-
append=False,
54+
append=True,
5055
warn_unused=False,
5156
_jit_bailout_depth=2, # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
5257
# Quote from eelison in PyTorch slack:
@@ -68,32 +73,61 @@
6873

6974

7075
def main(args=None, running_as_script: bool = True):
71-
config = parse_command_line(args)
76+
config, path_to_config, override_options = parse_command_line(args)
7277

7378
if running_as_script:
7479
set_up_script_logger(config.get("log", None), config.verbose)
7580

76-
found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth")
81+
train_dir = f"{config.root}/{config.run_name}"
82+
found_restart_file = exists(f"{train_dir}/trainer.pth")
7783
if found_restart_file and not config.append:
7884
raise RuntimeError(
79-
f"Training instance exists at {config.root}/{config.run_name}; "
85+
f"Training instance exists at {train_dir}; "
8086
"either set append to True or use a different root or runname"
8187
)
82-
elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"):
88+
elif not found_restart_file and isdir(train_dir):
8389
# output directory exists but no ``trainer.pth`` file, suggesting previous run crash during
8490
# first training epoch (usually due to memory):
8591
warnings.warn(
86-
f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model "
92+
f"Previous run folder at {train_dir} exists, but a saved model "
8793
f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will "
8894
f"be started."
8995
)
90-
rmtree(f"{config.root}/{config.run_name}")
96+
rmtree(train_dir)
9197

92-
# for fresh new train
93-
if not found_restart_file:
98+
if not found_restart_file: # fresh start
99+
# update config with override parameters for setting up train-dir
100+
config.update(override_options)
94101
trainer = fresh_start(config)
95-
else:
96-
trainer = restart(config)
102+
# copy original config to training directory
103+
shutil.copyfile(path_to_config, f"{train_dir}/original_config.yaml")
104+
else: # restart
105+
# perform string matching for original config and restart config
106+
# throw error if they are different
107+
with (
108+
open(f"{train_dir}/original_config.yaml") as orig_f,
109+
open(path_to_config) as current_f,
110+
):
111+
diffs = [
112+
x
113+
for x in difflib.Differ().compare(
114+
orig_f.readlines(), current_f.readlines()
115+
)
116+
if x[0] in ("+", "-")
117+
]
118+
if diffs:
119+
raise RuntimeError(
120+
f"Config {path_to_config} used for restart differs from original config for training run in {train_dir}.\n"
121+
+ "The following differences were found:\n\n"
122+
+ "".join(diffs)
123+
+ "\n"
124+
+ "If you intend to override the original config parameters, use the --override flag. For example, use\n"
125+
+ f'`nequip-train {path_to_config} --override "max_epochs: 42"`\n'
126+
+ 'on the command line to override the config parameter "max_epochs"\n'
127+
+ "BE WARNED that use of the --override flag is not protected by consistency checks performed by NequIP."
128+
)
129+
else:
130+
trainer = restart(config, override_options)
97131

98132
# Train
99133
trainer.save()
@@ -157,6 +191,12 @@ def parse_command_line(args=None):
157191
help="Warn instead of error when the config contains unused keys",
158192
action="store_true",
159193
)
194+
parser.add_argument(
195+
"--override",
196+
help="Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option.",
197+
type=str,
198+
default=None,
199+
)
160200
args = parser.parse_args(args=args)
161201

162202
config = Config.from_file(args.config, defaults=default_config)
@@ -169,10 +209,26 @@ def parse_command_line(args=None):
169209
):
170210
config[flag] = getattr(args, flag) or config[flag]
171211

172-
return config
212+
# Set override options before _set_global_options so that things like allow_tf32 are correctly handled
213+
if args.override is not None:
214+
override_options = yaml.load(args.override, Loader=yaml.Loader)
215+
assert isinstance(
216+
override_options, dict
217+
), "--override's YAML string must define a dictionary of top-level options"
218+
overridden_keys = set(config.keys()).intersection(override_options.keys())
219+
set_keys = set(override_options.keys()) - set(overridden_keys)
220+
logging.info(
221+
f"--override: overrode keys {list(overridden_keys)} and set new keys {list(set_keys)}"
222+
)
223+
del overridden_keys, set_keys
224+
else:
225+
override_options = {}
226+
227+
return config, args.config, override_options
173228

174229

175230
def fresh_start(config):
231+
176232
# we use add_to_config cause it's a fresh start and need to record it
177233
check_code_version(config, add_to_config=True)
178234
_set_global_options(config)
@@ -267,7 +323,7 @@ def _unused_check():
267323
return trainer
268324

269325

270-
def restart(config):
326+
def restart(config, override_options):
271327
# load the dictionary
272328
restart_file = f"{config.root}/{config.run_name}/trainer.pth"
273329
dictionary = load_file(
@@ -276,20 +332,6 @@ def restart(config):
276332
enforced_format="torch",
277333
)
278334

279-
# compare dictionary to config and update stop condition related arguments
280-
for k in config.keys():
281-
if config[k] != dictionary.get(k, ""):
282-
if k == "max_epochs":
283-
dictionary[k] = config[k]
284-
logging.info(f'Update "{k}" to {dictionary[k]}')
285-
elif k.startswith("early_stop"):
286-
dictionary[k] = config[k]
287-
logging.info(f'Update "{k}" to {dictionary[k]}')
288-
elif isinstance(config[k], type(dictionary.get(k, ""))):
289-
raise ValueError(
290-
f'Key "{k}" is different in config and the result trainer.pth file. Please double check'
291-
)
292-
293335
# note, "trainer.pth"/dictionary also store code versions,
294336
# which will not be stored in config and thus not checked here
295337
check_code_version(config)
@@ -299,6 +341,10 @@ def restart(config):
299341

300342
config = Config(dictionary, exclude_keys=["state_dict", "progress"])
301343

344+
# override configs loaded from save
345+
dictionary.update(override_options)
346+
config.update(override_options)
347+
302348
# dtype, etc.
303349
_set_global_options(config)
304350

0 commit comments

Comments
 (0)