|
| 1 | +# Instructions for training Llama3.1-70B-MaxText on TPU trillium (v6e-256) with Google Cloud Storage (GCS) |
| 2 | + |
| 3 | +## GCS Bucket setup |
| 4 | +Create a bucket with a dataset for dataloading and a bucket to write checkpoints. To create a regional HNS bucket use the following command: |
| 5 | +``` |
| 6 | +gcloud storage buckets create gs://${BUCKET_NAME} location=$ZONE defaultstorageclass=Standard enablehierarchicalnamespace uniformbucketlevelaccess |
| 7 | +``` |
| 8 | +Checkout this [link](https://github.com/AI-Hypercomputer/maxtext/blob/b93beba652db6b3f4e6c82dc48a83b03229f5d3a/getting_started/Data_Input_Pipeline.md#tfds-pipeline) for downloading the Allenai c4 dataset which is used in this recipe. |
| 9 | + |
| 10 | +## XPK setup |
| 11 | +1. Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK. |
| 12 | +2. A FUSE adapter lets you mount and access Cloud Storage buckets as local file systems, so applications can read and write objects in your bucket using standard file system semantics. It adds pv and pvc to the cluster. |
| 13 | +https://github.com/AIHypercomputer/xpk?tab=readmeovfile#storage |
| 14 | +``` |
| 15 | +cd ~/xpk |
| 16 | +
|
| 17 | +python3 xpk.py storage attach $USERdatasetbucket type=gcsfuse project=$PROJECT cluster=$CLUSTER zone=$ZONE mountpoint=/tmp/dataset readonly=false bucket=<bucket> size=64 automount=false manifest=$RECIPE_REPO/tpu-recipes/training/trillium/Llama3.1-70B-MaxText-with-Storage/dataset_pvc.yaml |
| 18 | +
|
| 19 | +python3 xpk.py storage attach $USERckptbucket type=gcsfuse project=$PROJECT cluster=$CLUSTER zone=$ZONE mountpoint=/tmp/ckpt readonly=false bucket=<bucket> size=64 automount=false manifest=$RECIPE_REPO/tpu-recipes/training/trillium/Llama3.1-70B-MaxText-with-Storage/checkpoint_pvc.yaml |
| 20 | +``` |
| 21 | +For the dataset bucket and checkpoint bucket create separate manifest files. |
| 22 | +Creating a bucket and xpk storage is a one time setup. Checkout `checkpoint_pvc.yaml` and `dataset_pvc.yaml` for example manifest files. |
| 23 | + |
| 24 | +## Prep for MaxText |
| 25 | + |
| 26 | +### Install MaxText and Build Docker Image |
| 27 | +Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. |
| 28 | + |
| 29 | +In step 2, use the jax-stable-stack image containing JAX 0.5.2: |
| 30 | +``` |
| 31 | +BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 |
| 32 | +bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} |
| 33 | +``` |
| 34 | + |
| 35 | +## Run MaxText Llama3.1-70B workloads on GKE |
| 36 | + |
| 37 | +### Starting workload |
| 38 | + |
| 39 | +From the MaxText root directory, start your Llama3.1-70B workload. |
| 40 | + |
| 41 | +Run MaxText Llama 3.1 70B with synthetic data and no checkpointing: |
| 42 | +``` |
| 43 | +python3 benchmarks/benchmark_runner.py xpk \ |
| 44 | + project=$PROJECT \ |
| 45 | + zone=$ZONE \ |
| 46 | + device_type=v6e256 \ |
| 47 | + num_slices=1 \ |
| 48 | + cluster_name=$CLUSTER \ |
| 49 | + base_output_directory=$OUTPUT_DIR \ |
| 50 | + model_name="llama3_1_70b_8192_synthetic" \ |
| 51 | + num_steps=100 \ |
| 52 | + base_docker_image=maxtext_base_image |
| 53 | +``` |
| 54 | + |
| 55 | +Run MaxText Llama 3.1 70B with checkpointing and loading real data from GCS: |
| 56 | +``` |
| 57 | +python3 benchmarks/benchmark_runner.py xpk \ |
| 58 | + project=$PROJECT \ |
| 59 | + zone=$ZONE \ |
| 60 | + device_type=v6e256 \ |
| 61 | + num_slices=1 \ |
| 62 | + cluster_name=${CLUSTER} \ |
| 63 | + base_output_directory=/tmp/ckpt \ |
| 64 | + model_name="llama3_1_70b_8192_rd_ckpt_grain" \ |
| 65 | + num_steps=100 \ |
| 66 | + base_docker_image=maxtext_base_image \ |
| 67 | + xpk_storage="yourdatasetbucket" xpk_storage="yourckptbucket" |
| 68 | +``` |
| 69 | + |
| 70 | +If you would like to run on multiple slices of v6e-256, you may modify the `--num_slices` flag. |
| 71 | + |
| 72 | +### Workload Details |
| 73 | + |
| 74 | +For reference, here are the `llama3_1_70b_8192_synthetic` and `llama3_1_70b_8192_rd_ckpt_grain` workload details: |
| 75 | + |
| 76 | +``` |
| 77 | + MaxTextModel( |
| 78 | + model_name="llama3_1-70b-8192", |
| 79 | + model_type="llama3.1-70b", |
| 80 | + tuning_params={ |
| 81 | + "per_device_batch_size": 4, |
| 82 | + "ici_fsdp_parallelism": -1, |
| 83 | + "remat_policy": "custom", |
| 84 | + "decoder_layer_input": "offload", |
| 85 | + "query_proj": "offload", |
| 86 | + "key_proj": "offload", |
| 87 | + "value_proj": "offload", |
| 88 | + "max_target_length": 8192, |
| 89 | + "attention": "flash", |
| 90 | + "use_iota_embed": True, |
| 91 | + "dataset_path": "gs://max-datasets-rogue", |
| 92 | + "dataset_type": "synthetic", |
| 93 | + "enable_checkpointing": False, |
| 94 | + "sa_block_q": 2048, |
| 95 | + "sa_block_kv": 2048, |
| 96 | + "sa_block_kv_compute": 2048, |
| 97 | + "sa_block_q_dkv": 2048, |
| 98 | + "sa_block_kv_dkv": 2048, |
| 99 | + "sa_block_kv_dkv_compute": 2048, |
| 100 | + "sa_block_q_dq": 2048, |
| 101 | + "sa_block_kv_dq": 2048, |
| 102 | + "sa_use_fused_bwd_kernel": True, |
| 103 | + "profiler": "xplane", |
| 104 | + "skip_first_n_steps_for_profiler": 10, |
| 105 | + "profiler_steps": 5, |
| 106 | + }, |
| 107 | + xla_flags=( |
| 108 | + xla_flags_library.DENSE_VMEM_LIMIT_FLAG |
| 109 | + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER |
| 110 | + + xla_flags_library.DATA_PARALLEL_OVERLAP |
| 111 | + + xla_flags_library.CF_FOR_ALL_GATHER |
| 112 | + + xla_flags_library.HOST_OFFLOAD_FLAGS |
| 113 | + ), |
| 114 | + ) |
| 115 | +
|
| 116 | +
|
| 117 | + MaxTextModel( |
| 118 | + model_name="llama3_1_70b_8192_rd_ckpt_grain", |
| 119 | + model_type="llama3.1-70b", |
| 120 | + tuning_params={ |
| 121 | + "per_device_batch_size": 2, |
| 122 | + "ici_fsdp_parallelism": -1, |
| 123 | + "remat_policy": "custom", |
| 124 | + "decoder_layer_input": "offload", |
| 125 | + "query_proj": "offload", |
| 126 | + "key_proj": "offload", |
| 127 | + "value_proj": "offload", |
| 128 | + "max_target_length": 8192, |
| 129 | + "attention": "flash", |
| 130 | + "use_iota_embed": True, |
| 131 | + "dataset_path": "/tmp/dataset", |
| 132 | + "dataset_type": "grain", |
| 133 | + "grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*", |
| 134 | + "grain_worker_count": 24, |
| 135 | + "enable_checkpointing": True, |
| 136 | + "async_checkpointing": True, |
| 137 | + "checkpoint_period": 20, |
| 138 | + "sa_block_q": 2048, |
| 139 | + "sa_block_kv": 2048, |
| 140 | + "sa_block_kv_compute": 2048, |
| 141 | + "sa_block_q_dkv": 2048, |
| 142 | + "sa_block_kv_dkv": 2048, |
| 143 | + "sa_block_kv_dkv_compute": 2048, |
| 144 | + "sa_block_q_dq": 2048, |
| 145 | + "sa_block_kv_dq": 2048, |
| 146 | + "sa_use_fused_bwd_kernel": True, |
| 147 | + }, |
| 148 | + xla_flags=( |
| 149 | + xla_flags_library.DENSE_VMEM_LIMIT_FLAG |
| 150 | + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER |
| 151 | + + xla_flags_library.DATA_PARALLEL_OVERLAP |
| 152 | + + xla_flags_library.CF_FOR_ALL_GATHER |
| 153 | + + xla_flags_library.HOST_OFFLOAD_FLAGS |
| 154 | + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE |
| 155 | + + " --xla_tpu_iova_dma_chunk_size_bytes=104857" |
| 156 | + ), |
| 157 | + ) |
| 158 | +``` |
| 159 | + |
| 160 | +This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/1e4d513ad70dd4074d975a9f7936295008d4b900/benchmarks/maxtext_trillium_model_configs.py#L1103-L1146) file within the MaxText repository. |
0 commit comments