From b4f9fedd1a61bf658062747bb2a8b01c8dd518f0 Mon Sep 17 00:00:00 2001 From: Nitesh Gundavarapu Date: Thu, 21 Apr 2022 06:39:23 +0000 Subject: [PATCH 1/5] Internal change PiperOrigin-RevId: 443291669 --- README.md | 4 ++-- lra_benchmarks/image/train.py | 2 +- lra_benchmarks/listops/train.py | 2 +- lra_benchmarks/matching/train.py | 2 +- lra_benchmarks/models/bigbird/bigbird.py | 2 +- .../models/bigbird/bigbird_attention.py | 6 +++--- lra_benchmarks/models/layers/common_layers.py | 3 +-- .../linear_transformer/linear_attention.py | 4 ++-- .../linear_transformer/linear_transformer.py | 2 +- lra_benchmarks/models/linformer/linformer.py | 2 +- .../models/linformer/linformer_attention.py | 6 +++--- lra_benchmarks/models/local/local.py | 2 +- lra_benchmarks/models/local/local_attention.py | 16 ++++++++-------- lra_benchmarks/models/longformer/longformer.py | 2 +- .../models/longformer/longformer_attention.py | 2 +- lra_benchmarks/models/performer/performer.py | 2 +- lra_benchmarks/models/reformer/reformer.py | 2 +- .../models/reformer/reformer_attention.py | 4 ++-- .../sinkhorn_transformer/sinkhorn_attention.py | 16 ++++++++-------- .../sinkhorn_transformer/sinkhorn_transformer.py | 2 +- .../sparse_transformer/sparse_attention.py | 2 +- .../sparse_transformer/sparse_transformer.py | 2 +- lra_benchmarks/models/synthesizer/synthesizer.py | 2 +- .../models/synthesizer/synthesizer_attention.py | 14 +++++++------- lra_benchmarks/models/transformer/transformer.py | 2 +- lra_benchmarks/text_classification/train.py | 2 +- lra_benchmarks/utils/train_utils.py | 2 +- requirements.txt | 3 --- 28 files changed, 54 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 18b05df..b76dff7 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ To run a task, run the train.py file in the corresponding task directory. (please see how to obtain the data for certain tasks if applicable). ``` -PYTHONPATH="$(pwd)":"$PYTHONPATH" python lra_benchmarks/listops/train.py \ +PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/listops/train.py \ --config=lra_benchmarks/listops/configs/transformer_base.py \ --model_dir=/tmp/listops \ --task_name=basic \ @@ -164,7 +164,7 @@ If you would like to go to longer/shorter sequence lengths, we also support generating your own split, run the following comment: ``` -PYTHONPATH="$(pwd)":"$PYTHONPATH" python lra_benchmarks/data/listops.py -- \ +PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/data/listops.py -- \ --output_dir=$HOME/lra_data/listops/ ``` diff --git a/lra_benchmarks/image/train.py b/lra_benchmarks/image/train.py index 0580bee..f044b8b 100644 --- a/lra_benchmarks/image/train.py +++ b/lra_benchmarks/image/train.py @@ -24,8 +24,8 @@ from absl import flags from absl import logging from flax import jax_utils -from flax import nn from flax import optim +from flax.deprecated import nn from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import common_utils diff --git a/lra_benchmarks/listops/train.py b/lra_benchmarks/listops/train.py index 05e99cd..80cd5fc 100644 --- a/lra_benchmarks/listops/train.py +++ b/lra_benchmarks/listops/train.py @@ -22,8 +22,8 @@ from absl import flags from absl import logging from flax import jax_utils -from flax import nn from flax import optim +from flax.deprecated import nn from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import common_utils diff --git a/lra_benchmarks/matching/train.py b/lra_benchmarks/matching/train.py index 44ac106..d56b752 100644 --- a/lra_benchmarks/matching/train.py +++ b/lra_benchmarks/matching/train.py @@ -22,8 +22,8 @@ from absl import flags from absl import logging from flax import jax_utils -from flax import nn from flax import optim +from flax.deprecated import nn from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import common_utils diff --git a/lra_benchmarks/models/bigbird/bigbird.py b/lra_benchmarks/models/bigbird/bigbird.py index 8b45118..9732e84 100644 --- a/lra_benchmarks/models/bigbird/bigbird.py +++ b/lra_benchmarks/models/bigbird/bigbird.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Transformer using BigBird (https://arxiv.org/abs/2007.14062).""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.bigbird import bigbird_attention from lra_benchmarks.models.layers import common_layers diff --git a/lra_benchmarks/models/bigbird/bigbird_attention.py b/lra_benchmarks/models/bigbird/bigbird_attention.py index 915b6bd..25315a3 100644 --- a/lra_benchmarks/models/bigbird/bigbird_attention.py +++ b/lra_benchmarks/models/bigbird/bigbird_attention.py @@ -13,8 +13,8 @@ # limitations under the License. """Big Bird attention mechanism. See https://arxiv.org/abs/2007.14062.""" from absl import logging -from flax import nn -from flax.nn import attention +from flax.deprecated import nn +from flax.deprecated.nn import attention import jax from jax import lax import jax.numpy as jnp @@ -462,7 +462,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/layers/common_layers.py b/lra_benchmarks/models/layers/common_layers.py index 0705dff..1ca1083 100644 --- a/lra_benchmarks/models/layers/common_layers.py +++ b/lra_benchmarks/models/layers/common_layers.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Lint as: python3 """Common layers used in models.""" -from flax import nn +from flax.deprecated import nn from jax import lax import jax.numpy as jnp import numpy as np diff --git a/lra_benchmarks/models/linear_transformer/linear_attention.py b/lra_benchmarks/models/linear_transformer/linear_attention.py index 3e37c84..d95c97b 100644 --- a/lra_benchmarks/models/linear_transformer/linear_attention.py +++ b/lra_benchmarks/models/linear_transformer/linear_attention.py @@ -13,7 +13,7 @@ # limitations under the License. """Custom Attention modules for Linear Transformer.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp @@ -118,7 +118,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/linear_transformer/linear_transformer.py b/lra_benchmarks/models/linear_transformer/linear_transformer.py index 9fb420e..c637663 100644 --- a/lra_benchmarks/models/linear_transformer/linear_transformer.py +++ b/lra_benchmarks/models/linear_transformer/linear_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. """LinearTransformer model.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.linear_transformer import linear_attention diff --git a/lra_benchmarks/models/linformer/linformer.py b/lra_benchmarks/models/linformer/linformer.py index 5a2b0ab..6463b3b 100644 --- a/lra_benchmarks/models/linformer/linformer.py +++ b/lra_benchmarks/models/linformer/linformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Linformer models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.linformer import linformer_attention diff --git a/lra_benchmarks/models/linformer/linformer_attention.py b/lra_benchmarks/models/linformer/linformer_attention.py index 99611b8..f5a5f01 100644 --- a/lra_benchmarks/models/linformer/linformer_attention.py +++ b/lra_benchmarks/models/linformer/linformer_attention.py @@ -13,8 +13,8 @@ # limitations under the License. """Custom Attention core modules for Flax.""" -from flax import nn -from flax.nn.attention import dot_product_attention +from flax.deprecated import nn +from flax.deprecated.nn.attention import dot_product_attention from jax import lax import jax.numpy as jnp @@ -74,7 +74,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/local/local.py b/lra_benchmarks/models/local/local.py index b2e3bc0..b52989d 100644 --- a/lra_benchmarks/models/local/local.py +++ b/lra_benchmarks/models/local/local.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Local Attention Transformer models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.local import local_attention diff --git a/lra_benchmarks/models/local/local_attention.py b/lra_benchmarks/models/local/local_attention.py index e0dc4a4..a988670 100644 --- a/lra_benchmarks/models/local/local_attention.py +++ b/lra_benchmarks/models/local/local_attention.py @@ -16,12 +16,12 @@ from collections.abc import Iterable # pylint: disable=g-importing-member from absl import logging -from flax import nn -from flax.nn.attention import _CacheEntry -from flax.nn.attention import _make_causal_mask -from flax.nn.attention import Cache -from flax.nn.attention import make_padding_mask -from flax.nn.stochastic import make_rng +from flax.deprecated import nn +from flax.deprecated.nn.attention import _CacheEntry +from flax.deprecated.nn.attention import _make_causal_mask +from flax.deprecated.nn.attention import Cache +from flax.deprecated.nn.attention import make_padding_mask +from flax.deprecated.nn.stochastic import make_rng import jax from jax import lax from jax import random @@ -42,7 +42,7 @@ def local_dot_product_attention(query, precision=None): """Computes dot-product attention given query, key, and value. - Note: This is equivalent to the dot product attention in flax.nn. + Note: This is equivalent to the dot product attention in flax.deprecated.nn. However, we do extra broadcasting of the bias in this function. I'm leaving this here in case we need to modify something later. @@ -206,7 +206,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/longformer/longformer.py b/lra_benchmarks/models/longformer/longformer.py index 0957b6f..c1d6590 100644 --- a/lra_benchmarks/models/longformer/longformer.py +++ b/lra_benchmarks/models/longformer/longformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Longformer modules.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.longformer import longformer_attention diff --git a/lra_benchmarks/models/longformer/longformer_attention.py b/lra_benchmarks/models/longformer/longformer_attention.py index 96bfbba..31f35f8 100644 --- a/lra_benchmarks/models/longformer/longformer_attention.py +++ b/lra_benchmarks/models/longformer/longformer_attention.py @@ -19,7 +19,7 @@ supported, however. """ -from flax import nn +from flax.deprecated import nn from jax import lax import jax.numpy as jnp import numpy as np diff --git a/lra_benchmarks/models/performer/performer.py b/lra_benchmarks/models/performer/performer.py index 70c90df..a9077d5 100644 --- a/lra_benchmarks/models/performer/performer.py +++ b/lra_benchmarks/models/performer/performer.py @@ -15,7 +15,7 @@ import functools from absl import logging -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.performer import performer_attention diff --git a/lra_benchmarks/models/reformer/reformer.py b/lra_benchmarks/models/reformer/reformer.py index 6535965..f8f582c 100644 --- a/lra_benchmarks/models/reformer/reformer.py +++ b/lra_benchmarks/models/reformer/reformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Reformer language models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.reformer import reformer_attention diff --git a/lra_benchmarks/models/reformer/reformer_attention.py b/lra_benchmarks/models/reformer/reformer_attention.py index 22ac618..fac38dc 100644 --- a/lra_benchmarks/models/reformer/reformer_attention.py +++ b/lra_benchmarks/models/reformer/reformer_attention.py @@ -14,7 +14,7 @@ """Attention modules for Reformer model.""" from absl import logging -from flax import nn +from flax.deprecated import nn import jax import jax.numpy as jnp from jax.scipy.special import logsumexp @@ -270,7 +270,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_attention.py b/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_attention.py index b5311c5..6606c44 100644 --- a/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_attention.py +++ b/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_attention.py @@ -15,12 +15,12 @@ from collections.abc import Iterable # pylint: disable=g-importing-member -from flax import nn -from flax.nn.attention import _CacheEntry -from flax.nn.attention import _make_causal_mask -from flax.nn.attention import Cache -from flax.nn.attention import make_padding_mask -from flax.nn.stochastic import make_rng +from flax.deprecated import nn +from flax.deprecated.nn.attention import _CacheEntry +from flax.deprecated.nn.attention import _make_causal_mask +from flax.deprecated.nn.attention import Cache +from flax.deprecated.nn.attention import make_padding_mask +from flax.deprecated.nn.stochastic import make_rng import jax from jax import lax from jax import random @@ -77,7 +77,7 @@ def local_dot_product_attention(query, precision=None): """Computes dot-product attention given query, key, and value. - Note: This is equivalent to the dot product attention in flax.nn. + Note: This is equivalent to the dot product attention in flax.deprecated.nn. However, we do extra broadcasting of the bias in this function. I'm leaving this here incase we need to modify something later. @@ -243,7 +243,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_transformer.py b/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_transformer.py index 3a9df47..9904a32 100644 --- a/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_transformer.py +++ b/lra_benchmarks/models/sinkhorn_transformer/sinkhorn_transformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Sinkhorn Attention Transformer models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.sinkhorn_transformer import sinkhorn_attention diff --git a/lra_benchmarks/models/sparse_transformer/sparse_attention.py b/lra_benchmarks/models/sparse_transformer/sparse_attention.py index c08acb3..6f52646 100644 --- a/lra_benchmarks/models/sparse_transformer/sparse_attention.py +++ b/lra_benchmarks/models/sparse_transformer/sparse_attention.py @@ -21,7 +21,7 @@ from typing import Iterable import attr -from flax import nn +from flax.deprecated import nn from jax import lax import jax.numpy as jnp import numpy as np diff --git a/lra_benchmarks/models/sparse_transformer/sparse_transformer.py b/lra_benchmarks/models/sparse_transformer/sparse_transformer.py index 9471b69..04fe004 100644 --- a/lra_benchmarks/models/sparse_transformer/sparse_transformer.py +++ b/lra_benchmarks/models/sparse_transformer/sparse_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. """Sparse Transformer modules.""" from absl import logging -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.sparse_transformer import sparse_attention diff --git a/lra_benchmarks/models/synthesizer/synthesizer.py b/lra_benchmarks/models/synthesizer/synthesizer.py index ec80760..f6c9daf 100644 --- a/lra_benchmarks/models/synthesizer/synthesizer.py +++ b/lra_benchmarks/models/synthesizer/synthesizer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Synthesizer models.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers from lra_benchmarks.models.synthesizer import synthesizer_attention diff --git a/lra_benchmarks/models/synthesizer/synthesizer_attention.py b/lra_benchmarks/models/synthesizer/synthesizer_attention.py index e95251b..9002860 100644 --- a/lra_benchmarks/models/synthesizer/synthesizer_attention.py +++ b/lra_benchmarks/models/synthesizer/synthesizer_attention.py @@ -17,12 +17,12 @@ from absl import logging -from flax import nn -from flax.nn.attention import _CacheEntry -from flax.nn.attention import _make_causal_mask -from flax.nn.attention import Cache -from flax.nn.attention import make_padding_mask -from flax.nn.stochastic import make_rng +from flax.deprecated import nn +from flax.deprecated.nn.attention import _CacheEntry +from flax.deprecated.nn.attention import _make_causal_mask +from flax.deprecated.nn.attention import Cache +from flax.deprecated.nn.attention import make_padding_mask +from flax.deprecated.nn.stochastic import make_rng import jax from jax import lax from jax import random @@ -225,7 +225,7 @@ def apply(self, key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient + cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout diff --git a/lra_benchmarks/models/transformer/transformer.py b/lra_benchmarks/models/transformer/transformer.py index afffaa8..0cea7aa 100644 --- a/lra_benchmarks/models/transformer/transformer.py +++ b/lra_benchmarks/models/transformer/transformer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Transformer model.""" -from flax import nn +from flax.deprecated import nn import jax.numpy as jnp from lra_benchmarks.models.layers import common_layers diff --git a/lra_benchmarks/text_classification/train.py b/lra_benchmarks/text_classification/train.py index 1289750..ffcf13d 100644 --- a/lra_benchmarks/text_classification/train.py +++ b/lra_benchmarks/text_classification/train.py @@ -22,8 +22,8 @@ from absl import flags from absl import logging from flax import jax_utils -from flax import nn from flax import optim +from flax.deprecated import nn from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import common_utils diff --git a/lra_benchmarks/utils/train_utils.py b/lra_benchmarks/utils/train_utils.py index 75e04dc..ac70425 100644 --- a/lra_benchmarks/utils/train_utils.py +++ b/lra_benchmarks/utils/train_utils.py @@ -13,7 +13,7 @@ # limitations under the License. """This contains utility functions for model training and evaluation.""" -from flax import nn +from flax.deprecated import nn from flax.training import common_utils import jax.numpy as jnp from lra_benchmarks.models.bigbird import bigbird diff --git a/requirements.txt b/requirements.txt index b52c9c1..08997ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,3 @@ ml-collections>=0.1.0 tensorboard>=2.3.0 tensorflow>=2.3.1 tensorflow-datasets>=4.0.1 -tensorflow_text -gin -gin-config>=0.1.3 From 80cde0a98cb0e768e713dcaec46e1d3d985f79ab Mon Sep 17 00:00:00 2001 From: Nitesh Gundavarapu Date: Tue, 2 Aug 2022 17:28:35 +0000 Subject: [PATCH 2/5] Internal change PiperOrigin-RevId: 464826132 --- .../models/transformer_tlb/__init__.py | 0 .../models/transformer_tlb/transformer_tlb.py | 323 ++++++++++++++++++ .../configs/transformer_tlb_base.py | 37 ++ lra_benchmarks/text_classification/train.py | 15 + lra_benchmarks/utils/train_utils.py | 4 + 5 files changed, 379 insertions(+) create mode 100644 lra_benchmarks/models/transformer_tlb/__init__.py create mode 100644 lra_benchmarks/models/transformer_tlb/transformer_tlb.py create mode 100644 lra_benchmarks/text_classification/configs/transformer_tlb_base.py diff --git a/lra_benchmarks/models/transformer_tlb/__init__.py b/lra_benchmarks/models/transformer_tlb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lra_benchmarks/models/transformer_tlb/transformer_tlb.py b/lra_benchmarks/models/transformer_tlb/transformer_tlb.py new file mode 100644 index 0000000..6738822 --- /dev/null +++ b/lra_benchmarks/models/transformer_tlb/transformer_tlb.py @@ -0,0 +1,323 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformer-based stateful lra models.""" +from flax.deprecated import nn +import jax +import jax.numpy as jnp +from lra_benchmarks.models.layers import common_layers +from lra_benchmarks.models.transformer import transformer + + +class CrossTransformerBlock(nn.Module): + """Cross Transformer layer.""" + + def apply(self, + inputs_q, + inputs_kv, + qkv_dim, + mlp_dim, + num_heads, + dtype=jnp.float32, + inputs_segmentation=None, + causal_mask=False, + padding_mask=None, + key_padding_mask=None, + dropout_rate=0.1, + attention_dropout_rate=0.1, + deterministic=False, + cache=None, + residual=True): + """Applies CrossTransformerBlock module. + + Args: + inputs_q: input query + inputs_kv: input key-value + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + num_heads: number of heads + dtype: the dtype of the computation (default: float32). + inputs_segmentation: input segmentation info for packed examples. + causal_mask: bool, mask future or not + padding_mask: bool, mask padding tokens + key_padding_mask: bool, mask padding tokens + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + deterministic: bool, deterministic or not (to apply dropout) + cache: flax autoregressive cache for fast decoding. + residual: Boolean, to use residual connectors or not. + + Returns: + output after transformer block. + + """ + + # Attention block. + assert inputs_q.ndim == 3 + x = nn.LayerNorm(inputs_q) + s = nn.LayerNorm(inputs_kv) + x = nn.MultiHeadDotProductAttention( + x, s, + num_heads=num_heads, + dtype=dtype, + qkv_features=qkv_dim, + attention_axis=(1,), + causal_mask=causal_mask, + segmentation=inputs_segmentation, + padding_mask=padding_mask, + key_padding_mask=key_padding_mask, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.normal(stddev=1e-6), + bias=False, + broadcast_dropout=False, + dropout_rate=attention_dropout_rate, + deterministic=deterministic, + cache=cache) + x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) + x = x + inputs_q + + # MLP block. + y = nn.LayerNorm(x) + y = common_layers.MlpBlock( + y, + mlp_dim=mlp_dim, + dtype=dtype, + dropout_rate=dropout_rate, + deterministic=deterministic) + + if residual: + output = x + y + else: + output = x + if padding_mask is not None: + corner_case = (jnp.sum(padding_mask, + axis=1) == 0)[..., None] + output = jnp.where(corner_case, inputs_q, output) + elif key_padding_mask is not None: + corner_case = (jnp.sum(key_padding_mask, + axis=1) == 0)[..., None] + output = jnp.where(corner_case, inputs_q, output) + return output + + +class StatefulTransformerEncoder(nn.Module): + """Stateful Transformer Model Encoder (https://arxiv.org/abs/2205.14794).""" + + def apply(self, + inputs, + vocab_size, + inputs_positions=None, + inputs_segmentation=None, + shared_embedding=None, + use_bfloat16=False, + emb_dim=512, + num_heads=8, + dtype=jnp.float32, + num_layers=6, + qkv_dim=512, + mlp_dim=2048, + max_len=512, + train=True, + dropout_rate=0.1, + attention_dropout_rate=0.1, + learn_pos_emb=False, + classifier=False, + classifier_pool='CLS', + num_classes=10, + tied_weights=False, + meta_network=False, + meta_layers=1, + meta_pool='last', + use_residual=True, + meta_partition=3, + meta_layer_output=False, + self_to_cross_ratio_input_updater=2, + num_cross_layers_input_updater=1, + num_cross_layers_state_updater=1, + num_state_tokens=20, + block_size=20, + use_global_pos_encoding=False): + """Applies Transformer model on the inputs. + + Args: + inputs: input data + vocab_size: size of the vocabulary + inputs_positions: input subsequence positions for packed examples. + inputs_segmentation: input segmentation info for packed examples. + shared_embedding: a shared embedding layer to use. + use_bfloat16: bool: whether use bfloat16. + emb_dim: dimension of embedding + num_heads: number of heads + dtype: the dtype of the computation (default: float32) + num_layers: number of layers + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + max_len: maximum length. + train: if it is training, + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + learn_pos_emb: boolean, if learn the positional embedding or use the + sinusoidal positional embedding. + classifier: boolean, for classification mode (output N-class logits) + classifier_pool: str, supports "MEAN", "MAX" pooling. + num_classes: int, number of classification classes. + tied_weights: bool, to tie weights or not. + meta_network: boolean, experimental extreme self-attention. + meta_layers: int, number of meta_layers + meta_pool: str, the type of meta pooling. + use_residual: boolean, turn off transformer residuals. + meta_partition: int. + meta_layer_output: boolean. + self_to_cross_ratio_input_updater: number of self-attention layers before + each cross attention layer in the input-update direction + num_cross_layers_input_updater: number of cross-attention layers + in the input update direction + num_cross_layers_state_updater: number of cross-attention layers + in the state update direction + num_state_tokens: number of state tokens + block_size: chunk size of inputs + use_global_pos_encoding: Whether the input position embedding is global + or local + + Returns: + output of a transformer encoder or logits if classifier_mode is true. + """ + assert inputs.ndim == 2 # (batch, len) + + # Padding Masks + src_padding_mask = (inputs > 0)[..., None] + + # Input Embedding + if shared_embedding is None: + input_embed = nn.Embed.partial( + num_embeddings=vocab_size, + features=emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0)) + else: + input_embed = shared_embedding + x = inputs.astype('int32') + x = input_embed(x) + + # Input Positional Encoding + pe_init = nn.initializers.normal(stddev=0.02) if learn_pos_emb else None + if use_global_pos_encoding: + x = common_layers.AddPositionEmbs( + x, + inputs_positions=inputs_positions, + posemb_init=pe_init, + max_len=max_len, + name='global_posembed_input') + pe = common_layers.AddPositionEmbs( + jnp.zeros((x.shape[0], block_size, x.shape[2]), + dtype=x.dtype), + inputs_positions=inputs_positions, + posemb_init=pe_init, + max_len=max_len, + name='posembed_input') + + if use_bfloat16: + x = x.astype(jnp.bfloat16) + dtype = jnp.bfloat16 + else: + dtype = jnp.float32 + + # Create layers + horizontal_blocks = [] + for block_idx in range(num_cross_layers_input_updater): + horizontal_layers = [] + for layer_idx in range(self_to_cross_ratio_input_updater): + horizontal_layers.append( + transformer.TransformerBlock.shared( + qkv_dim=qkv_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + dtype=dtype, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + deterministic=not train, + name=f'horizontal_block_{block_idx}_self_{layer_idx}')) + horizontal_layers.append( + CrossTransformerBlock.shared( + qkv_dim=qkv_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + dtype=dtype, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + deterministic=not train, + name=f'horizontal_block_{block_idx}_cross', + residual=use_residual)) + horizontal_blocks.append(horizontal_layers) + vertical_layers = [] + for layer_idx in range(num_cross_layers_state_updater): + vertical_layers.append( + CrossTransformerBlock.shared( + qkv_dim=qkv_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + dtype=dtype, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + deterministic=not train, + name=f'vertical_block_{block_idx}_cross', + residual=use_residual)) + + num_tokens = x.shape[1] + num_chunks = num_tokens // block_size + init_state = jnp.zeros((x.shape[0], num_state_tokens, x.shape[2])) + x_with_pad = jnp.concatenate([x, src_padding_mask], axis=2) + # Split inputs into chunks of block_size. + x_with_pad = jnp.stack( + jnp.split(x_with_pad, num_chunks, axis=1), axis=0) + + # State positional encoding + state_pos_embed = common_layers.AddPositionEmbs( + init_state, + inputs_positions=None, + posemb_init=None, + max_len=num_state_tokens, + name='posembed_state') + + # Processing function for each chunk + def scan_inner(cur_state, cur_x_with_pad): + padding_mask_cur = cur_x_with_pad[:, :, -1][:, :, None] + x_cur = cur_x_with_pad[:, :, :-1] + if not use_global_pos_encoding: + x_cur = x_cur + pe + x_cur = nn.dropout(x_cur, rate=dropout_rate, deterministic=not train) + cur_state = cur_state + state_pos_embed + for block_idx in range(num_cross_layers_input_updater): + for layer_idx in range(self_to_cross_ratio_input_updater): + x_cur = horizontal_blocks[block_idx][layer_idx]( + x_cur, padding_mask=padding_mask_cur) + x_cur = horizontal_blocks[block_idx][-1]( + x_cur, cur_state, padding_mask=padding_mask_cur) + for layer_idx in range(num_cross_layers_state_updater): + cur_state = vertical_layers[layer_idx]( + cur_state, x_cur, key_padding_mask=padding_mask_cur) + return cur_state, None + + # Scan + cur_state, _ = jax.lax.scan(scan_inner, init_state, x_with_pad, unroll=40) + + assert cur_state.shape == init_state.shape + + encoded = nn.LayerNorm(cur_state, dtype=dtype, name='encoder_norm') + + if classifier: + encoded = common_layers.classifier_head( + encoded, num_classes, mlp_dim, pooling_mode='MEAN') + return encoded diff --git a/lra_benchmarks/text_classification/configs/transformer_tlb_base.py b/lra_benchmarks/text_classification/configs/transformer_tlb_base.py new file mode 100644 index 0000000..f28ea86 --- /dev/null +++ b/lra_benchmarks/text_classification/configs/transformer_tlb_base.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration and hyperparameter sweeps.""" + +from lra_benchmarks.text_classification.configs import base_tc_config + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = base_tc_config.get_config() + config.model_type = "transformer_tlb" + config.learning_rate = 0.05/2. + config.self_to_cross_ratio_input_updater = 2 + config.num_cross_layers_input_updater = 1 + config.num_cross_layers_state_updater = 1 + config.num_state_tokens = 10 + config.block_size = 10 + config.use_global_pos_encoding = False + + config.max_length = 4000 + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/text_classification/train.py b/lra_benchmarks/text_classification/train.py index ffcf13d..bb67e34 100644 --- a/lra_benchmarks/text_classification/train.py +++ b/lra_benchmarks/text_classification/train.py @@ -184,6 +184,21 @@ def main(argv): 'num_classes': CLASS_MAP[FLAGS.task_name], 'classifier_pool': config.classifier_pool } + if config.model_type == 'transformer_tlb': + model_kwargs.update({ + 'self_to_cross_ratio_input_updater': + config.self_to_cross_ratio_input_updater, + 'num_cross_layers_input_updater': + config.num_cross_layers_input_updater, + 'num_cross_layers_state_updater': + config.num_cross_layers_state_updater, + 'num_state_tokens': + config.num_state_tokens, + 'block_size': + config.block_size, + 'use_global_pos_encoding': + config.use_global_pos_encoding + }) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) diff --git a/lra_benchmarks/utils/train_utils.py b/lra_benchmarks/utils/train_utils.py index ac70425..3c5883d 100644 --- a/lra_benchmarks/utils/train_utils.py +++ b/lra_benchmarks/utils/train_utils.py @@ -28,6 +28,7 @@ from lra_benchmarks.models.sparse_transformer import sparse_transformer from lra_benchmarks.models.synthesizer import synthesizer from lra_benchmarks.models.transformer import transformer +from lra_benchmarks.models.transformer_tlb import transformer_tlb import numpy as onp @@ -80,6 +81,9 @@ def get_model(model_type, create_model_fn, model_kwargs, *create_model_args): elif model_type == 'longformer': return create_model_fn(longformer.LongformerEncoder, model_kwargs, *create_model_args) + elif model_type == 'transformer_tlb': + return create_model_fn(transformer_tlb.StatefulTransformerEncoder, + model_kwargs, *create_model_args) else: raise ValueError('Model type not supported') From 38a2484ef2ed4a97bff5fb54143ad80ec720091e Mon Sep 17 00:00:00 2001 From: Nitesh Gundavarapu Date: Tue, 16 Aug 2022 16:21:17 -0700 Subject: [PATCH 3/5] Update copyright --- .../text_classification/configs/transformer_tlb_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lra_benchmarks/text_classification/configs/transformer_tlb_base.py b/lra_benchmarks/text_classification/configs/transformer_tlb_base.py index f28ea86..e83ea2d 100644 --- a/lra_benchmarks/text_classification/configs/transformer_tlb_base.py +++ b/lra_benchmarks/text_classification/configs/transformer_tlb_base.py @@ -1,4 +1,4 @@ -# Copyright 2021 Google LLC +# Copyright 2022 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From fd1763a235078d703ccaa6b07fc4bb529b81bf4b Mon Sep 17 00:00:00 2001 From: Nitesh Gundavarapu Date: Mon, 12 Sep 2022 17:12:11 +0000 Subject: [PATCH 4/5] Internal change PiperOrigin-RevId: 473778528 --- .../configs/cifar10/transformer_tlb_base.py | 38 ++++++++ .../pathfinder32/transformer_tlb_base.py | 39 ++++++++ .../listops/configs/transformer_tlb_base.py | 37 ++++++++ lra_benchmarks/listops/train.py | 2 + .../matching/configs/transformer_tlb_base.py | 39 ++++++++ lra_benchmarks/matching/train.py | 4 +- .../models/transformer_tlb/transformer_tlb.py | 93 +++++++++++++++++++ .../configs/transformer_tlb_base.py | 14 +-- lra_benchmarks/text_classification/train.py | 17 +--- lra_benchmarks/utils/train_utils.py | 3 + 10 files changed, 264 insertions(+), 22 deletions(-) create mode 100644 lra_benchmarks/image/configs/cifar10/transformer_tlb_base.py create mode 100644 lra_benchmarks/image/configs/pathfinder32/transformer_tlb_base.py create mode 100644 lra_benchmarks/listops/configs/transformer_tlb_base.py create mode 100644 lra_benchmarks/matching/configs/transformer_tlb_base.py diff --git a/lra_benchmarks/image/configs/cifar10/transformer_tlb_base.py b/lra_benchmarks/image/configs/cifar10/transformer_tlb_base.py new file mode 100644 index 0000000..58e2229 --- /dev/null +++ b/lra_benchmarks/image/configs/cifar10/transformer_tlb_base.py @@ -0,0 +1,38 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration and hyperparameter sweeps.""" + +from lra_benchmarks.image.configs.cifar10 import base_cifar10_config + + +def get_config(): + """Get the hyperparameter configuration.""" + config = base_cifar10_config.get_config() + config.model_type = "transformer_tlb" + config.learning_rate = .001 + config.model.emb_dim = 128 + config.model.mlp_dim = 128 + config.model.num_heads = 8 + config.model.qkv_dim = 64 + config.model.self_to_cross_ratio_input_updater = 1 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 5 + config.model.block_size = 32 + config.model.use_global_pos_encoding = False + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/image/configs/pathfinder32/transformer_tlb_base.py b/lra_benchmarks/image/configs/pathfinder32/transformer_tlb_base.py new file mode 100644 index 0000000..3ad009e --- /dev/null +++ b/lra_benchmarks/image/configs/pathfinder32/transformer_tlb_base.py @@ -0,0 +1,39 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base Configuration.""" + +from lra_benchmarks.image.configs.pathfinder32 import base_pathfinder32_config + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = base_pathfinder32_config.get_config() + config.model_type = "transformer_tlb" + config.learning_rate = 0.0005 + config.model.num_heads = 8 + config.model.emb_dim = 128 + config.model.dropout_rate = 0.1 + config.model.qkv_dim = 64 + config.model.mlp_dim = 128 + config.model.self_to_cross_ratio_input_updater = 2 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 5 + config.model.block_size = 64 + config.model.use_global_pos_encoding = True + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/listops/configs/transformer_tlb_base.py b/lra_benchmarks/listops/configs/transformer_tlb_base.py new file mode 100644 index 0000000..f75bb31 --- /dev/null +++ b/lra_benchmarks/listops/configs/transformer_tlb_base.py @@ -0,0 +1,37 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration and hyperparameter sweeps.""" + +from lra_benchmarks.listops.configs import base_listops_config +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = base_listops_config.get_config() + config.model_type = "transformer_tlb" + config.model = ml_collections.ConfigDict() + config.learning_rate = 0.05 / 4. + config.model.self_to_cross_ratio_input_updater = 2 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 100 + config.model.block_size = 100 + config.model.use_global_pos_encoding = False + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/listops/train.py b/lra_benchmarks/listops/train.py index 80cd5fc..d03fb06 100644 --- a/lra_benchmarks/listops/train.py +++ b/lra_benchmarks/listops/train.py @@ -188,6 +188,8 @@ def main(argv): 'classifier': True, 'num_classes': 10 }) + if 'model' in config: + model_kwargs.update(config.model) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) diff --git a/lra_benchmarks/matching/configs/transformer_tlb_base.py b/lra_benchmarks/matching/configs/transformer_tlb_base.py new file mode 100644 index 0000000..2f0d75a --- /dev/null +++ b/lra_benchmarks/matching/configs/transformer_tlb_base.py @@ -0,0 +1,39 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration and hyperparameter sweeps.""" + +from lra_benchmarks.matching.configs import base_match_config +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = base_match_config.get_config() + config.model_type = "transformer_tlb_dual" + config.batch_size = 128 + config.learning_rate = 0.05 / 2. + config.eval_frequency = 5000 + config.num_train_steps = 20000 + config.model = ml_collections.ConfigDict() + config.model.self_to_cross_ratio_input_updater = 2 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 10 + config.model.block_size = 10 + config.model.use_global_pos_encoding = False + return config + + +def get_hyper(hyper): + return hyper.product([]) diff --git a/lra_benchmarks/matching/train.py b/lra_benchmarks/matching/train.py index d56b752..6eacf23 100644 --- a/lra_benchmarks/matching/train.py +++ b/lra_benchmarks/matching/train.py @@ -186,6 +186,8 @@ def main(argv): 'num_classes': 2, 'classifier_pool': config.pooling_mode } + if 'model' in config: + model_kwargs.update(config.model) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) @@ -195,7 +197,7 @@ def main(argv): dropout_rngs = random.split(rng, jax.local_device_count()) model = train_utils.get_model(model_type, create_model, model_kwargs, - init_rng, input_shape) + init_rng, input_shape, input_shape) optimizer = create_optimizer( model, learning_rate, weight_decay=FLAGS.config.weight_decay) diff --git a/lra_benchmarks/models/transformer_tlb/transformer_tlb.py b/lra_benchmarks/models/transformer_tlb/transformer_tlb.py index 6738822..506de62 100644 --- a/lra_benchmarks/models/transformer_tlb/transformer_tlb.py +++ b/lra_benchmarks/models/transformer_tlb/transformer_tlb.py @@ -321,3 +321,96 @@ def scan_inner(cur_state, cur_x_with_pad): encoded = common_layers.classifier_head( encoded, num_classes, mlp_dim, pooling_mode='MEAN') return encoded + + +class StatefulTransformerDualEncoder(nn.Module): + """Stateful Transformer Model Encoder (https://arxiv.org/abs/2205.14794).""" + + def apply(self, + inputs1, + inputs2, + vocab_size, + inputs1_positions=None, + inputs2_positions=None, + inputs1_segmentation=None, + inputs2_segmentation=None, + shared_embedding=None, + use_bfloat16=False, + emb_dim=512, + num_heads=8, + dtype=jnp.float32, + num_layers=6, + qkv_dim=512, + mlp_dim=2048, + max_len=512, + train=True, + dropout_rate=0.1, + attention_dropout_rate=0.1, + learn_pos_emb=False, + classifier=False, + classifier_pool='CLS', + num_classes=10, + tied_weights=False, + meta_network=False, + meta_layers=1, + meta_pool='last', + use_residual=True, + meta_partition=3, + meta_layer_output=False, + self_to_cross_ratio_input_updater=2, + num_cross_layers_input_updater=1, + num_cross_layers_state_updater=1, + num_state_tokens=20, + block_size=20, + use_global_pos_encoding=False, + interaction=None): + """Applies Transformer model on the inputs.""" + encoder = StatefulTransformerEncoder.shared( + vocab_size=vocab_size, + shared_embedding=shared_embedding, + use_bfloat16=use_bfloat16, + emb_dim=emb_dim, + num_heads=num_heads, + dtype=dtype, + num_layers=num_layers, + qkv_dim=qkv_dim, + mlp_dim=mlp_dim, + max_len=max_len, + train=train, + dropout_rate=dropout_rate, + attention_dropout_rate=attention_dropout_rate, + learn_pos_emb=learn_pos_emb, + classifier=False, + classifier_pool=classifier_pool, + num_classes=num_classes, + tied_weights=tied_weights, + meta_network=meta_network, + meta_layers=meta_layers, + meta_pool=meta_pool, + use_residual=use_residual, + meta_partition=meta_partition, + meta_layer_output=meta_layer_output, + self_to_cross_ratio_input_updater=self_to_cross_ratio_input_updater, + num_cross_layers_input_updater=num_cross_layers_input_updater, + num_cross_layers_state_updater=num_cross_layers_state_updater, + num_state_tokens=num_state_tokens, + block_size=block_size, + use_global_pos_encoding=use_global_pos_encoding, + name='stateful_encoder') + inputs1_encoded = encoder( + inputs=inputs1, + inputs_positions=inputs1_positions, + inputs_segmentation=inputs1_segmentation) + inputs2_encoded = encoder( + inputs=inputs2, + inputs_positions=inputs2_positions, + inputs_segmentation=inputs2_segmentation) + + encoded = common_layers.classifier_head_dual( + inputs1_encoded, + inputs2_encoded, + num_classes, + mlp_dim, + pooling_mode='MEAN', + interaction=interaction) + return encoded diff --git a/lra_benchmarks/text_classification/configs/transformer_tlb_base.py b/lra_benchmarks/text_classification/configs/transformer_tlb_base.py index e83ea2d..7dd75ca 100644 --- a/lra_benchmarks/text_classification/configs/transformer_tlb_base.py +++ b/lra_benchmarks/text_classification/configs/transformer_tlb_base.py @@ -15,6 +15,7 @@ """Configuration and hyperparameter sweeps.""" from lra_benchmarks.text_classification.configs import base_tc_config +import ml_collections def get_config(): @@ -22,12 +23,13 @@ def get_config(): config = base_tc_config.get_config() config.model_type = "transformer_tlb" config.learning_rate = 0.05/2. - config.self_to_cross_ratio_input_updater = 2 - config.num_cross_layers_input_updater = 1 - config.num_cross_layers_state_updater = 1 - config.num_state_tokens = 10 - config.block_size = 10 - config.use_global_pos_encoding = False + config.model = ml_collections.ConfigDict() + config.model.self_to_cross_ratio_input_updater = 2 + config.model.num_cross_layers_input_updater = 1 + config.model.num_cross_layers_state_updater = 1 + config.model.num_state_tokens = 10 + config.model.block_size = 10 + config.model.use_global_pos_encoding = False config.max_length = 4000 return config diff --git a/lra_benchmarks/text_classification/train.py b/lra_benchmarks/text_classification/train.py index bb67e34..b3b44f2 100644 --- a/lra_benchmarks/text_classification/train.py +++ b/lra_benchmarks/text_classification/train.py @@ -184,21 +184,8 @@ def main(argv): 'num_classes': CLASS_MAP[FLAGS.task_name], 'classifier_pool': config.classifier_pool } - if config.model_type == 'transformer_tlb': - model_kwargs.update({ - 'self_to_cross_ratio_input_updater': - config.self_to_cross_ratio_input_updater, - 'num_cross_layers_input_updater': - config.num_cross_layers_input_updater, - 'num_cross_layers_state_updater': - config.num_cross_layers_state_updater, - 'num_state_tokens': - config.num_state_tokens, - 'block_size': - config.block_size, - 'use_global_pos_encoding': - config.use_global_pos_encoding - }) + if 'model' in config: + model_kwargs.update(config.model) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.process_index()) diff --git a/lra_benchmarks/utils/train_utils.py b/lra_benchmarks/utils/train_utils.py index 3c5883d..f7851e1 100644 --- a/lra_benchmarks/utils/train_utils.py +++ b/lra_benchmarks/utils/train_utils.py @@ -84,6 +84,9 @@ def get_model(model_type, create_model_fn, model_kwargs, *create_model_args): elif model_type == 'transformer_tlb': return create_model_fn(transformer_tlb.StatefulTransformerEncoder, model_kwargs, *create_model_args) + elif model_type == 'transformer_tlb_dual': + return create_model_fn(transformer_tlb.StatefulTransformerDualEncoder, + model_kwargs, *create_model_args) else: raise ValueError('Model type not supported') From faf9b610232e628d8fb624c5ce4c9e99e312d0d8 Mon Sep 17 00:00:00 2001 From: Nitesh Gundavarapu Date: Tue, 13 Sep 2022 18:50:33 -0700 Subject: [PATCH 5/5] Update Readme.MD to include TLB results --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b76dff7..abb9e5b 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,11 @@ We list the entries of other papers and submissions that used our LRA benchmark. Model | ListOps | Text | Retrieval | Image | Path | Path-X | Avg --------------- | --------- | --------- | --------- | --------- | --------- | ------ | --- IGLOO | 39.23 | 82 | 75.5 | 47.0 | 67.50 | NA | 62.25 +TLB | 37.05 | 81.88 | 76.91 | 57.51 | 79.06 | FAIL | 66.48 -IGLOO Submissions (by Vsevolod Sourkov) - https://github.com/redna11/lra-igloo +IGLOO Submissions (by Vsevolod Sourkov) - https://github.com/redna11/lra-igloo \ +TLB ([Temporal Latent Bottleneck](lra_benchmarks/models/transformer_tlb)) - [transformer_tlb](lra_benchmarks/models/transformer_tlb) ## Citation