diff --git a/README.md b/README.md index 18b05df..abb9e5b 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,11 @@ We list the entries of other papers and submissions that used our LRA benchmark. Model | ListOps | Text | Retrieval | Image | Path | Path-X | Avg --------------- | --------- | --------- | --------- | --------- | --------- | ------ | --- IGLOO | 39.23 | 82 | 75.5 | 47.0 | 67.50 | NA | 62.25 +TLB | 37.05 | 81.88 | 76.91 | 57.51 | 79.06 | FAIL | 66.48 -IGLOO Submissions (by Vsevolod Sourkov) - https://github.com/redna11/lra-igloo +IGLOO Submissions (by Vsevolod Sourkov) - https://github.com/redna11/lra-igloo \ +TLB ([Temporal Latent Bottleneck](lra_benchmarks/models/transformer_tlb)) - [transformer_tlb](lra_benchmarks/models/transformer_tlb) ## Citation @@ -138,7 +140,7 @@ To run a task, run the train.py file in the corresponding task directory. (please see how to obtain the data for certain tasks if applicable). ``` -PYTHONPATH="$(pwd)":"$PYTHONPATH" python lra_benchmarks/listops/train.py \ +PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/listops/train.py \ --config=lra_benchmarks/listops/configs/transformer_base.py \ --model_dir=/tmp/listops \ --task_name=basic \ @@ -164,7 +166,7 @@ If you would like to go to longer/shorter sequence lengths, we also support generating your own split, run the following comment: ``` -PYTHONPATH="$(pwd)":"$PYTHONPATH" python lra_benchmarks/data/listops.py -- \ +PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/data/listops.py -- \ --output_dir=$HOME/lra_data/listops/ ``` diff --git a/lra_benchmarks/image/configs/cifar10/transformer_tlb_base.py b/lra_benchmarks/image/configs/cifar10/transformer_tlb_base.py new file mode 100644 index 0000000..58e2229 --- /dev/null +++ b/lra_benchmarks/image/configs/cifar10/transformer_tlb_base.py @@ -0,0 +1,38 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration and hyperparameter sweeps.""" + +from lra_benchmarks.image.configs.cifar10 import base_cifar10_config + + +def get_config(): + """Get the hyperparameter configuration.""" + config = base_cifar10_config.get_config() + config.model_type = "transformer_tlb" + config.learning_rate = .001 + config.model.emb_dim = 128 + config.model.mlp_dim = 128 + config.model.num_heads = 8 + config.model.qkv_dim = 64 + config.model.self_to_cross_ratio_input_updater = 1 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 5 + config.model.block_size = 32 + config.model.use_global_pos_encoding = False + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/image/configs/pathfinder32/transformer_tlb_base.py b/lra_benchmarks/image/configs/pathfinder32/transformer_tlb_base.py new file mode 100644 index 0000000..3ad009e --- /dev/null +++ b/lra_benchmarks/image/configs/pathfinder32/transformer_tlb_base.py @@ -0,0 +1,39 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base Configuration.""" + +from lra_benchmarks.image.configs.pathfinder32 import base_pathfinder32_config + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = base_pathfinder32_config.get_config() + config.model_type = "transformer_tlb" + config.learning_rate = 0.0005 + config.model.num_heads = 8 + config.model.emb_dim = 128 + config.model.dropout_rate = 0.1 + config.model.qkv_dim = 64 + config.model.mlp_dim = 128 + config.model.self_to_cross_ratio_input_updater = 2 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 5 + config.model.block_size = 64 + config.model.use_global_pos_encoding = True + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/image/train.py b/lra_benchmarks/image/train.py index 0580bee..f044b8b 100644 --- a/lra_benchmarks/image/train.py +++ b/lra_benchmarks/image/train.py @@ -24,8 +24,8 @@ from absl import flags from absl import logging from flax import jax_utils -from flax import nn from flax import optim +from flax.deprecated import nn from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import common_utils diff --git a/lra_benchmarks/listops/configs/transformer_tlb_base.py b/lra_benchmarks/listops/configs/transformer_tlb_base.py new file mode 100644 index 0000000..f75bb31 --- /dev/null +++ b/lra_benchmarks/listops/configs/transformer_tlb_base.py @@ -0,0 +1,37 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration and hyperparameter sweeps.""" + +from lra_benchmarks.listops.configs import base_listops_config +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = base_listops_config.get_config() + config.model_type = "transformer_tlb" + config.model = ml_collections.ConfigDict() + config.learning_rate = 0.05 / 4. + config.model.self_to_cross_ratio_input_updater = 2 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 100 + config.model.block_size = 100 + config.model.use_global_pos_encoding = False + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/listops/train.py b/lra_benchmarks/listops/train.py index 05e99cd..d03fb06 100644 --- a/lra_benchmarks/listops/train.py +++ b/lra_benchmarks/listops/train.py @@ -22,8 +22,8 @@ from absl import flags from absl import logging from flax import jax_utils -from flax import nn from flax import optim +from flax.deprecated import nn from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import common_utils @@ -188,6 +188,8 @@ def main(argv): 'classifier': True, 'num_classes': 10 }) + if 'model' in config: + model_kwargs.update(config.model) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) diff --git a/lra_benchmarks/matching/configs/transformer_tlb_base.py b/lra_benchmarks/matching/configs/transformer_tlb_base.py new file mode 100644 index 0000000..2f0d75a --- /dev/null +++ b/lra_benchmarks/matching/configs/transformer_tlb_base.py @@ -0,0 +1,39 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration and hyperparameter sweeps.""" + +from lra_benchmarks.matching.configs import base_match_config +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = base_match_config.get_config() + config.model_type = "transformer_tlb_dual" + config.batch_size = 128 + config.learning_rate = 0.05 / 2. + config.eval_frequency = 5000 + config.num_train_steps = 20000 + config.model = ml_collections.ConfigDict() + config.model.self_to_cross_ratio_input_updater = 2 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 10 + config.model.block_size = 10 + config.model.use_global_pos_encoding = False + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/matching/train.py b/lra_benchmarks/matching/train.py index 44ac106..6eacf23 100644 --- a/lra_benchmarks/matching/train.py +++ b/lra_benchmarks/matching/train.py @@ -22,8 +22,8 @@ from absl import flags from absl import logging from flax import jax_utils -from flax import nn from flax import optim +from flax.deprecated import nn from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import common_utils @@ -186,6 +186,8 @@ def main(argv): 'num_classes': 2, 'classifier_pool': config.pooling_mode } + if 'model' in config: + model_kwargs.update(config.model) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) @@ -195,7 +197,7 @@ def main(argv): dropout_rngs = random.split(rng, jax.local_device_count()) model = train_utils.get_model(model_type, create_model, model_kwargs, - init_rng, input_shape) + init_rng, input_shape, input_shape) optimizer = create_optimizer( model, learning_rate, weight_decay=FLAGS.config.weight_decay) diff --git a/lra_benchmarks/models/bigbird/bigbird.py b/lra_benchmarks/models/bigbird/bigbird.py index 8b45118..9732e84 100644 --- a/lra_benchmarks/models/bigbird/bigbird.py +++ b/lra_benchmarks/models/bigbird/bigbird.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Transformer using BigBird (https://arxiv.org/abs/2007.14062).""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.bigbird import bigbird_attention from lra_benchmarks.models.layers import common_layers diff --git a/lra_benchmarks/models/bigbird/bigbird_attention.py b/lra_benchmarks/models/bigbird/bigbird_attention.py index 915b6bd..25315a3 100644 --- a/lra_benchmarks/models/bigbird/bigbird_attention.py +++ b/lra_benchmarks/models/bigbird/bigbird_attention.py @@ -13,8 +13,8 @@ # limitations under the License. """Big Bird attention mechanism. See https://arxiv.org/abs/2007.14062.""" from absl import logging -from flax import nn -from flax.nn import attention +from flax.deprecated import nn +from flax.deprecated.nn import attention import jax from jax import lax import jax.numpy as jnp @@ -462,7 +462,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/layers/common_layers.py b/lra_benchmarks/models/layers/common_layers.py index 0705dff..1ca1083 100644 --- a/lra_benchmarks/models/layers/common_layers.py +++ b/lra_benchmarks/models/layers/common_layers.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Lint as: python3 """Common layers used in models.""" -from flax import nn +from flax.deprecated import nn from jax import lax import jax.numpy as jnp import numpy as np diff --git a/lra_benchmarks/models/linear_transformer/linear_attention.py b/lra_benchmarks/models/linear_transformer/linear_attention.py index 3e37c84..d95c97b 100644 --- a/lra_benchmarks/models/linear_transformer/linear_attention.py +++ b/lra_benchmarks/models/linear_transformer/linear_attention.py @@ -13,7 +13,7 @@ # limitations under the License. """Custom Attention modules for Linear Transformer.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp @@ -118,7 +118,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/linear_transformer/linear_transformer.py b/lra_benchmarks/models/linear_transformer/linear_transformer.py index 9fb420e..c637663 100644 --- a/lra_benchmarks/models/linear_transformer/linear_transformer.py +++ b/lra_benchmarks/models/linear_transformer/linear_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. """LinearTransformer model.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.linear_transformer import linear_attention diff --git a/lra_benchmarks/models/linformer/linformer.py b/lra_benchmarks/models/linformer/linformer.py index 5a2b0ab..6463b3b 100644 --- a/lra_benchmarks/models/linformer/linformer.py +++ b/lra_benchmarks/models/linformer/linformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Linformer models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.linformer import linformer_attention diff --git a/lra_benchmarks/models/linformer/linformer_attention.py b/lra_benchmarks/models/linformer/linformer_attention.py index 99611b8..f5a5f01 100644 --- a/lra_benchmarks/models/linformer/linformer_attention.py +++ b/lra_benchmarks/models/linformer/linformer_attention.py @@ -13,8 +13,8 @@ # limitations under the License. """Custom Attention core modules for Flax.""" -from flax import nn -from flax.nn.attention import dot_product_attention +from flax.deprecated import nn +from flax.deprecated.nn.attention import dot_product_attention from jax import lax import jax.numpy as jnp @@ -74,7 +74,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/local/local.py b/lra_benchmarks/models/local/local.py index b2e3bc0..b52989d 100644 --- a/lra_benchmarks/models/local/local.py +++ b/lra_benchmarks/models/local/local.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Local Attention Transformer models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.local import local_attention diff --git a/lra_benchmarks/models/local/local_attention.py b/lra_benchmarks/models/local/local_attention.py index e0dc4a4..a988670 100644 --- a/lra_benchmarks/models/local/local_attention.py +++ b/lra_benchmarks/models/local/local_attention.py @@ -16,12 +16,12 @@ from collections.abc import Iterable # pylint: disable=g-importing-member from absl import logging -from flax import nn -from flax.nn.attention import _CacheEntry -from flax.nn.attention import _make_causal_mask -from flax.nn.attention import Cache -from flax.nn.attention import make_padding_mask -from flax.nn.stochastic import make_rng +from flax.deprecated import nn +from flax.deprecated.nn.attention import _CacheEntry +from flax.deprecated.nn.attention import _make_causal_mask +from flax.deprecated.nn.attention import Cache +from flax.deprecated.nn.attention import make_padding_mask +from flax.deprecated.nn.stochastic import make_rng import jax from jax import lax from jax import random @@ -42,7 +42,7 @@ def local_dot_product_attention(query, precision=None): """Computes dot-product attention given query, key, and value. - Note: This is equivalent to the dot product attention in flax.nn. + Note: This is equivalent to the dot product attention in flax.deprecated.nn. However, we do extra broadcasting of the bias in this function. I'm leaving this here in case we need to modify something later. @@ -206,7 +206,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/longformer/longformer.py b/lra_benchmarks/models/longformer/longformer.py index 0957b6f..c1d6590 100644 --- a/lra_benchmarks/models/longformer/longformer.py +++ b/lra_benchmarks/models/longformer/longformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Longformer modules.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.longformer import longformer_attention diff --git a/lra_benchmarks/models/longformer/longformer_attention.py b/lra_benchmarks/models/longformer/longformer_attention.py index 96bfbba..31f35f8 100644 --- a/lra_benchmarks/models/longformer/longformer_attention.py +++ b/lra_benchmarks/models/longformer/longformer_attention.py @@ -19,7 +19,7 @@ supported, however. """ -from flax import nn +from flax.deprecated import nn from jax import lax import jax.numpy as jnp import numpy as np diff --git a/lra_benchmarks/models/performer/performer.py b/lra_benchmarks/models/performer/performer.py index 70c90df..a9077d5 100644 --- a/lra_benchmarks/models/performer/performer.py +++ b/lra_benchmarks/models/performer/performer.py @@ -15,7 +15,7 @@ import functools from absl import logging -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.performer import performer_attention diff --git a/lra_benchmarks/models/reformer/reformer.py b/lra_benchmarks/models/reformer/reformer.py index 6535965..f8f582c 100644 --- a/lra_benchmarks/models/reformer/reformer.py +++ b/lra_benchmarks/models/reformer/reformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Reformer language models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.reformer import reformer_attention diff --git a/lra_benchmarks/models/reformer/reformer_attention.py b/lra_benchmarks/models/reformer/reformer_attention.py index 22ac618..fac38dc 100644 --- a/lra_benchmarks/models/reformer/reformer_attention.py +++ b/lra_benchmarks/models/reformer/reformer_attention.py @@ -14,7 +14,7 @@ """Attention modules for Reformer model.""" from absl import logging -from flax import nn +from flax.deprecated import nn import jax import jax.numpy as jnp from jax.scipy.special import logsumexp @@ -270,7 +270,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_attention.py b/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_attention.py index b5311c5..6606c44 100644 --- a/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_attention.py +++ b/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_attention.py @@ -15,12 +15,12 @@ from collections.abc import Iterable # pylint: disable=g-importing-member -from flax import nn -from flax.nn.attention import _CacheEntry -from flax.nn.attention import _make_causal_mask -from flax.nn.attention import Cache -from flax.nn.attention import make_padding_mask -from flax.nn.stochastic import make_rng +from flax.deprecated import nn +from flax.deprecated.nn.attention import _CacheEntry +from flax.deprecated.nn.attention import _make_causal_mask +from flax.deprecated.nn.attention import Cache +from flax.deprecated.nn.attention import make_padding_mask +from flax.deprecated.nn.stochastic import make_rng import jax from jax import lax from jax import random @@ -77,7 +77,7 @@ def local_dot_product_attention(query, precision=None): """Computes dot-product attention given query, key, and value. - Note: This is equivalent to the dot product attention in flax.nn. + Note: This is equivalent to the dot product attention in flax.deprecated.nn. However, we do extra broadcasting of the bias in this function. I'm leaving this here incase we need to modify something later. @@ -243,7 +243,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_transformer.py b/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_transformer.py index 3a9df47..9904a32 100644 --- a/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_transformer.py +++ b/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_transformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn Attention Transformer models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.sinkhorn_transformer import sinkhorn_attention diff --git a/lra_benchmarks/models/sparse_transformer/sparse_attention.py b/lra_benchmarks/models/sparse_transformer/sparse_attention.py index c08acb3..6f52646 100644 --- a/lra_benchmarks/models/sparse_transformer/sparse_attention.py +++ b/lra_benchmarks/models/sparse_transformer/sparse_attention.py @@ -21,7 +21,7 @@ from typing import Iterable import attr -from flax import nn +from flax.deprecated import nn from jax import lax import jax.numpy as jnp import numpy as np diff --git a/lra_benchmarks/models/sparse_transformer/sparse_transformer.py b/lra_benchmarks/models/sparse_transformer/sparse_transformer.py index 9471b69..04fe004 100644 --- a/lra_benchmarks/models/sparse_transformer/sparse_transformer.py +++ b/lra_benchmarks/models/sparse_transformer/sparse_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. """Sparse Transformer modules.""" from absl import logging -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.sparse_transformer import sparse_attention diff --git a/lra_benchmarks/models/synthesizer/synthesizer.py b/lra_benchmarks/models/synthesizer/synthesizer.py index ec80760..f6c9daf 100644 --- a/lra_benchmarks/models/synthesizer/synthesizer.py +++ b/lra_benchmarks/models/synthesizer/synthesizer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Synthesizer models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.synthesizer import synthesizer_attention diff --git a/lra_benchmarks/models/synthesizer/synthesizer_attention.py b/lra_benchmarks/models/synthesizer/synthesizer_attention.py index e95251b..9002860 100644 --- a/lra_benchmarks/models/synthesizer/synthesizer_attention.py +++ b/lra_benchmarks/models/synthesizer/synthesizer_attention.py @@ -17,12 +17,12 @@ from absl import logging -from flax import nn -from flax.nn.attention import _CacheEntry -from flax.nn.attention import _make_causal_mask -from flax.nn.attention import Cache -from flax.nn.attention import make_padding_mask -from flax.nn.stochastic import make_rng +from flax.deprecated import nn +from flax.deprecated.nn.attention import _CacheEntry +from flax.deprecated.nn.attention import _make_causal_mask +from flax.deprecated.nn.attention import Cache +from flax.deprecated.nn.attention import make_padding_mask +from flax.deprecated.nn.stochastic import make_rng import jax from jax import lax from jax import random @@ -225,7 +225,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/transformer/transformer.py b/lra_benchmarks/models/transformer/transformer.py index afffaa8..0cea7aa 100644 --- a/lra_benchmarks/models/transformer/transformer.py +++ b/lra_benchmarks/models/transformer/transformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Transformer model.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers diff --git a/lra_benchmarks/models/transformer_tlb/__init__.py b/lra_benchmarks/models/transformer_tlb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lra_benchmarks/models/transformer_tlb/transformer_tlb.py b/lra_benchmarks/models/transformer_tlb/transformer_tlb.py new file mode 100644 index 0000000..506de62 --- /dev/null +++ b/lra_benchmarks/models/transformer_tlb/transformer_tlb.py @@ -0,0 +1,416 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformer-based stateful lra models.""" +from flax.deprecated import nn +import jax +import jax.numpy as jnp +from lra_benchmarks.models.layers import common_layers +from lra_benchmarks.models.transformer import transformer + + +class CrossTransformerBlock(nn.Module): + """Cross Transformer layer.""" + + def apply(self, + inputs_q, + inputs_kv, + qkv_dim, + mlp_dim, + num_heads, + dtype=jnp.float32, + inputs_segmentation=None, + causal_mask=False, + padding_mask=None, + key_padding_mask=None, + dropout_rate=0.1, + attention_dropout_rate=0.1, + deterministic=False, + cache=None, + residual=True): + """Applies CrossTransformerBlock module. + + Args: + inputs_q: input query + inputs_kv: input key-value + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + num_heads: number of heads + dtype: the dtype of the computation (default: float32). + inputs_segmentation: input segmentation info for packed examples. + causal_mask: bool, mask future or not + padding_mask: bool, mask padding tokens + key_padding_mask: bool, mask padding tokens + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + deterministic: bool, deterministic or not (to apply dropout) + cache: flax autoregressive cache for fast decoding. + residual: Boolean, to use residual connectors or not. + + Returns: + output after transformer block. + + """ + + # Attention block. + assert inputs_q.ndim == 3 + x = nn.LayerNorm(inputs_q) + s = nn.LayerNorm(inputs_kv) + x = nn.MultiHeadDotProductAttention( + x, s, + num_heads=num_heads, + dtype=dtype, + qkv_features=qkv_dim, + attention_axis=(1,), + causal_mask=causal_mask, + segmentation=inputs_segmentation, + padding_mask=padding_mask, + key_padding_mask=key_padding_mask, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.normal(stddev=1e-6), + bias=False, + broadcast_dropout=False, + dropout_rate=attention_dropout_rate, + deterministic=deterministic, + cache=cache) + x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) + x = x + inputs_q + + # MLP block. + y = nn.LayerNorm(x) + y = common_layers.MlpBlock( + y, + mlp_dim=mlp_dim, + dtype=dtype, + dropout_rate=dropout_rate, + deterministic=deterministic) + + if residual: + output = x + y + else: + output = x + if padding_mask is not None: + corner_case = (jnp.sum(padding_mask, + axis=1) == 0)[..., None] + output = jnp.where(corner_case, inputs_q, output) + elif key_padding_mask is not None: + corner_case = (jnp.sum(key_padding_mask, + axis=1) == 0)[..., None] + output = jnp.where(corner_case, inputs_q, output) + return output + + +class StatefulTransformerEncoder(nn.Module): + """Stateful Transformer Model Encoder (https://arxiv.org/abs/2205.14794).""" + + def apply(self, + inputs, + vocab_size, + inputs_positions=None, + inputs_segmentation=None, + shared_embedding=None, + use_bfloat16=False, + emb_dim=512, + num_heads=8, + dtype=jnp.float32, + num_layers=6, + qkv_dim=512, + mlp_dim=2048, + max_len=512, + train=True, + dropout_rate=0.1, + attention_dropout_rate=0.1, + learn_pos_emb=False, + classifier=False, + classifier_pool='CLS', + num_classes=10, + tied_weights=False, + meta_network=False, + meta_layers=1, + meta_pool='last', + use_residual=True, + meta_partition=3, + meta_layer_output=False, + self_to_cross_ratio_input_updater=2, + num_cross_layers_input_updater=1, + num_cross_layers_state_updater=1, + num_state_tokens=20, + block_size=20, + use_global_pos_encoding=False): + """Applies Transformer model on the inputs. + + Args: + inputs: input data + vocab_size: size of the vocabulary + inputs_positions: input subsequence positions for packed examples. + inputs_segmentation: input segmentation info for packed examples. + shared_embedding: a shared embedding layer to use. + use_bfloat16: bool: whether use bfloat16. + emb_dim: dimension of embedding + num_heads: number of heads + dtype: the dtype of the computation (default: float32) + num_layers: number of layers + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + max_len: maximum length. + train: if it is training, + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + learn_pos_emb: boolean, if learn the positional embedding or use the + sinusoidal positional embedding. + classifier: boolean, for classification mode (output N-class logits) + classifier_pool: str, supports "MEAN", "MAX" pooling. + num_classes: int, number of classification classes. + tied_weights: bool, to tie weights or not. + meta_network: boolean, experimental extreme self-attention. + meta_layers: int, number of meta_layers + meta_pool: str, the type of meta pooling. + use_residual: boolean, turn off transformer residuals. + meta_partition: int. + meta_layer_output: boolean. + self_to_cross_ratio_input_updater: number of self-attention layers before + each cross attention layer in the input-update direction + num_cross_layers_input_updater: number of cross-attention layers + in the input update direction + num_cross_layers_state_updater: number of cross-attention layers + in the state update direction + num_state_tokens: number of state tokens + block_size: chunk size of inputs + use_global_pos_encoding: Whether the input position embedding is global + or local + + Returns: + output of a transformer encoder or logits if classifier_mode is true. + """ + assert inputs.ndim == 2 # (batch, len) + + # Padding Masks + src_padding_mask = (inputs > 0)[..., None] + + # Input Embedding + if shared_embedding is None: + input_embed = nn.Embed.partial( + num_embeddings=vocab_size, + features=emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0)) + else: + input_embed = shared_embedding + x = inputs.astype('int32') + x = input_embed(x) + + # Input Positional Encoding + pe_init = nn.initializers.normal(stddev=0.02) if learn_pos_emb else None + if use_global_pos_encoding: + x = common_layers.AddPositionEmbs( + x, + inputs_positions=inputs_positions, + posemb_init=pe_init, + max_len=max_len, + name='global_posembed_input') + pe = common_layers.AddPositionEmbs( + jnp.zeros((x.shape[0], block_size, x.shape[2]), + dtype=x.dtype), + inputs_positions=inputs_positions, + posemb_init=pe_init, + max_len=max_len, + name='posembed_input') + + if use_bfloat16: + x = x.astype(jnp.bfloat16) + dtype = jnp.bfloat16 + else: + dtype = jnp.float32 + + # Create layers + horizontal_blocks = [] + for block_idx in range(num_cross_layers_input_updater): + horizontal_layers = [] + for layer_idx in range(self_to_cross_ratio_input_updater): + horizontal_layers.append( + transformer.TransformerBlock.shared( + qkv_dim=qkv_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + dtype=dtype, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + deterministic=not train, + name=f'horizontal_block_{block_idx}_self_{layer_idx}')) + horizontal_layers.append( + CrossTransformerBlock.shared( + qkv_dim=qkv_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + dtype=dtype, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + deterministic=not train, + name=f'horizontal_block_{block_idx}_cross', + residual=use_residual)) + horizontal_blocks.append(horizontal_layers) + vertical_layers = [] + for layer_idx in range(num_cross_layers_state_updater): + vertical_layers.append( + CrossTransformerBlock.shared( + qkv_dim=qkv_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + dtype=dtype, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + deterministic=not train, + name=f'vertical_block_{block_idx}_cross', + residual=use_residual)) + + num_tokens = x.shape[1] + num_chunks = num_tokens // block_size + init_state = jnp.zeros((x.shape[0], num_state_tokens, x.shape[2])) + x_with_pad = jnp.concatenate([x, src_padding_mask], axis=2) + # Split inputs into chunks of block_size. + x_with_pad = jnp.stack( + jnp.split(x_with_pad, num_chunks, axis=1), axis=0) + + # State positional encoding + state_pos_embed = common_layers.AddPositionEmbs( + init_state, + inputs_positions=None, + posemb_init=None, + max_len=num_state_tokens, + name='posembed_state') + + # Processing function for each chunk + def scan_inner(cur_state, cur_x_with_pad): + padding_mask_cur = cur_x_with_pad[:, :, -1][:, :, None] + x_cur = cur_x_with_pad[:, :, :-1] + if not use_global_pos_encoding: + x_cur = x_cur + pe + x_cur = nn.dropout(x_cur, rate=dropout_rate, deterministic=not train) + cur_state = cur_state + state_pos_embed + for block_idx in range(num_cross_layers_input_updater): + for layer_idx in range(self_to_cross_ratio_input_updater): + x_cur = horizontal_blocks[block_idx][layer_idx]( + x_cur, padding_mask=padding_mask_cur) + x_cur = horizontal_blocks[block_idx][-1]( + x_cur, cur_state, padding_mask=padding_mask_cur) + for layer_idx in range(num_cross_layers_state_updater): + cur_state = vertical_layers[layer_idx]( + cur_state, x_cur, key_padding_mask=padding_mask_cur) + return cur_state, None + + # Scan + cur_state, _ = jax.lax.scan(scan_inner, init_state, x_with_pad, unroll=40) + + assert cur_state.shape == init_state.shape + + encoded = nn.LayerNorm(cur_state, dtype=dtype, name='encoder_norm') + + if classifier: + encoded = common_layers.classifier_head( + encoded, num_classes, mlp_dim, pooling_mode='MEAN') + return encoded + + +class StatefulTransformerDualEncoder(nn.Module): + """Stateful Transformer Model Encoder (https://arxiv.org/abs/2205.14794).""" + + def apply(self, + inputs1, + inputs2, + vocab_size, + inputs1_positions=None, + inputs2_positions=None, + inputs1_segmentation=None, + inputs2_segmentation=None, + shared_embedding=None, + use_bfloat16=False, + emb_dim=512, + num_heads=8, + dtype=jnp.float32, + num_layers=6, + qkv_dim=512, + mlp_dim=2048, + max_len=512, + train=True, + dropout_rate=0.1, + attention_dropout_rate=0.1, + learn_pos_emb=False, + classifier=False, + classifier_pool='CLS', + num_classes=10, + tied_weights=False, + meta_network=False, + meta_layers=1, + meta_pool='last', + use_residual=True, + meta_partition=3, + meta_layer_output=False, + self_to_cross_ratio_input_updater=2, + num_cross_layers_input_updater=1, + num_cross_layers_state_updater=1, + num_state_tokens=20, + block_size=20, + use_global_pos_encoding=False, + interaction=None): + """Applies Transformer model on the inputs.""" + encoder = StatefulTransformerEncoder.shared( + vocab_size=vocab_size, + shared_embedding=shared_embedding, + use_bfloat16=use_bfloat16, + emb_dim=emb_dim, + num_heads=num_heads, + dtype=dtype, + num_layers=num_layers, + qkv_dim=qkv_dim, + mlp_dim=mlp_dim, + max_len=max_len, + train=train, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + learn_pos_emb=learn_pos_emb, + classifier=False, + classifier_pool=classifier_pool, + num_classes=num_classes, + tied_weights=tied_weights, + meta_network=meta_network, + meta_layers=meta_layers, + meta_pool=meta_pool, + use_residual=use_residual, + meta_partition=meta_partition, + meta_layer_output=meta_layer_output, + self_to_cross_ratio_input_updater=self_to_cross_ratio_input_updater, + num_cross_layers_input_updater=num_cross_layers_input_updater, + num_cross_layers_state_updater=num_cross_layers_state_updater, + num_state_tokens=num_state_tokens, + block_size=block_size, + use_global_pos_encoding=use_global_pos_encoding, + name='stateful_encoder') + inputs1_encoded = encoder( + inputs=inputs1, + inputs_positions=inputs1_positions, + inputs_segmentation=inputs1_segmentation) + inputs2_encoded = encoder( + inputs=inputs2, + inputs_positions=inputs2_positions, + inputs_segmentation=inputs2_segmentation) + + encoded = common_layers.classifier_head_dual( + inputs1_encoded, + inputs2_encoded, + num_classes, + mlp_dim, + pooling_mode='MEAN', + interaction=interaction) + return encoded diff --git a/lra_benchmarks/text_classification/configs/transformer_tlb_base.py b/lra_benchmarks/text_classification/configs/transformer_tlb_base.py new file mode 100644 index 0000000..7dd75ca --- /dev/null +++ b/lra_benchmarks/text_classification/configs/transformer_tlb_base.py @@ -0,0 +1,39 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration and hyperparameter sweeps.""" + +from lra_benchmarks.text_classification.configs import base_tc_config +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = base_tc_config.get_config() + config.model_type = "transformer_tlb" + config.learning_rate = 0.05/2. + config.model = ml_collections.ConfigDict() + config.model.self_to_cross_ratio_input_updater = 2 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 10 + config.model.block_size = 10 + config.model.use_global_pos_encoding = False + + config.max_length = 4000 + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/text_classification/train.py b/lra_benchmarks/text_classification/train.py index 1289750..b3b44f2 100644 --- a/lra_benchmarks/text_classification/train.py +++ b/lra_benchmarks/text_classification/train.py @@ -22,8 +22,8 @@ from absl import flags from absl import logging from flax import jax_utils -from flax import nn from flax import optim +from flax.deprecated import nn from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import common_utils @@ -184,6 +184,8 @@ def main(argv): 'num_classes': CLASS_MAP[FLAGS.task_name], 'classifier_pool': config.classifier_pool } + if 'model' in config: + model_kwargs.update(config.model) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) diff --git a/lra_benchmarks/utils/train_utils.py b/lra_benchmarks/utils/train_utils.py index 75e04dc..f7851e1 100644 --- a/lra_benchmarks/utils/train_utils.py +++ b/lra_benchmarks/utils/train_utils.py @@ -13,7 +13,7 @@ # limitations under the License. """This contains utility functions for model training and evaluation.""" -from flax import nn +from flax.deprecated import nn from flax.training import common_utils import jax.numpy as jnp from lra_benchmarks.models.bigbird import bigbird @@ -28,6 +28,7 @@ from lra_benchmarks.models.sparse_transformer import sparse_transformer from lra_benchmarks.models.synthesizer import synthesizer from lra_benchmarks.models.transformer import transformer +from lra_benchmarks.models.transformer_tlb import transformer_tlb import numpy as onp @@ -80,6 +81,12 @@ def get_model(model_type, create_model_fn, model_kwargs, *create_model_args): elif model_type == 'longformer': return create_model_fn(longformer.LongformerEncoder, model_kwargs, *create_model_args) + elif model_type == 'transformer_tlb': + return create_model_fn(transformer_tlb.StatefulTransformerEncoder, + model_kwargs, *create_model_args) + elif model_type == 'transformer_tlb_dual': + return create_model_fn(transformer_tlb.StatefulTransformerDualEncoder, + model_kwargs, *create_model_args) else: raise ValueError('Model type not supported') diff --git a/requirements.txt b/requirements.txt index b52c9c1..08997ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,3 @@ ml-collections>=0.1.0 tensorboard>=2.3.0 tensorflow>=2.3.1 tensorflow-datasets>=4.0.1 -tensorflow_text -gin -gin-config>=0.1.3