Skip to content

Commit

Permalink
Merge pull request #51 from NiteshBharadwaj/main
Browse files Browse the repository at this point in the history
Adding Temporal Latent Bottleneck & Fixing Flax Deprecated dependencies
  • Loading branch information
MostafaDehghani authored Sep 14, 2022
2 parents 09c2916 + faf9b61 commit cd31e5c
Show file tree
Hide file tree
Showing 35 changed files with 679 additions and 60 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ We list the entries of other papers and submissions that used our LRA benchmark.
Model | ListOps | Text | Retrieval | Image | Path | Path-X | Avg
--------------- | --------- | --------- | --------- | --------- | --------- | ------ | ---
IGLOO | 39.23 | 82 | 75.5 | 47.0 | 67.50 | NA | 62.25
TLB | 37.05 | 81.88 | 76.91 | 57.51 | 79.06 | FAIL | 66.48


IGLOO Submissions (by Vsevolod Sourkov) - https://github.com/redna11/lra-igloo
IGLOO Submissions (by Vsevolod Sourkov) - https://github.com/redna11/lra-igloo \
TLB ([Temporal Latent Bottleneck](lra_benchmarks/models/transformer_tlb)) - [transformer_tlb](lra_benchmarks/models/transformer_tlb)

## Citation

Expand Down Expand Up @@ -138,7 +140,7 @@ To run a task, run the train.py file in the corresponding task directory.
(please see how to obtain the data for certain tasks if applicable).

```
PYTHONPATH="$(pwd)":"$PYTHONPATH" python lra_benchmarks/listops/train.py \
PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/listops/train.py \
--config=lra_benchmarks/listops/configs/transformer_base.py \
--model_dir=/tmp/listops \
--task_name=basic \
Expand All @@ -164,7 +166,7 @@ If you would like to go to longer/shorter sequence lengths, we also support
generating your own split, run the following comment:

```
PYTHONPATH="$(pwd)":"$PYTHONPATH" python lra_benchmarks/data/listops.py -- \
PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/data/listops.py -- \
--output_dir=$HOME/lra_data/listops/
```

Expand Down
38 changes: 38 additions & 0 deletions lra_benchmarks/image/configs/cifar10/transformer_tlb_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration and hyperparameter sweeps."""

from lra_benchmarks.image.configs.cifar10 import base_cifar10_config


def get_config():
"""Get the hyperparameter configuration."""
config = base_cifar10_config.get_config()
config.model_type = "transformer_tlb"
config.learning_rate = .001
config.model.emb_dim = 128
config.model.mlp_dim = 128
config.model.num_heads = 8
config.model.qkv_dim = 64
config.model.self_to_cross_ratio_input_updater = 1
config.model.num_cross_layers_input_updater = 1
config.model.num_cross_layers_state_updater = 1
config.model.num_state_tokens = 5
config.model.block_size = 32
config.model.use_global_pos_encoding = False
return config


def get_hyper(hyper):
return hyper.product([])
39 changes: 39 additions & 0 deletions lra_benchmarks/image/configs/pathfinder32/transformer_tlb_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base Configuration."""

from lra_benchmarks.image.configs.pathfinder32 import base_pathfinder32_config


def get_config():
"""Get the default hyperparameter configuration."""
config = base_pathfinder32_config.get_config()
config.model_type = "transformer_tlb"
config.learning_rate = 0.0005
config.model.num_heads = 8
config.model.emb_dim = 128
config.model.dropout_rate = 0.1
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.self_to_cross_ratio_input_updater = 2
config.model.num_cross_layers_input_updater = 1
config.model.num_cross_layers_state_updater = 1
config.model.num_state_tokens = 5
config.model.block_size = 64
config.model.use_global_pos_encoding = True
return config


def get_hyper(hyper):
return hyper.product([])
2 changes: 1 addition & 1 deletion lra_benchmarks/image/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from absl import flags
from absl import logging
from flax import jax_utils
from flax import nn
from flax import optim
from flax.deprecated import nn
from flax.metrics import tensorboard
from flax.training import checkpoints
from flax.training import common_utils
Expand Down
37 changes: 37 additions & 0 deletions lra_benchmarks/listops/configs/transformer_tlb_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration and hyperparameter sweeps."""

from lra_benchmarks.listops.configs import base_listops_config
import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = base_listops_config.get_config()
config.model_type = "transformer_tlb"
config.model = ml_collections.ConfigDict()
config.learning_rate = 0.05 / 4.
config.model.self_to_cross_ratio_input_updater = 2
config.model.num_cross_layers_input_updater = 1
config.model.num_cross_layers_state_updater = 1
config.model.num_state_tokens = 100
config.model.block_size = 100
config.model.use_global_pos_encoding = False
return config


def get_hyper(hyper):
return hyper.product([])
4 changes: 3 additions & 1 deletion lra_benchmarks/listops/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from absl import flags
from absl import logging
from flax import jax_utils
from flax import nn
from flax import optim
from flax.deprecated import nn
from flax.metrics import tensorboard
from flax.training import checkpoints
from flax.training import common_utils
Expand Down Expand Up @@ -188,6 +188,8 @@ def main(argv):
'classifier': True,
'num_classes': 10
})
if 'model' in config:
model_kwargs.update(config.model)

rng = random.PRNGKey(random_seed)
rng = jax.random.fold_in(rng, jax.process_index())
Expand Down
39 changes: 39 additions & 0 deletions lra_benchmarks/matching/configs/transformer_tlb_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration and hyperparameter sweeps."""

from lra_benchmarks.matching.configs import base_match_config
import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = base_match_config.get_config()
config.model_type = "transformer_tlb_dual"
config.batch_size = 128
config.learning_rate = 0.05 / 2.
config.eval_frequency = 5000
config.num_train_steps = 20000
config.model = ml_collections.ConfigDict()
config.model.self_to_cross_ratio_input_updater = 2
config.model.num_cross_layers_input_updater = 1
config.model.num_cross_layers_state_updater = 1
config.model.num_state_tokens = 10
config.model.block_size = 10
config.model.use_global_pos_encoding = False
return config


def get_hyper(hyper):
return hyper.product([])
6 changes: 4 additions & 2 deletions lra_benchmarks/matching/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from absl import flags
from absl import logging
from flax import jax_utils
from flax import nn
from flax import optim
from flax.deprecated import nn
from flax.metrics import tensorboard
from flax.training import checkpoints
from flax.training import common_utils
Expand Down Expand Up @@ -186,6 +186,8 @@ def main(argv):
'num_classes': 2,
'classifier_pool': config.pooling_mode
}
if 'model' in config:
model_kwargs.update(config.model)

rng = random.PRNGKey(random_seed)
rng = jax.random.fold_in(rng, jax.process_index())
Expand All @@ -195,7 +197,7 @@ def main(argv):
dropout_rngs = random.split(rng, jax.local_device_count())

model = train_utils.get_model(model_type, create_model, model_kwargs,
init_rng, input_shape)
init_rng, input_shape, input_shape)

optimizer = create_optimizer(
model, learning_rate, weight_decay=FLAGS.config.weight_decay)
Expand Down
2 changes: 1 addition & 1 deletion lra_benchmarks/models/bigbird/bigbird.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer using BigBird (https://arxiv.org/abs/2007.14062)."""
from flax import nn
from flax.deprecated import nn
import jax.numpy as jnp
from lra_benchmarks.models.bigbird import bigbird_attention
from lra_benchmarks.models.layers import common_layers
Expand Down
6 changes: 3 additions & 3 deletions lra_benchmarks/models/bigbird/bigbird_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
"""Big Bird attention mechanism. See https://arxiv.org/abs/2007.14062."""
from absl import logging
from flax import nn
from flax.nn import attention
from flax.deprecated import nn
from flax.deprecated.nn import attention
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -462,7 +462,7 @@ def apply(self,
key_padding_mask: boolean specifying key-value tokens that are pad token.
segmentation: segment indices for packed inputs_q data.
key_segmentation: segment indices for packed inputs_kv data.
cache: an instance of `flax.nn.attention.Cache` used for efficient
cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient
autoregressive decoding.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
Expand Down
3 changes: 1 addition & 2 deletions lra_benchmarks/models/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Common layers used in models."""
from flax import nn
from flax.deprecated import nn
from jax import lax
import jax.numpy as jnp
import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions lra_benchmarks/models/linear_transformer/linear_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Custom Attention modules for Linear Transformer."""

from flax import nn
from flax.deprecated import nn
import jax.numpy as jnp


Expand Down Expand Up @@ -118,7 +118,7 @@ def apply(self,
key_padding_mask: boolean specifying key-value tokens that are pad token.
segmentation: segment indices for packed inputs_q data.
key_segmentation: segment indices for packed inputs_kv data.
cache: an instance of `flax.nn.attention.Cache` used for efficient
cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient
autoregressive decoding.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""LinearTransformer model."""

from flax import nn
from flax.deprecated import nn
import jax.numpy as jnp
from lra_benchmarks.models.layers import common_layers
from lra_benchmarks.models.linear_transformer import linear_attention
Expand Down
2 changes: 1 addition & 1 deletion lra_benchmarks/models/linformer/linformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linformer models."""
from flax import nn
from flax.deprecated import nn
import jax.numpy as jnp
from lra_benchmarks.models.layers import common_layers
from lra_benchmarks.models.linformer import linformer_attention
Expand Down
6 changes: 3 additions & 3 deletions lra_benchmarks/models/linformer/linformer_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
"""Custom Attention core modules for Flax."""

from flax import nn
from flax.nn.attention import dot_product_attention
from flax.deprecated import nn
from flax.deprecated.nn.attention import dot_product_attention
from jax import lax
import jax.numpy as jnp

Expand Down Expand Up @@ -74,7 +74,7 @@ def apply(self,
key_padding_mask: boolean specifying key-value tokens that are pad token.
segmentation: segment indices for packed inputs_q data.
key_segmentation: segment indices for packed inputs_kv data.
cache: an instance of `flax.nn.attention.Cache` used for efficient
cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient
autoregressive decoding.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
Expand Down
2 changes: 1 addition & 1 deletion lra_benchmarks/models/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Local Attention Transformer models."""
from flax import nn
from flax.deprecated import nn
import jax.numpy as jnp
from lra_benchmarks.models.layers import common_layers
from lra_benchmarks.models.local import local_attention
Expand Down
16 changes: 8 additions & 8 deletions lra_benchmarks/models/local/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from collections.abc import Iterable # pylint: disable=g-importing-member

from absl import logging
from flax import nn
from flax.nn.attention import _CacheEntry
from flax.nn.attention import _make_causal_mask
from flax.nn.attention import Cache
from flax.nn.attention import make_padding_mask
from flax.nn.stochastic import make_rng
from flax.deprecated import nn
from flax.deprecated.nn.attention import _CacheEntry
from flax.deprecated.nn.attention import _make_causal_mask
from flax.deprecated.nn.attention import Cache
from flax.deprecated.nn.attention import make_padding_mask
from flax.deprecated.nn.stochastic import make_rng
import jax
from jax import lax
from jax import random
Expand All @@ -42,7 +42,7 @@ def local_dot_product_attention(query,
precision=None):
"""Computes dot-product attention given query, key, and value.
Note: This is equivalent to the dot product attention in flax.nn.
Note: This is equivalent to the dot product attention in flax.deprecated.nn.
However, we do extra broadcasting of the bias in this function.
I'm leaving this here in case we need to modify something later.
Expand Down Expand Up @@ -206,7 +206,7 @@ def apply(self,
key_padding_mask: boolean specifying key-value tokens that are pad token.
segmentation: segment indices for packed inputs_q data.
key_segmentation: segment indices for packed inputs_kv data.
cache: an instance of `flax.nn.attention.Cache` used for efficient
cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient
autoregressive decoding.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
Expand Down
2 changes: 1 addition & 1 deletion lra_benchmarks/models/longformer/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer modules."""
from flax import nn
from flax.deprecated import nn
import jax.numpy as jnp
from lra_benchmarks.models.layers import common_layers
from lra_benchmarks.models.longformer import longformer_attention
Expand Down
Loading

0 comments on commit cd31e5c

Please sign in to comment.