Skip to content

Commit bf61255

Browse files
Merge pull request #825 from mlcommons/dev
[do not merge] Dev -> Main
2 parents 86d2a0d + 6c8fd56 commit bf61255

File tree

35 files changed

+832
-100
lines changed

35 files changed

+832
-100
lines changed

DOCUMENTATION.md

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ In principle, submissions are allowed to use the available hardware systems in a
8080
Submissions provide a [per-workload batch size](#batch-size-getter) to use. Specification of the batch size for each workload is necessary to avoid running out of memory for different workloads. Therefore, submitters can determine this batch size in advance and specify it as part of the submission. Submitters may also provide per-workload batch sizes for all [randomized workloads](#randomized-workloads). If no such batch size is provided for a randomized workload, by default, submissions will then use the batch size of the most similar [fixed workload](#fixed-workloads) (for example, if there is an ImageNet fixed workload and also a randomized workload with a similarly sized model on similarly sized images, the ImageNet batch size will be used for held-out workloads generated from this randomized workload).
8181
Note that submitters are *not* allowed to modify the *evaluation batch size*, which is set by the benchmarking codebase. However, you can file an issue if you believe that the evaluation batch size of a particular workload is set inappropriately. The working group will review this request and consider adjusting the evaluation batch size in the benchmarking codebase, thus affecting all submitters equally.
8282

83-
The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code.
83+
The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, *prepare for evaluation function*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code.
8484

8585
##### Fixed functions
8686

@@ -220,9 +220,35 @@ def update_params(
220220
- Cannot modify the given hyperparameters in a workload-conditional way (please see the [Valid submission](#valid-submissions) section). This rule is intended to prohibit circumventing the tuning rules by looking up a pre-tuned optimal set of hyperparameters for each workload. It is not intended to prohibit line searches and other similar techniques.
221221
- The fixed `init_model_fn` can optionally be called during training, for example, to reinitialize the model after a failed training effort.
222222
- Cannot replace the model parameters with pre-trained ones.
223-
- This API supports Polyak averaging and similar methods that implement moving averages of model parameters.
224223
- Batch norm should work here because the `model_fn` will return updated batch norm moving averages when it is told to with `update_batch_norm`.
225224

225+
226+
###### Prepare for evaluation function
227+
228+
```python
229+
def prepare_for_eval(
230+
workload: Workload,
231+
current_param_container: ParameterContainer,
232+
current_params_types: ParameterTypeTree,
233+
model_state: ModelAuxiliaryState,
234+
hyperparameters: Hyperparameters,
235+
loss_type: LossType,
236+
optimizer_state: OptimizerState,
237+
eval_results: List[Tuple[int, float]],
238+
global_step: int,
239+
rng: RandomState
240+
) -> (updated_optimizer_state, updated_variables, updated_model_state)
241+
```
242+
243+
- Arguments are the same of `update_param`, with the only exception of `batch`.
244+
- This function is called when a submission is deemed eligible for an evaluation (see [Evluation during training](#evaluation-during-training) section).
245+
- The call to `prepare_for_eval` is timed and its runtime accumulates to the overall submission time.
246+
- The returned model parameters are evaluated on the validation and test sets, provided that the accumulated submission time does not exceed the maximum runtime after this function call.
247+
- This API supports Polyak averaging and similar methods that implement moving averages of model parameters.
248+
- Allowed to update model state and model parameters.
249+
- Allowed to update state for the optimizer.
250+
- Cannot replace the model parameters with pre-trained ones.
251+
226252
###### Data selection
227253

228254
```python
@@ -252,7 +278,8 @@ def data_selection(
252278

253279
In general, with noisy, non-deterministic training, evaluation frequency can affect training time measurements as more "bites of the apple" potentially allows the training code to exploit instability. We also want to discourage submissions from complicated and unrealistic logic that attempts to guess when training is close to complete and increases the evaluation rate, while not producing a well-sampled training curve at the start of training. Simply allowing submissions complete freedom over evaluation frequency encourages competitors to work to minimize the number of evaluations, which distracts from the primary goal of finding better training algorithms.
254280

255-
Submissions are eligible for an untimed eval every `eval_period` seconds, run as soon as the current call of `update_params` completes. Any additional evaluations performed by the submission code count against the runtime for scoring. The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval and, if so, pausing the clock and running an eval. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs.
281+
Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring.
282+
The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval, if so, the submission is given the possibility to prepare for evaluation (through a timed call to `prepare_for_eval`). If the accumulated runtime does not exceed the maximum allowed runtime after the preparation step, the clock is paused, and the submission is evaluated. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs.
256283

257284
#### Valid submissions
258285

@@ -419,6 +446,19 @@ The currently eight fixed workloads are:
419446
| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 |
420447
| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 |
421448

449+
Default Dropout Values for Different Workloads:
450+
451+
| Workload | Dropout Values |
452+
|------------------------|------------------------------------------------------------------------------------------------------|
453+
| criteo 1tb | dropout_rate: 0.0 |
454+
| fastmri | dropout_rate: 0.0 |
455+
| imagenet_resnet | dropout not used |
456+
| imagenet_vit | dropout_rate: 0.0 |
457+
| librispeech_conformer | attention_dropout_rate: 0.0 <br> attention_residual_dropout_rate: 0.1 <br> conv_residual_dropout_rate: 0.0 <br> feed_forward_dropout_rate: 0.0 <br> feed_forward_residual_dropout_rate: 0.1 <br> input_dropout_rate: 0.1 |
458+
| librispeech_deepspeech | input_dropout_rate: 0.1 <br> feed_forward_dropout_rate: 0.1 <br> (Only for JAX - dropout_rate in CudnnLSTM class: 0.0) |
459+
| ogbg | dropout_rate: 0.1 |
460+
| wmt | dropout_rate: 0.1 <br> attention_dropout_rate: 0.1 |
461+
422462
#### Randomized workloads
423463

424464
In addition to the [fixed and known workloads](#fixed-workloads), there will also be randomized workloads in our benchmark. These randomized workloads will introduce minor modifications to a fixed workload (e.g. small model changes). The exact instances of these randomized workloads will only be created after the submission deadline and are thus unknown to both the submitters as well as the benchmark organizers. The instructions for creating them, i.e. providing a set or distribution of workloads to sample from, will be defined by this working group and made public with the call for submissions, to allow the members of this working group to submit as well as ensure that they do not possess any additional information compared to other submitters. We will refer to the unspecific workloads as *randomized workloads*, e.g. the set or distribution. The specific instance of such a randomized workload we call a *held-out workload*. That is, a held-out workload is a specific sample of a randomized workload that is used for one iteration of the benchmark. While we may reuse randomized workloads between iterations of the benchmark, new held-out workloads will be sampled for each new benchmark iteration.

algorithmic_efficiency/init_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None:
1313
# Perform lecun_normal initialization.
1414
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
1515
std = math.sqrt(1. / fan_in) / .87962566103423978
16-
nn.init.trunc_normal_(module.weight, std=std)
16+
nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
1717
if module.bias is not None:
1818
nn.init.constant_(module.bias, 0.)

algorithmic_efficiency/random_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@
1616

1717
FLAGS = flags.FLAGS
1818

19-
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
19+
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an
2020
# unsigned int), while RandomState.randint only accepts and returns signed ints.
21-
MAX_INT32 = 2**31
22-
MIN_INT32 = -MAX_INT32
21+
MAX_INT32 = 2**31 - 1
22+
MIN_INT32 = 0
2323

2424
SeedType = Union[int, list, np.ndarray]
2525

2626

2727
def _signed_to_unsigned(seed: SeedType) -> SeedType:
2828
if isinstance(seed, int):
29-
return seed % 2**32
29+
return seed % MAX_INT32
3030
if isinstance(seed, list):
31-
return [s % 2**32 for s in seed]
31+
return [s % MAX_INT32 for s in seed]
3232
if isinstance(seed, np.ndarray):
33-
return np.array([s % 2**32 for s in seed.tolist()])
33+
return np.array([s % MAX_INT32 for s in seed.tolist()])
3434

3535

3636
def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:

algorithmic_efficiency/spec.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,36 @@ def update_params(workload: Workload,
431431
pass
432432

433433

434+
PrepareForEvalFn = Callable[[
435+
Workload,
436+
ParameterContainer,
437+
ParameterTypeTree,
438+
ModelAuxiliaryState,
439+
Hyperparameters,
440+
LossType,
441+
OptimizerState,
442+
List[Tuple[int, float]],
443+
int,
444+
RandomState
445+
],
446+
UpdateReturn]
447+
448+
449+
# Prepare model and optimizer for evaluation.
450+
def prepare_for_eval(workload: Workload,
451+
current_param_container: ParameterContainer,
452+
current_params_types: ParameterTypeTree,
453+
model_state: ModelAuxiliaryState,
454+
hyperparameters: Hyperparameters,
455+
loss_type: LossType,
456+
optimizer_state: OptimizerState,
457+
eval_results: List[Tuple[int, float]],
458+
global_step: int,
459+
rng: RandomState) -> UpdateReturn:
460+
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
461+
pass
462+
463+
434464
DataSelectionFn = Callable[[
435465
Workload,
436466
Iterator[Dict[str, Any]],

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,27 @@ def update_params(
302302
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
303303

304304

305+
def prepare_for_eval(workload: spec.Workload,
306+
current_param_container: spec.ParameterContainer,
307+
current_params_types: spec.ParameterTypeTree,
308+
model_state: spec.ModelAuxiliaryState,
309+
hyperparameters: spec.Hyperparameters,
310+
loss_type: spec.LossType,
311+
optimizer_state: spec.OptimizerState,
312+
eval_results: List[Tuple[int, float]],
313+
global_step: int,
314+
rng: spec.RandomState) -> spec.UpdateReturn:
315+
"""Return (updated_optimizer_state, updated_params)."""
316+
del workload
317+
del hyperparameters
318+
del current_params_types
319+
del loss_type
320+
del eval_results
321+
del global_step
322+
del rng
323+
return (optimizer_state, current_param_container, model_state)
324+
325+
305326
def get_batch_size(workload_name):
306327
# Return the global batch size.
307328
if workload_name == 'criteo1tb':

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,27 @@ def update_params(
302302
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
303303

304304

305+
def prepare_for_eval(workload: spec.Workload,
306+
current_param_container: spec.ParameterContainer,
307+
current_params_types: spec.ParameterTypeTree,
308+
model_state: spec.ModelAuxiliaryState,
309+
hyperparameters: spec.Hyperparameters,
310+
loss_type: spec.LossType,
311+
optimizer_state: spec.OptimizerState,
312+
eval_results: List[Tuple[int, float]],
313+
global_step: int,
314+
rng: spec.RandomState) -> spec.UpdateReturn:
315+
"""Return (updated_optimizer_state, updated_params)."""
316+
del workload
317+
del hyperparameters
318+
del current_params_types
319+
del loss_type
320+
del eval_results
321+
del global_step
322+
del rng
323+
return (optimizer_state, current_param_container, model_state)
324+
325+
305326
def get_batch_size(workload_name):
306327
# Return the global batch size.
307328
if workload_name == 'criteo1tb':

prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,27 @@ def update_params(
304304
return (optimizer_state, current_param_container, new_model_state)
305305

306306

307+
def prepare_for_eval(workload: spec.Workload,
308+
current_param_container: spec.ParameterContainer,
309+
current_params_types: spec.ParameterTypeTree,
310+
model_state: spec.ModelAuxiliaryState,
311+
hyperparameters: spec.Hyperparameters,
312+
loss_type: spec.LossType,
313+
optimizer_state: spec.OptimizerState,
314+
eval_results: List[Tuple[int, float]],
315+
global_step: int,
316+
rng: spec.RandomState) -> spec.UpdateReturn:
317+
"""Return (updated_optimizer_state, updated_params)."""
318+
del workload
319+
del hyperparameters
320+
del current_params_types
321+
del loss_type
322+
del eval_results
323+
del global_step
324+
del rng
325+
return (optimizer_state, current_param_container, model_state)
326+
327+
307328
def get_batch_size(workload_name):
308329
# Return the global batch size.
309330
if workload_name == 'criteo1tb':

prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,27 @@ def update_params(
304304
return (optimizer_state, current_param_container, new_model_state)
305305

306306

307+
def prepare_for_eval(workload: spec.Workload,
308+
current_param_container: spec.ParameterContainer,
309+
current_params_types: spec.ParameterTypeTree,
310+
model_state: spec.ModelAuxiliaryState,
311+
hyperparameters: spec.Hyperparameters,
312+
loss_type: spec.LossType,
313+
optimizer_state: spec.OptimizerState,
314+
eval_results: List[Tuple[int, float]],
315+
global_step: int,
316+
rng: spec.RandomState) -> spec.UpdateReturn:
317+
"""Return (updated_optimizer_state, updated_params)."""
318+
del workload
319+
del hyperparameters
320+
del current_params_types
321+
del loss_type
322+
del eval_results
323+
del global_step
324+
del rng
325+
return (optimizer_state, current_param_container, model_state)
326+
327+
307328
def get_batch_size(workload_name):
308329
# Return the global batch size.
309330
if workload_name == 'criteo1tb':

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,27 @@ def update_params(
317317
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
318318

319319

320+
def prepare_for_eval(workload: spec.Workload,
321+
current_param_container: spec.ParameterContainer,
322+
current_params_types: spec.ParameterTypeTree,
323+
model_state: spec.ModelAuxiliaryState,
324+
hyperparameters: spec.Hyperparameters,
325+
loss_type: spec.LossType,
326+
optimizer_state: spec.OptimizerState,
327+
eval_results: List[Tuple[int, float]],
328+
global_step: int,
329+
rng: spec.RandomState) -> spec.UpdateReturn:
330+
"""Return (updated_optimizer_state, updated_params)."""
331+
del workload
332+
del hyperparameters
333+
del current_params_types
334+
del loss_type
335+
del eval_results
336+
del global_step
337+
del rng
338+
return (optimizer_state, current_param_container, model_state)
339+
340+
320341
def get_batch_size(workload_name):
321342
# Return the global batch size.
322343
if workload_name == 'criteo1tb':

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,27 @@ def update_params(
317317
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
318318

319319

320+
def prepare_for_eval(workload: spec.Workload,
321+
current_param_container: spec.ParameterContainer,
322+
current_params_types: spec.ParameterTypeTree,
323+
model_state: spec.ModelAuxiliaryState,
324+
hyperparameters: spec.Hyperparameters,
325+
loss_type: spec.LossType,
326+
optimizer_state: spec.OptimizerState,
327+
eval_results: List[Tuple[int, float]],
328+
global_step: int,
329+
rng: spec.RandomState) -> spec.UpdateReturn:
330+
"""Return (updated_optimizer_state, updated_params)."""
331+
del workload
332+
del hyperparameters
333+
del current_params_types
334+
del loss_type
335+
del eval_results
336+
del global_step
337+
del rng
338+
return (optimizer_state, current_param_container, model_state)
339+
340+
320341
def get_batch_size(workload_name):
321342
# Return the global batch size.
322343
if workload_name == 'criteo1tb':

0 commit comments

Comments
 (0)