-
Notifications
You must be signed in to change notification settings - Fork 75
[WIP] LM Workload #860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
rka97
wants to merge
88
commits into
dev
Choose a base branch
from
lm_workload
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
[WIP] LM Workload #860
Changes from all commits
Commits
Show all changes
88 commits
Select commit
Hold shift + click to select a range
1d81455
Merge pull request #847 from mlcommons/dev
priyakasimbeg da5f85a
first LM commit
Niccolo-Ajroldi a12a364
lm data pipeline
Niccolo-Ajroldi ca83ab8
testing
Niccolo-Ajroldi e3e78dc
LM workload tested torch pipeline
Niccolo-Ajroldi e619495
LM workload - fix torch tests
Niccolo-Ajroldi d8e9c56
add LM tests, remove dev files
Niccolo-Ajroldi 6b4ff12
add LM tests, remove dev files
Niccolo-Ajroldi 3c5c847
Stop tracking .gitignore
Niccolo-Ajroldi 20d841b
Remove dev/ from repo, keep locally
Niccolo-Ajroldi f3ba059
fix comments
Niccolo-Ajroldi 381451f
add class specifications
Niccolo-Ajroldi f111d2e
add workload LM info
Niccolo-Ajroldi 808d398
restore data_utils.py tree map
Niccolo-Ajroldi 35f8f89
fixed NFS bug
Niccolo-Ajroldi cbb6ee6
train/val split before concat
Niccolo-Ajroldi 868987c
renamed datasets to avoid conflict with HF
Niccolo-Ajroldi 8191f6d
Merge remote-tracking branch 'upstream/lm_workload' into lm_workload
Niccolo-Ajroldi dd59ded
renamed datasets to dataset
Niccolo-Ajroldi 496b9c3
fix style
Niccolo-Ajroldi 50989eb
fix formatting
Niccolo-Ajroldi 5af0fdc
fix style
Niccolo-Ajroldi 2683099
fix style
Niccolo-Ajroldi 6b7ee29
fix yapf
Niccolo-Ajroldi 46b645b
fix style
Niccolo-Ajroldi b3ae647
HF datasets pipeline
rka97 f095d4b
Testing with linear model
rka97 4189ae0
Merge branch 'jit_switch' into lm_workload
rka97 0c22f3d
lm workload with linear model
rka97 99c7b9b
add nanodo model
rka97 706d9f7
torch model
rka97 c335e34
lm workload dataset integration in jax
rka97 2d54365
lm workload dataset integration in jax
rka97 af8cce4
set package versions for transformers and datasets
priyakasimbeg d68c54e
use train_test_split method to shuffle and split fineweb-edu dataset
priyakasimbeg 9737367
modifications to fwedu datasetup
priyakasimbeg 1bf0750
rename fwedu data dir
priyakasimbeg a333391
fix
priyakasimbeg 05dc4dd
add back batch mapping in tokenization for fwedu
priyakasimbeg b374cf8
debugging
priyakasimbeg c0c1e3c
debugging
priyakasimbeg f76dc39
debugging
priyakasimbeg e805fa7
use tfds to shuffle and split dataset
priyakasimbeg 362cbda
Merge remote-tracking branch 'origin/dev' into lm_workload
rka97 c9e9abc
add command for fineweb-edu
priyakasimbeg e4323de
fix
priyakasimbeg f0c6e75
update calls to sharing utils
priyakasimbeg f4ffbe7
Fix torch sharding issue, update input pipeline and workload classes …
rka97 5c85c7e
test working, lm workload training not working (debugging)
rka97 a59dfda
updates to input_pipeline and model spec
priyakasimbeg 1c3cb66
add defaults for lm workload
priyakasimbeg af91b12
refactor eval pipeline and loss fn for lm
priyakasimbeg 6b55adf
refactor evaluation pipeline for lm
priyakasimbeg 210d671
remove temporary flag for hlo dumps
priyakasimbeg 0ad7788
fix in workload target condition check
priyakasimbeg 01921d5
fix in mlp for glu
priyakasimbeg e420450
Fix OOM error in weighted cross entropy calculation
rka97 3b31ad5
fix issue with checkpointing bool
rka97 bbc114f
increase buffer size
priyakasimbeg f531b35
Merge branch 'lm_workload_priya' of github.com:mlcommons/algorithmic-…
priyakasimbeg 2b162e8
remove _eval_batch from jax workload
priyakasimbeg 617e1a3
add todo for pytorch _eval_batch cleanup
priyakasimbeg bebc80a
Merge pull request #891 from mlcommons/lm_workload_priya
rka97 64ea658
add target setting algorithm for fineweb edu lm workload
priyakasimbeg b38ade0
update step hint for lm workload
priyakasimbeg 65369f2
update target
priyakasimbeg 6171b2d
update eval split sizes for lm workload and target setting point
priyakasimbeg d7a885c
Porting workload input pipeline to torch
rka97 f111aea
Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
rka97 1f0439a
Fix OOM bug in lm eval
rka97 b11c193
repeat dataset
rka97 42d1d1a
label smoothing default fix
priyakasimbeg c334c97
finish merge
priyakasimbeg d95f2bf
Make sure to take the correct number of batches in lm
rka97 7deb070
Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
rka97 0dc16db
Properly handle repetition in LM training and evaluation splits
rka97 7edb702
move eval_batch from shared class to framework specific classes since…
priyakasimbeg 0879e68
finish merge
priyakasimbeg 73e3ea6
Refactor imports and clean up unused code in LM workload and related …
rka97 91988af
pass linter checks
rka97 bb4a380
Refactor loss function in LM workloads to unify label handling and im…
rka97 a58fbd5
Fix init in both models to be the same, add lm model diff test
rka97 b59afa0
Refactor model configuration classes to make them consistent between …
rka97 d35cdde
Add query-key normalization to CausalAttn and Attention classes, incl…
rka97 ffb8163
update target
priyakasimbeg 2cc9dff
Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
priyakasimbeg 202e5cb
add pytorch nadamw_target_setting
priyakasimbeg 98e491a
docker updates for a100
priyakasimbeg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| """Input pipeline for a LM dataset.""" | ||
|
|
||
| import functools | ||
| import os | ||
| from typing import Optional | ||
|
|
||
| import jax | ||
| import tensorflow as tf | ||
|
|
||
| from algoperf import data_utils | ||
|
|
||
| AUTOTUNE = tf.data.experimental.AUTOTUNE | ||
| PAD_ID = tf.constant(-1, dtype=tf.int64) | ||
|
|
||
| TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} | ||
|
|
||
| SEQUENCE_LENGTH = 1024 | ||
| MAX_CORPUS_CHARS = 1_000_000_000 | ||
| SHUFFLE_BUFFER_SIZE = 1000 | ||
| VOCAB_SIZE = 50_257 | ||
|
|
||
|
|
||
| def batch_with_padding( | ||
| dataset: tf.data.Dataset, | ||
| batch_size, | ||
| padded_shapes=None, | ||
| padding_id=PAD_ID, | ||
| ): | ||
| """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. | ||
|
|
||
| Args: | ||
| dataset: tf.data.Dataset | ||
| batch_size: batch size of resulting batched dataset | ||
| padded_shapes: shapes of the padded batches | ||
| padding_id: value for padding, for elements in new batch | ||
|
|
||
| Returns: | ||
| """ | ||
| batched_dataset = dataset.batch(batch_size, drop_remainder=False) | ||
|
|
||
| # tf.data.Dataset.padded.batch pads elements in the batch so we call it | ||
| # again with batch_size=1 to pad each element in original batch. | ||
| padded_batched_dataset = batched_dataset.padded_batch( | ||
| 1, padded_shapes=padded_shapes, padding_values=padding_id | ||
| ) | ||
|
|
||
| # Remove extra dimension resulting from the batch_size=1. | ||
| padded_batched_dataset = padded_batched_dataset.unbatch() | ||
|
|
||
| return padded_batched_dataset | ||
|
|
||
|
|
||
| def get_data_iter( | ||
| data_rng: jax.random.PRNGKey, | ||
| split: str, | ||
| data_dir: str, | ||
| batch_size: int, | ||
| num_batches: Optional[int] = None, | ||
| ): | ||
| ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) | ||
|
|
||
| it = map( | ||
| functools.partial( | ||
| data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size | ||
| ), | ||
| ds, | ||
| ) | ||
|
|
||
| return iter(it) | ||
|
|
||
|
|
||
| def get_lm_dataset( | ||
| data_rng: jax.random.PRNGKey, | ||
| split: str, | ||
| data_dir: str, | ||
| batch_size: int, | ||
| num_batches: Optional[int] = None, | ||
| ): | ||
| """Load preprocessed TF dataset.""" | ||
| if split not in TFDS_SPLIT_NAME: | ||
| raise NotImplementedError | ||
|
|
||
| shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1) | ||
|
|
||
| data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) | ||
| tokens_ds = tf.data.Dataset.load(data_dir) | ||
|
|
||
| # tokens | ||
| tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices) | ||
|
|
||
| # sequences | ||
| sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True) | ||
|
|
||
| # get inputs and outputs | ||
| sequences_ds = sequences_ds.map( | ||
| lambda x: { | ||
| 'inputs': x['input_ids'][:SEQUENCE_LENGTH], | ||
| 'targets': x['input_ids'][1:], | ||
| }, | ||
| num_parallel_calls=AUTOTUNE, | ||
| ) | ||
| if split == 'train': | ||
| ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) | ||
| ds = ds.batch(batch_size, drop_remainder=False) | ||
| ds = ds.take(num_batches) if num_batches is not None else ds | ||
| ds = ds.repeat() | ||
| ds = ds.map( | ||
| lambda x: { | ||
| 'inputs': x['inputs'], | ||
| 'targets': x['targets'], | ||
| 'weights': None, | ||
| } | ||
| ) | ||
| ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
| elif split == 'eval_train': | ||
| ds = batch_with_padding( | ||
| sequences_ds, | ||
| batch_size, | ||
| padded_shapes={ | ||
| 'inputs': (batch_size, None), | ||
| 'targets': (batch_size, None), | ||
| }, | ||
| ) | ||
| ds = ds.take(num_batches) if num_batches is not None else ds | ||
| ds = ds.repeat() | ||
| ds = ds.map( | ||
| lambda x: { | ||
| 'inputs': x['inputs'], | ||
| 'targets': x['targets'], | ||
| 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), | ||
| } | ||
| ) | ||
| ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
| elif split == 'validation': | ||
| ds = batch_with_padding( | ||
| sequences_ds, | ||
| batch_size, | ||
| padded_shapes={ | ||
| 'inputs': (batch_size, None), | ||
| 'targets': (batch_size, None), | ||
| }, | ||
| ) | ||
| ds = ds.take(num_batches) if num_batches is not None else ds | ||
| ds = ds.repeat() | ||
priyakasimbeg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ds = ds.map( | ||
| lambda x: { | ||
| 'inputs': x['inputs'], | ||
| 'targets': x['targets'], | ||
| 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), | ||
| } | ||
| ) | ||
| ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
| return ds | ||
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.