Skip to content

Commit 57ff296

Browse files
committed
Add MaxText Llama 3.1 70B training with GCS recipe
1 parent af2a7cd commit 57ff296

File tree

5 files changed

+263
-0
lines changed

5 files changed

+263
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Instructions for training Llama3.1-70B-MaxText on TPU trillium (v6e-256) with Google Cloud Storage (GCS)
2+
3+
## GCS Bucket setup
4+
Create a bucket with a dataset for dataloading and a bucket to write checkpoints. To create a regional HNS bucket use the following command:
5+
```
6+
gcloud storage buckets create gs://${BUCKET_NAME} location=$ZONE defaultstorageclass=Standard enablehierarchicalnamespace uniformbucketlevelaccess
7+
```
8+
Checkout this [link](https://github.com/AI-Hypercomputer/maxtext/blob/b93beba652db6b3f4e6c82dc48a83b03229f5d3a/getting_started/Data_Input_Pipeline.md#tfds-pipeline) for downloading the Allenai c4 dataset which is used in this recipe.
9+
10+
## XPK setup
11+
1. Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK.
12+
2. A FUSE adapter lets you mount and access Cloud Storage buckets as local file systems, so applications can read and write objects in your bucket using standard file system semantics. It adds pv and pvc to the cluster.
13+
https://github.com/AIHypercomputer/xpk?tab=readmeovfile#storage
14+
```
15+
cd ~/xpk
16+
17+
python3 xpk.py storage attach $USERdatasetbucket type=gcsfuse project=$PROJECT cluster=$CLUSTER zone=$ZONE mountpoint=/tmp/dataset readonly=false bucket=<bucket> size=64 automount=false manifest=$RECIPE_REPO/tpu-recipes/training/trillium/Llama3.1-70B-MaxText-with-Storage/dataset_pvc.yaml
18+
19+
python3 xpk.py storage attach $USERckptbucket type=gcsfuse project=$PROJECT cluster=$CLUSTER zone=$ZONE mountpoint=/tmp/ckpt readonly=false bucket=<bucket> size=64 automount=false manifest=$RECIPE_REPO/tpu-recipes/training/trillium/Llama3.1-70B-MaxText-with-Storage/checkpoint_pvc.yaml
20+
```
21+
For the dataset bucket and checkpoint bucket create separate manifest files.
22+
Creating a bucket and xpk storage is a one time setup. Checkout `checkpoint_pvc.yaml` and `dataset_pvc.yaml` for example manifest files.
23+
24+
## Prep for MaxText
25+
26+
### Install MaxText and Build Docker Image
27+
Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image.
28+
29+
In step 2, use the jax-stable-stack image containing JAX 0.5.2:
30+
```
31+
BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1
32+
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE}
33+
```
34+
35+
## Run MaxText Llama3.1-70B workloads on GKE
36+
37+
### Starting workload
38+
39+
From the MaxText root directory, start your Llama3.1-70B workload.
40+
41+
Run MaxText Llama 3.1 70B with synthetic data and no checkpointing:
42+
```
43+
python3 benchmarks/benchmark_runner.py xpk \
44+
project=$PROJECT \
45+
zone=$ZONE \
46+
device_type=v6e256 \
47+
num_slices=1 \
48+
cluster_name=$CLUSTER \
49+
base_output_directory=$OUTPUT_DIR \
50+
model_name="llama3_1_70b_8192_synthetic" \
51+
num_steps=100 \
52+
base_docker_image=maxtext_base_image
53+
```
54+
55+
Run MaxText Llama 3.1 70B with checkpointing and loading real data from GCS:
56+
```
57+
python3 benchmarks/benchmark_runner.py xpk \
58+
project=$PROJECT \
59+
zone=$ZONE \
60+
device_type=v6e256 \
61+
num_slices=1 \
62+
cluster_name=${CLUSTER} \
63+
base_output_directory=/tmp/ckpt \
64+
model_name="llama3_1_70b_8192_rd_ckpt_grain" \
65+
num_steps=100 \
66+
base_docker_image=maxtext_base_image \
67+
xpk_storage="yourdatasetbucket" xpk_storage="yourckptbucket"
68+
```
69+
70+
If you would like to run on multiple slices of v6e-256, you may modify the `--num_slices` flag.
71+
72+
### Workload Details
73+
74+
For reference, here are the `llama3_1_70b_8192_synthetic` and `llama3_1_70b_8192_rd_ckpt_grain` workload details:
75+
76+
```
77+
MaxTextModel(
78+
model_name="llama3_1-70b-8192",
79+
model_type="llama3.1-70b",
80+
tuning_params={
81+
"per_device_batch_size": 4,
82+
"ici_fsdp_parallelism": -1,
83+
"remat_policy": "custom",
84+
"decoder_layer_input": "offload",
85+
"query_proj": "offload",
86+
"key_proj": "offload",
87+
"value_proj": "offload",
88+
"max_target_length": 8192,
89+
"attention": "flash",
90+
"use_iota_embed": True,
91+
"dataset_path": "gs://max-datasets-rogue",
92+
"dataset_type": "synthetic",
93+
"enable_checkpointing": False,
94+
"sa_block_q": 2048,
95+
"sa_block_kv": 2048,
96+
"sa_block_kv_compute": 2048,
97+
"sa_block_q_dkv": 2048,
98+
"sa_block_kv_dkv": 2048,
99+
"sa_block_kv_dkv_compute": 2048,
100+
"sa_block_q_dq": 2048,
101+
"sa_block_kv_dq": 2048,
102+
"sa_use_fused_bwd_kernel": True,
103+
"profiler": "xplane",
104+
"skip_first_n_steps_for_profiler": 10,
105+
"profiler_steps": 5,
106+
},
107+
xla_flags=(
108+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
109+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
110+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
111+
+ xla_flags_library.CF_FOR_ALL_GATHER
112+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
113+
),
114+
)
115+
116+
117+
MaxTextModel(
118+
model_name="llama3_1_70b_8192_rd_ckpt_grain",
119+
model_type="llama3.1-70b",
120+
tuning_params={
121+
"per_device_batch_size": 2,
122+
"ici_fsdp_parallelism": -1,
123+
"remat_policy": "custom",
124+
"decoder_layer_input": "offload",
125+
"query_proj": "offload",
126+
"key_proj": "offload",
127+
"value_proj": "offload",
128+
"max_target_length": 8192,
129+
"attention": "flash",
130+
"use_iota_embed": True,
131+
"dataset_path": "/tmp/dataset",
132+
"dataset_type": "grain",
133+
"grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*",
134+
"grain_worker_count": 24,
135+
"enable_checkpointing": True,
136+
"async_checkpointing": True,
137+
"checkpoint_period": 20,
138+
"sa_block_q": 2048,
139+
"sa_block_kv": 2048,
140+
"sa_block_kv_compute": 2048,
141+
"sa_block_q_dkv": 2048,
142+
"sa_block_kv_dkv": 2048,
143+
"sa_block_kv_dkv_compute": 2048,
144+
"sa_block_q_dq": 2048,
145+
"sa_block_kv_dq": 2048,
146+
"sa_use_fused_bwd_kernel": True,
147+
},
148+
xla_flags=(
149+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
150+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
151+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
152+
+ xla_flags_library.CF_FOR_ALL_GATHER
153+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
154+
+ xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE
155+
+ " --xla_tpu_iova_dma_chunk_size_bytes=104857"
156+
),
157+
)
158+
```
159+
160+
This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/1e4d513ad70dd4074d975a9f7936295008d4b900/benchmarks/maxtext_trillium_model_configs.py#L1103-L1146) file within the MaxText repository.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
apiVersion: v1
2+
kind: PersistentVolume
3+
metadata:
4+
name: your-ckpt-bucket-pv
5+
spec:
6+
accessModes:
7+
- ReadWriteMany
8+
capacity:
9+
storage: 64Gi
10+
persistentVolumeReclaimPolicy: Retain
11+
storageClassName: gcsfuse-sc # dummy storage class
12+
claimRef:
13+
namespace: default
14+
name: your-ckpt-bucket-pvc
15+
mountOptions:
16+
- metadata-cache:ttl-secs:-1
17+
- metadata-cache:negative-ttl-secs:0
18+
- metadata-cache:stat-cache-max-size-mb:-1
19+
- metadata-cache:type-cache-max-size-mb:-1
20+
- file-cache:enable-parallel-downloads:false
21+
- file-system:kernel-list-cache-ttl-secs:0
22+
- write:enable-streaming-writes:true
23+
- file-system:precondition-errors:false
24+
csi:
25+
driver: gcsfuse.csi.storage.gke.io
26+
volumeHandle: your-trillium-chkpt # unique bucket name
27+
volumeAttributes:
28+
gcsfuseMetadataPrefetchOnMount: "true"
29+
---
30+
apiVersion: v1
31+
kind: PersistentVolumeClaim
32+
metadata:
33+
name: your-ckpt-bucket-pvc
34+
namespace: defaultls
35+
spec:
36+
accessModes:
37+
- ReadWriteMany
38+
resources:
39+
requests:
40+
storage: 64Gi
41+
volumeName: your-ckpt-bucket-pv
42+
storageClassName: gcsfuse-sc # dummy storage class
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
apiVersion: v1
2+
kind: PersistentVolume
3+
metadata:
4+
name: your-dataset-bucket-pv
5+
spec:
6+
accessModes:
7+
- ReadWriteMany
8+
capacity:
9+
storage: 64Gi
10+
persistentVolumeReclaimPolicy: Retain
11+
storageClassName: gcsfuse-sc # dummy storage class
12+
claimRef:
13+
namespace: default
14+
name: your-dataset-bucket-pvc
15+
mountOptions:
16+
- metadata-cache:ttl-secs:-1
17+
- metadata-cache:stat-cache-max-size-mb:-1
18+
- metadata-cache:type-cache-max-size-mb:-1
19+
- file-cache:enable-parallel-downloads:false
20+
- file-system:kernel-list-cache-ttl-secs:-1
21+
- write:enable-streaming-writes:true
22+
csi:
23+
driver: gcsfuse.csi.storage.gke.io
24+
volumeHandle: your-trillium-dataloading # unique bucket name
25+
volumeAttributes:
26+
gcsfuseMetadataPrefetchOnMount: "true"
27+
---
28+
apiVersion: v1
29+
kind: PersistentVolumeClaim
30+
metadata:
31+
name: your-dataset-bucket-pvc
32+
namespace: default
33+
spec:
34+
accessModes:
35+
- ReadWriteMany
36+
resources:
37+
requests:
38+
storage: 64Gi
39+
volumeName: your-dataset-bucket-pv
40+
storageClassName: gcsfuse-sc # dummy storage class
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
python3 benchmarks/benchmark_runner.py xpk \
2+
project=$PROJECT \
3+
zone=$ZONE \
4+
device_type=v6e256 \
5+
num_slices=1 \
6+
cluster_name=${CLUSTER} \
7+
base_output_directory=/tmp/ckpt \
8+
model_name="llama3_1_70b_8192_rd_ckpt_grain" \
9+
num_steps=100 \
10+
base_docker_image=maxtext_base_image \
11+
xpk_storage="yourdatasetbucket" xpk_storage="yourckptbucket"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
python3 benchmarks/benchmark_runner.py xpk \
2+
project=$PROJECT \
3+
zone=$ZONE \
4+
device_type=v6e256 \
5+
num_slices=1 \
6+
cluster_name=$CLUSTER \
7+
base_output_directory=$OUTPUT_DIR \
8+
model_name="llama3_1_70b_8192_synthetic" \
9+
num_steps=100 \
10+
base_docker_image=maxtext_base_image

0 commit comments

Comments
 (0)