Skip to content

Commit e52b1b8

Browse files
committed
Add MaxText Llama 3.1 70B training with GCS recipe
1 parent af2a7cd commit e52b1b8

File tree

5 files changed

+293
-0
lines changed

5 files changed

+293
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Instructions for training Llama3.1-70B-MaxText on TPU trillium (v6e-256) with Google Cloud Storage (GCS)
2+
3+
## GCS Bucket setup
4+
1. Create two buckets: one to hold the dataset and one to use for checkpoints. To create regional HNS buckets use the following commands:
5+
```
6+
# Set variables
7+
export DATASET_BUCKET="dataloading-bucket-name"
8+
export CHECKPOINT_BUCKET="checkpoint-bucket-name"
9+
export REGION="us-central1"
10+
11+
# Create dataset bucket
12+
gcloud storage buckets create gs://${DATASET_BUCKET} --location=${REGION} --default-storage-class=Standard --enable-hierarchical-namespace --uniform-bucket-level-access
13+
14+
# Create checkpoint bucket
15+
gcloud storage buckets create gs://${CHECKPOINT_BUCKET} --location=${REGION} --default-storage-class=Standard --enable-hierarchical-namespace --uniform-bucket-level-access
16+
```
17+
Replace the following values:
18+
- `<DATASET_BUCKET>`:the name of your Cloud Storage bucket with training dataset. Do not include the gs:// prefix
19+
- `<CHECKPOINT_BUCKET>`: the name of your Cloud Storage bucket where checkpoints will be written. Do not include the gs:// prefix
20+
- `<REGION>`: the region where your GKE cluster is located ([available locations](https://cloud.google.com/storage/docs/locations#location-r))
21+
22+
2. Follow these [instructions](https://github.com/AI-Hypercomputer/maxtext/blob/b93beba652db6b3f4e6c82dc48a83b03229f5d3a/getting_started/Data_Input_Pipeline.md#tfds-pipeline) to download the Allenai c4 dataset to the dataset bucket.
23+
Then follow these [instructions](https://github.com/google/array_record/tree/main/beam) to convert the dataset into ArrayRecord.
24+
25+
## XPK setup
26+
1. Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK.
27+
2. GCSFuse lets you mount and access Cloud Storage buckets as local file systems, so applications can read and write objects in your bucket using standard file system semantics. You'll need to use the below commands to create [XPK storage resources](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#storage) for both the dataset and checkpoint buckets in order to mount them to the MaxText workload using GCSFuse. For the dataset bucket and checkpoint bucket use separate manifest files `checkpoint_pvc.yaml` and `dataset_pvc.yaml` from this repo.
28+
Be sure to update `volumeHandle` in the yamls with your correct bucket names. Creating a bucket and xpk storage is a one time setup.
29+
```
30+
31+
export RECIPE_REPO="path-to-this-recipe-repo" # Update
32+
33+
cd ~/xpk
34+
35+
python3 xpk.py storage attach dataset-bucket type=gcsfuse project=$PROJECT cluster=$CLUSTER zone=$ZONE mountpoint=/tmp/dataset readonly=false bucket=$DATASET_BUCKET size=64 automount=false manifest=$RECIPE_REPO/tpu-recipes/training/trillium/Llama3.1-70B-MaxText-with-Storage/dataset_pvc.yaml
36+
37+
python3 xpk.py storage attach checkpoint-bucket type=gcsfuse project=$PROJECT cluster=$CLUSTER zone=$ZONE mountpoint=/tmp/ckpt readonly=false bucket=$CHECKPOINT_BUCKET size=64 automount=false manifest=$RECIPE_REPO/tpu-recipes/training/trillium/Llama3.1-70B-MaxText-with-Storage/checkpoint_pvc.yaml
38+
```
39+
40+
41+
## Prep for MaxText
42+
43+
### Install MaxText and Build Docker Image
44+
Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image.
45+
46+
In step 2, use the jax-stable-stack image containing JAX 0.5.2:
47+
```
48+
BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1
49+
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE}
50+
```
51+
52+
## Run MaxText Llama3.1-70B workloads on GKE
53+
54+
### Starting workload
55+
56+
From the MaxText root directory, start your Llama3.1-70B workload.
57+
58+
Run MaxText Llama 3.1 70B with synthetic data and no checkpointing:
59+
```
60+
python3 benchmarks/benchmark_runner.py xpk \
61+
project=$PROJECT \
62+
zone=$ZONE \
63+
device_type=v6e-256 \
64+
num_slices=1 \
65+
cluster_name=$CLUSTER \
66+
base_output_directory=$OUTPUT_DIR \
67+
model_name="llama3_1_70b_8192_synthetic" \
68+
num_steps=100 \
69+
base_docker_image=maxtext_base_image
70+
```
71+
72+
Run MaxText Llama 3.1 70B with checkpointing and loading real data from GCS:
73+
```
74+
python3 benchmarks/benchmark_runner.py xpk \
75+
project=$PROJECT \
76+
zone=$ZONE \
77+
device_type=v6e-256 \
78+
num_slices=1 \
79+
cluster_name=${CLUSTER} \
80+
base_output_directory=/tmp/ckpt \
81+
model_name="llama3_1_70b_8192_rd_ckpt_grain" \
82+
num_steps=100 \
83+
base_docker_image=maxtext_base_image \
84+
xpk_storage=dataset-bucket xpk_storage=checkpoint-bucket
85+
```
86+
87+
If you would like to run on multiple slices of v6e-256, you may modify the `--num_slices` flag.
88+
89+
### Workload Details
90+
91+
For reference, here are the `llama3_1_70b_8192_synthetic` and `llama3_1_70b_8192_rd_ckpt_grain` workload details:
92+
93+
```
94+
MaxTextModel(
95+
model_name="llama3_1-70b-8192",
96+
model_type="llama3.1-70b",
97+
tuning_params={
98+
"per_device_batch_size": 4,
99+
"ici_fsdp_parallelism": -1,
100+
"remat_policy": "custom",
101+
"decoder_layer_input": "offload",
102+
"query_proj": "offload",
103+
"key_proj": "offload",
104+
"value_proj": "offload",
105+
"max_target_length": 8192,
106+
"attention": "flash",
107+
"use_iota_embed": True,
108+
"dataset_path": "gs://max-datasets-rogue",
109+
"dataset_type": "synthetic",
110+
"enable_checkpointing": False,
111+
"sa_block_q": 2048,
112+
"sa_block_kv": 2048,
113+
"sa_block_kv_compute": 2048,
114+
"sa_block_q_dkv": 2048,
115+
"sa_block_kv_dkv": 2048,
116+
"sa_block_kv_dkv_compute": 2048,
117+
"sa_block_q_dq": 2048,
118+
"sa_block_kv_dq": 2048,
119+
"sa_use_fused_bwd_kernel": True,
120+
"profiler": "xplane",
121+
"skip_first_n_steps_for_profiler": 10,
122+
"profiler_steps": 5,
123+
},
124+
xla_flags=(
125+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
126+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
127+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
128+
+ xla_flags_library.CF_FOR_ALL_GATHER
129+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
130+
),
131+
)
132+
133+
134+
MaxTextModel(
135+
model_name="llama3_1_70b_8192_rd_ckpt_grain",
136+
model_type="llama3.1-70b",
137+
tuning_params={
138+
"per_device_batch_size": 2,
139+
"ici_fsdp_parallelism": -1,
140+
"remat_policy": "custom",
141+
"decoder_layer_input": "offload",
142+
"query_proj": "offload",
143+
"key_proj": "offload",
144+
"value_proj": "offload",
145+
"max_target_length": 8192,
146+
"attention": "flash",
147+
"use_iota_embed": True,
148+
"dataset_path": "/tmp/dataset",
149+
"dataset_type": "grain",
150+
"grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*",
151+
"grain_worker_count": 24,
152+
"enable_checkpointing": True,
153+
"async_checkpointing": True,
154+
"checkpoint_period": 20,
155+
"sa_block_q": 2048,
156+
"sa_block_kv": 2048,
157+
"sa_block_kv_compute": 2048,
158+
"sa_block_q_dkv": 2048,
159+
"sa_block_kv_dkv": 2048,
160+
"sa_block_kv_dkv_compute": 2048,
161+
"sa_block_q_dq": 2048,
162+
"sa_block_kv_dq": 2048,
163+
"sa_use_fused_bwd_kernel": True,
164+
},
165+
xla_flags=(
166+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
167+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
168+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
169+
+ xla_flags_library.CF_FOR_ALL_GATHER
170+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
171+
+ xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE
172+
+ " --xla_tpu_iova_dma_chunk_size_bytes=104857"
173+
),
174+
)
175+
```
176+
177+
This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/1e4d513ad70dd4074d975a9f7936295008d4b900/benchmarks/maxtext_trillium_model_configs.py#L1103-L1146) file within the MaxText repository.
178+
179+
## Clean-up
180+
You can run the following commands to detach the XPK storage resources (this removes the PersistentVolumes and PersistentVolumeClaims created by the `xpk storage attach` commands from your GKE cluster).
181+
```
182+
# Detach dataset storage
183+
python3 xpk.py storage detach dataset-bucket \
184+
--project=$PROJECT --cluster=$CLUSTER --zone=$ZONE
185+
186+
# Detach checkpoint storage
187+
python3 xpk.py storage detach checkpoint-bucket \
188+
--project=$PROJECT --cluster=$CLUSTER --zone=$ZONE
189+
```
190+
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
apiVersion: v1
2+
kind: PersistentVolume
3+
metadata:
4+
name: checkpoint-bucket-pv
5+
spec:
6+
accessModes:
7+
- ReadWriteMany
8+
capacity:
9+
storage: 64Gi
10+
persistentVolumeReclaimPolicy: Retain
11+
storageClassName: gcsfuse-sc # dummy storage class
12+
claimRef:
13+
namespace: default
14+
name: checkpoint-bucket-pvc
15+
mountOptions:
16+
- metadata-cache:ttl-secs:-1
17+
- metadata-cache:negative-ttl-secs:0
18+
- metadata-cache:stat-cache-max-size-mb:-1
19+
- metadata-cache:type-cache-max-size-mb:-1
20+
- file-cache:enable-parallel-downloads:false
21+
- file-system:kernel-list-cache-ttl-secs:0
22+
- write:enable-streaming-writes:true
23+
- file-system:precondition-errors:false
24+
csi:
25+
driver: gcsfuse.csi.storage.gke.io
26+
volumeHandle: checkpoint-bucket-name # Update with your checkpoint bucket name
27+
volumeAttributes:
28+
gcsfuseMetadataPrefetchOnMount: "true"
29+
---
30+
apiVersion: v1
31+
kind: PersistentVolumeClaim
32+
metadata:
33+
name: checkpoint-bucket-pvc
34+
namespace: defaultls
35+
spec:
36+
accessModes:
37+
- ReadWriteMany
38+
resources:
39+
requests:
40+
storage: 64Gi
41+
volumeName: checkpoint-bucket-pv
42+
storageClassName: gcsfuse-sc # dummy storage class
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
apiVersion: v1
2+
kind: PersistentVolume
3+
metadata:
4+
name: dataset-bucket-pv
5+
spec:
6+
accessModes:
7+
- ReadWriteMany
8+
capacity:
9+
storage: 64Gi
10+
persistentVolumeReclaimPolicy: Retain
11+
storageClassName: gcsfuse-sc # dummy storage class
12+
claimRef:
13+
namespace: default
14+
name: dataset-bucket-pvc
15+
mountOptions:
16+
- metadata-cache:ttl-secs:-1
17+
- metadata-cache:stat-cache-max-size-mb:-1
18+
- metadata-cache:type-cache-max-size-mb:-1
19+
- file-cache:enable-parallel-downloads:false
20+
- file-system:kernel-list-cache-ttl-secs:-1
21+
- write:enable-streaming-writes:true
22+
csi:
23+
driver: gcsfuse.csi.storage.gke.io
24+
volumeHandle: dataloading-bucket-name # Update with your bucket name
25+
volumeAttributes:
26+
gcsfuseMetadataPrefetchOnMount: "true"
27+
---
28+
apiVersion: v1
29+
kind: PersistentVolumeClaim
30+
metadata:
31+
name: dataset-bucket-pvc
32+
namespace: default
33+
spec:
34+
accessModes:
35+
- ReadWriteMany
36+
resources:
37+
requests:
38+
storage: 64Gi
39+
volumeName: dataset-bucket-pv
40+
storageClassName: gcsfuse-sc # dummy storage class
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
python3 benchmarks/benchmark_runner.py xpk \
2+
project=$PROJECT \
3+
zone=$ZONE \
4+
device_type=v6e-256 \
5+
num_slices=1 \
6+
cluster_name=${CLUSTER} \
7+
base_output_directory=/tmp/ckpt \
8+
model_name="llama3_1_70b_8192_rd_ckpt_grain" \
9+
num_steps=100 \
10+
base_docker_image=maxtext_base_image \
11+
xpk_storage=$DATASET_STORAGE_NAME xpk_storage=$CHECKPOINT_STORAGE_NAME
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
python3 benchmarks/benchmark_runner.py xpk \
2+
project=$PROJECT \
3+
zone=$ZONE \
4+
device_type=v6e-256 \
5+
num_slices=1 \
6+
cluster_name=$CLUSTER \
7+
base_output_directory=$OUTPUT_DIR \
8+
model_name="llama3_1_70b_8192_synthetic" \
9+
num_steps=100 \
10+
base_docker_image=maxtext_base_image

0 commit comments

Comments
 (0)