You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: configs/full.yaml
+3
Original file line number
Diff line number
Diff line change
@@ -181,6 +181,9 @@ save_ema_checkpoint_freq: -1
181
181
# training
182
182
n_train: 100# number of training data
183
183
n_val: 50# number of validation data
184
+
# alternatively, n_train and n_val can be set as percentages of the dataset size:
185
+
# n_train: 70% # 70% of dataset
186
+
# n_val: 30% # 30% of dataset (if validation_dataset not set), or 30% of validation_dataset (if set)
184
187
learning_rate: 0.005# learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune
185
188
batch_size: 5# batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better
186
189
validation_batch_size: 10# batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training.
Copy file name to clipboardexpand all lines: nequip/data/_keys.py
+1
Original file line number
Diff line number
Diff line change
@@ -2,6 +2,7 @@
2
2
3
3
This is a seperate module to compensate for a TorchScript bug that can only recognize constants when they are accessed as attributes of an imported module.
Copy file name to clipboardexpand all lines: nequip/train/trainer.py
+70-16
Original file line number
Diff line number
Diff line change
@@ -7,6 +7,7 @@
7
7
make an interface with ray
8
8
9
9
"""
10
+
10
11
importsys
11
12
importinspect
12
13
importlogging
@@ -107,7 +108,7 @@ class Trainer:
107
108
- "trainer_save.pth": all the training information. The file used for loading and restart
108
109
109
110
For restart run, the default set up is to not append to the original folders and files.
110
-
The Output class will automatically build a folder call root/run_name
111
+
The Output class will automatically build a folder called ``root/run_name``
111
112
If append mode is on, the log file will be appended and the best model and last model will be overwritten.
112
113
113
114
More examples can be found in tests/train/test_trainer.py
@@ -157,9 +158,9 @@ class Trainer:
157
158
batch_size (int): size of each batch
158
159
validation_batch_size (int): batch size for evaluating the model for validation
159
160
shuffle (bool): parameters for dataloader
160
-
n_train (int): # of frames for training
161
+
n_train (int, str): # of frames for training (as int, or as a percentage string)
161
162
n_train_per_epoch (optional int): how many frames from `n_train` to use each epoch; see `PartialSampler`. When `None`, all `n_train` frames will be used each epoch.
162
-
n_val (int): # of frames for validation
163
+
n_val (int), str: # of frames for validation (as int, or as a percentage string)
163
164
exclude_keys (list): fields from dataset to ignore.
164
165
dataloader_num_workers (int): `num_workers` for the `DataLoader`s
165
166
train_idcs (optional, list): list of frames to use for training
@@ -250,9 +251,9 @@ def __init__(
250
251
batch_size: int=5,
251
252
validation_batch_size: int=5,
252
253
shuffle: bool=True,
253
-
n_train: Optional[int] =None,
254
+
n_train: Optional[Union[int, str]] =None,
254
255
n_train_per_epoch: Optional[int] =None,
255
-
n_val: Optional[int] =None,
256
+
n_val: Optional[Union[int, str]] =None,
256
257
dataloader_num_workers: int=0,
257
258
train_idcs: Optional[list] =None,
258
259
val_idcs: Optional[list] =None,
@@ -754,7 +755,6 @@ def init_metrics(self):
754
755
)
755
756
756
757
deftrain(self):
757
-
758
758
"""Training"""
759
759
ifgetattr(self, "dl_train", None) isNone:
760
760
raiseRuntimeError("You must call `set_dataset()` before calling `train()`")
@@ -1144,12 +1144,59 @@ def __del__(self):
1144
1144
foriinrange(len(logger.handlers)):
1145
1145
logger.handlers.pop()
1146
1146
1147
+
def_parse_n_train_n_val(
1148
+
self, train_dataset_size: int, val_dataset_size: int
1149
+
) ->Tuple[int, int]:
1150
+
# parse n_train and n_val (can be ints or str with percentage):
1151
+
n_train_n_val= []
1152
+
forn_name, dataset_sizein (
1153
+
("n_train", train_dataset_size),
1154
+
("n_val", val_dataset_size),
1155
+
):
1156
+
n=getattr(self, n_name)
1157
+
ifisinstance(n, str) and"%"inn:
1158
+
n_train_n_val.append(
1159
+
(float(n.rstrip("%")) /100) *dataset_size
1160
+
) # convert to float first
1161
+
elifisinstance(n, int):
1162
+
n_train_n_val.append(n)
1163
+
else:
1164
+
raiseValueError(
1165
+
f"Invalid value/type for {n_name}: {n} -- must be either int or str with %!"
0 commit comments