Skip to content

Commit 86d2a0d

Browse files
Merge pull request #812 from mlcommons/dev
Dev -> main
2 parents 3c61cc4 + 5f6a2ff commit 86d2a0d

File tree

45 files changed

+602
-403
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+602
-403
lines changed

DOCUMENTATION.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def update_params(
199199
batch: Dict[str, Tensor],
200200
loss_type: LossType,
201201
optimizer_state: OptimizerState,
202+
train_state: Dict[str, Any],
202203
eval_results: List[Tuple[int, float]],
203204
global_step: int,
204205
rng: RandomState
@@ -212,6 +213,7 @@ def update_params(
212213
- The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used.
213214
- Allowed to update state for the optimizer.
214215
- Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state).
216+
- The submission can access the elapsed training time and get further information about the evaluation through `train_state`.
215217
- The submission can access the target evaluation metric via the `workload` variable.
216218
- **A call to this function will be considered a step**
217219
- The time between a call to this function and the next call to this function will be considered the per-step time.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ python3 submission_runner.py \
8989
--workload=mnist \
9090
--experiment_dir=$HOME/experiments \
9191
--experiment_name=my_first_experiment \
92-
--submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \
92+
--submission_path=reference_algorithms/paper_baselines/adamw/pytorch/submission.py \
9393
--tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json
9494
```
9595

algorithmic_efficiency/pytorch_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
6767
)
6868
if isinstance(module, bn_layers):
6969
if not update_batch_norm:
70-
module.eval()
71-
module.momentum_backup = module.momentum
70+
if not hasattr(module, 'momentum_backup'):
71+
module.momentum_backup = module.momentum
72+
7273
# module.momentum can be float or torch.Tensor.
73-
module.momentum = 0. * module.momentum_backup
74+
if torch.is_tensor(module.momentum_backup):
75+
module.momentum = torch.zeros_like(module.momentum_backup)
76+
else:
77+
module.momentum = 0.0
7478
elif hasattr(module, 'momentum_backup'):
7579
module.momentum = module.momentum_backup
76-
module.track_running_stats = update_batch_norm

algorithmic_efficiency/spec.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ def init_optimizer_state(workload: Workload,
403403
OptimizerState,
404404
List[Tuple[int, float]],
405405
int,
406-
RandomState
406+
RandomState,
407+
Optional[Dict[str, Any]]
407408
],
408409
UpdateReturn]
409410

@@ -424,7 +425,8 @@ def update_params(workload: Workload,
424425
optimizer_state: OptimizerState,
425426
eval_results: List[Tuple[int, float]],
426427
global_step: int,
427-
rng: RandomState) -> UpdateReturn:
428+
rng: RandomState,
429+
train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
428430
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
429431
pass
430432

algorithmic_efficiency/workloads/cifar/cifar_jax/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@ class ResNet(nn.Module):
2828
@nn.compact
2929
def __call__(self,
3030
x: spec.Tensor,
31-
update_batch_norm: bool = True) -> spec.Tensor:
31+
update_batch_norm: bool = True,
32+
use_running_average_bn: bool = None) -> spec.Tensor:
3233
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
34+
35+
# Preserve default behavior for backwards compatibility
36+
if use_running_average_bn is None:
37+
use_running_average_bn = not update_batch_norm
3338
norm = functools.partial(
3439
nn.BatchNorm,
35-
use_running_average=not update_batch_norm,
40+
use_running_average=use_running_average_bn,
3641
momentum=0.9,
3742
epsilon=1e-5,
3843
dtype=self.dtype)

algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def model_fn(
110110
model_state: spec.ModelAuxiliaryState,
111111
mode: spec.ForwardPassMode,
112112
rng: spec.RandomState,
113-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
113+
update_batch_norm: bool,
114+
use_running_average_bn: Optional[bool] = None
115+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
114116
del mode
115117
del rng
116118
variables = {'params': params, **model_state}
@@ -119,14 +121,16 @@ def model_fn(
119121
variables,
120122
augmented_and_preprocessed_input_batch['inputs'],
121123
update_batch_norm=update_batch_norm,
122-
mutable=['batch_stats'])
124+
mutable=['batch_stats'],
125+
use_running_average_bn=use_running_average_bn)
123126
return logits, new_model_state
124127
else:
125128
logits = self._model.apply(
126129
variables,
127130
augmented_and_preprocessed_input_batch['inputs'],
128131
update_batch_norm=update_batch_norm,
129-
mutable=False)
132+
mutable=False,
133+
use_running_average_bn=use_running_average_bn)
130134
return logits, model_state
131135

132136
# Does NOT apply regularization, which is left to the submitter to do in

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,16 @@ class ResNet(nn.Module):
8484
@nn.compact
8585
def __call__(self,
8686
x: spec.Tensor,
87-
update_batch_norm: bool = True) -> spec.Tensor:
87+
update_batch_norm: bool = True,
88+
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
8889
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
90+
91+
# Preserve default behavior for backwards compatibility
92+
if use_running_average_bn is None:
93+
use_running_average_bn = not update_batch_norm
8994
norm = functools.partial(
9095
nn.BatchNorm,
91-
use_running_average=not update_batch_norm,
96+
use_running_average=use_running_average_bn,
9297
momentum=0.9,
9398
epsilon=1e-5,
9499
dtype=self.dtype)

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def model_fn(
148148
model_state: spec.ModelAuxiliaryState,
149149
mode: spec.ForwardPassMode,
150150
rng: spec.RandomState,
151-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
151+
update_batch_norm: bool,
152+
use_running_average_bn: Optional[bool] = None
153+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
152154
del mode
153155
del rng
154156
variables = {'params': params, **model_state}
@@ -157,14 +159,16 @@ def model_fn(
157159
variables,
158160
augmented_and_preprocessed_input_batch['inputs'],
159161
update_batch_norm=update_batch_norm,
160-
mutable=['batch_stats'])
162+
mutable=['batch_stats'],
163+
use_running_average_bn=use_running_average_bn)
161164
return logits, new_model_state
162165
else:
163166
logits = self._model.apply(
164167
variables,
165168
augmented_and_preprocessed_input_batch['inputs'],
166169
update_batch_norm=update_batch_norm,
167-
mutable=False)
170+
mutable=False,
171+
use_running_average_bn=use_running_average_bn)
168172
return logits, model_state
169173

170174
# Does NOT apply regularization, which is left to the submitter to do in

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -454,15 +454,24 @@ def setup(self):
454454
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
455455

456456
@nn.compact
457-
def __call__(self, inputs, input_paddings, train):
457+
def __call__(self,
458+
inputs,
459+
input_paddings,
460+
update_batch_norm,
461+
use_running_average_bn):
458462
rank = inputs.ndim
459463
reduce_over_dims = list(range(0, rank - 1))
460464

461465
padding = jnp.expand_dims(input_paddings, -1)
462466
momentum = self.config.batch_norm_momentum
463467
epsilon = self.config.batch_norm_epsilon
464468

465-
if train:
469+
if use_running_average_bn:
470+
mean = self.ra_mean.value
471+
var = self.ra_var.value
472+
473+
else:
474+
# compute batch statistics
466475
mask = 1.0 - padding
467476
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
468477
count_v = jnp.sum(
@@ -478,16 +487,13 @@ def __call__(self, inputs, input_paddings, train):
478487

479488
var = sum_vv / count_v
480489

481-
self.ra_mean.value = momentum * \
482-
self.ra_mean.value + (1 - momentum) * mean
483-
self.ra_var.value = momentum * \
484-
self.ra_var.value + (1 - momentum) * var
485-
else:
486-
mean = self.ra_mean.value
487-
var = self.ra_var.value
490+
if update_batch_norm:
491+
self.ra_mean.value = momentum * \
492+
self.ra_mean.value + (1 - momentum) * mean
493+
self.ra_var.value = momentum * \
494+
self.ra_var.value + (1 - momentum) * var
488495

489496
inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
490-
491497
bn_output = (inputs - mean) * inv + self.beta
492498
bn_output *= 1.0 - padding
493499

@@ -517,7 +523,12 @@ class ConvolutionBlock(nn.Module):
517523
config: ConformerConfig
518524

519525
@nn.compact
520-
def __call__(self, inputs, input_paddings, train):
526+
def __call__(self,
527+
inputs,
528+
input_paddings,
529+
train,
530+
update_batch_norm,
531+
use_running_average_bn):
521532
config = self.config
522533
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
523534

@@ -546,7 +557,10 @@ def __call__(self, inputs, input_paddings, train):
546557
kernel_init=nn.initializers.xavier_uniform())(
547558
inputs)
548559

549-
inputs = BatchNorm(config)(inputs, input_paddings, train)
560+
inputs = BatchNorm(config)(inputs,
561+
input_paddings,
562+
update_batch_norm,
563+
use_running_average_bn)
550564
if config.activation_function_name == 'swish':
551565
activation_fn = nn.swish
552566
elif config.activation_function_name == 'gelu':
@@ -586,7 +600,12 @@ class ConformerBlock(nn.Module):
586600
config: ConformerConfig
587601

588602
@nn.compact
589-
def __call__(self, inputs, input_paddings, train):
603+
def __call__(self,
604+
inputs,
605+
input_paddings,
606+
train,
607+
update_batch_norm,
608+
use_running_average):
590609
config = self.config
591610
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
592611

@@ -597,7 +616,12 @@ def __call__(self, inputs, input_paddings, train):
597616
inputs, input_paddings, train)
598617

599618
inputs = inputs + \
600-
ConvolutionBlock(config)(inputs, input_paddings, train)
619+
ConvolutionBlock(config)(inputs,
620+
input_paddings,
621+
train,
622+
update_batch_norm,
623+
use_running_average
624+
)
601625

602626
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
603627
inputs, padding_mask, train)
@@ -629,12 +653,23 @@ def setup(self):
629653
.use_dynamic_time_mask_max_frames)
630654

631655
@nn.compact
632-
def __call__(self, inputs, input_paddings, train):
656+
def __call__(self,
657+
inputs,
658+
input_paddings,
659+
train,
660+
update_batch_norm: Optional[bool] = None,
661+
use_running_average_bn: Optional[bool] = None):
633662
config = self.config
634663

635664
outputs = inputs
636665
output_paddings = input_paddings
637666

667+
# Set BN args if not supplied for backwards compatibility
668+
if update_batch_norm is None:
669+
update_batch_norm = train
670+
if use_running_average_bn is None:
671+
use_running_average_bn = not train
672+
638673
# Compute normalized log mel spectrograms from input audio signal.
639674
preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
640675
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
@@ -660,7 +695,11 @@ def __call__(self, inputs, input_paddings, train):
660695

661696
# Run the conformer encoder layers.
662697
for _ in range(config.num_encoder_layers):
663-
outputs = ConformerBlock(config)(outputs, output_paddings, train)
698+
outputs = ConformerBlock(config)(outputs,
699+
output_paddings,
700+
train,
701+
update_batch_norm,
702+
use_running_average_bn)
664703

665704
outputs = LayerNorm(config.encoder_dim)(outputs)
666705
# Run the decoder which in this case is a trivial projection layer.

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def model_fn(
107107
model_state: spec.ModelAuxiliaryState,
108108
mode: spec.ForwardPassMode,
109109
rng: spec.RandomState,
110-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
110+
update_batch_norm: bool,
111+
use_running_average_bn: Optional[bool] = None
112+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
111113
variables = {'params': params, **model_state}
112114
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
113115
is_train_mode = mode == spec.ForwardPassMode.TRAIN
@@ -118,15 +120,17 @@ def model_fn(
118120
input_paddings,
119121
train=True,
120122
rngs={'dropout' : rng},
121-
mutable=['batch_stats'])
123+
mutable=['batch_stats'],
124+
use_running_average_bn=use_running_average_bn)
122125
return (logits, logit_paddings), new_model_state
123126
else:
124127
logits, logit_paddings = self._model.apply(
125128
variables,
126129
inputs,
127130
input_paddings,
128131
train=False,
129-
mutable=False)
132+
mutable=False,
133+
use_running_average_bn=use_running_average_bn)
130134
return (logits, logit_paddings), model_state
131135

132136
def _build_input_queue(

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ConformerConfig:
4040
time_masks_per_frame: float = 0.0
4141
use_dynamic_time_mask_max_frames: bool = True
4242
input_dropout_rate: float = 0.1
43-
batch_norm_momentum: float = 0.999
43+
batch_norm_momentum: float = 1 - 0.999
4444
batch_norm_epsilon: float = 0.001
4545
use_specaug: bool = True
4646
attention_temperature: float = 1.0
@@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings):
369369
mean = (masked_inp).sum(dim=(0, 1)) / count
370370
var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count
371371

372-
self.running_mean = self.momentum * self.running_mean + (
373-
1 - self.momentum) * mean.detach()
374-
self.running_var = self.momentum * self.running_var + (
375-
1 - self.momentum) * var.detach()
372+
self.running_mean = (1 - self.momentum) * self.running_mean + (
373+
self.momentum) * mean.detach()
374+
self.running_var = (1 - self.momentum) * self.running_var + (
375+
self.momentum) * var.detach()
376+
376377
else:
377378
mean = self.running_mean
378379
var = self.running_var

0 commit comments

Comments
 (0)