Skip to content

Commit af9ca4c

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
Add managed_pathways_service for Pathways-on-Cloud
This change introduces Managed Pathways Service for GKE. It includes: `tpu_manager.py` uses `kubectl` to deploy a Pathways proxy JobSet on a GKE cluster, sets up port forwarding, and configures JAX environment variables to connect to the proxy. `run_connect_example.py` is an example script to start the proxy. Prerequisite: A Pathways cluster is up and running with Resource Manager and worker pods deployed successfully, e.g., using pw-cluster.yaml. TESTED: on proxy 2 clients each requesting `2 x v5e-32` simultaneously. PiperOrigin-RevId: 801034430
1 parent 315f578 commit af9ca4c

File tree

5 files changed

+569
-0
lines changed

5 files changed

+569
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This file marks this directory as a Python package.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
apiVersion: jobset.x-k8s.io/v1alpha2
2+
kind: JobSet
3+
metadata:
4+
name: pathways-akshu-s4-rw7
5+
spec:
6+
coordinator:
7+
replicatedJob: pathways-head
8+
failurePolicy:
9+
maxRestarts: 1
10+
restartStrategy: Recreate
11+
network:
12+
enableDNSHostnames: true
13+
publishNotReadyAddresses: true
14+
replicatedJobs:
15+
- name: pathways-head
16+
replicas: 1
17+
template:
18+
metadata:
19+
annotations:
20+
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
21+
spec:
22+
backoffLimit: 0
23+
completionMode: Indexed
24+
completions: 1
25+
parallelism: 1
26+
template:
27+
metadata:
28+
labels:
29+
kueue.x-k8s.io/podset: pathways-head
30+
spec:
31+
containers:
32+
- args:
33+
- --server_port=29001
34+
- --gcs_scratch_location=gs://akshu-v5e
35+
- --node_type=resource_manager
36+
- --instance_count=4
37+
- --instance_type=tpuv5e:4x8
38+
- --xla_tpu_use_enhanced_launch_barrier=true
39+
- --logtostderr
40+
- --stderrthreshold=0
41+
- --v=1
42+
env:
43+
- name: REPLICATED_JOB_NAME
44+
valueFrom:
45+
fieldRef:
46+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
47+
- name: JOBSET_NAME
48+
valueFrom:
49+
fieldRef:
50+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
51+
- name: HOST_ADDRESS
52+
valueFrom:
53+
fieldRef:
54+
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
55+
- name: TPU_SKIP_MDS_QUERY
56+
value: "true"
57+
image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_server:latest
58+
imagePullPolicy: Always
59+
name: pathways-rm
60+
ports:
61+
- containerPort: 29001
62+
protocol: TCP
63+
- containerPort: 29002
64+
protocol: TCP
65+
resources:
66+
limits:
67+
cpu: "8"
68+
memory: 16G
69+
nodeSelector:
70+
cloud.google.com/gke-nodepool: cpu-np
71+
dnsPolicy: ClusterFirstWithHostNet
72+
hostNetwork: true
73+
restartPolicy: OnFailure
74+
- name: worker
75+
replicas: 4
76+
template:
77+
metadata: {}
78+
spec:
79+
backoffLimit: 64
80+
completionMode: Indexed
81+
completions: 8
82+
parallelism: 8
83+
template:
84+
metadata:
85+
annotations:
86+
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
87+
labels:
88+
kueue.x-k8s.io/podset: worker
89+
spec:
90+
containers:
91+
- args:
92+
- --server_port=29005
93+
- --resource_manager_address=$(PATHWAYS_HEAD):29001
94+
- --gcs_scratch_location=gs://akshu-v5e
95+
- --xla_tpu_use_enhanced_launch_barrier=true
96+
- --logtostderr
97+
- --stderrthreshold=0
98+
- --v=1
99+
env:
100+
- name: TPU_MIN_LOG_LEVEL
101+
value: "0"
102+
- name: TF_CPP_MIN_LOG_LEVEL
103+
value: "0"
104+
- name: XCLOUD_ENVIRONMENT
105+
value: GCP
106+
- name: MEGASCALE_GRPC_ENABLE_XOR_TRACER
107+
value: "false"
108+
- name: MEGASCALE_NUM_SLICES
109+
valueFrom:
110+
fieldRef:
111+
fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']
112+
- name: JOBSET_NAME
113+
valueFrom:
114+
fieldRef:
115+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
116+
- name: REPLICATED_JOB_NAME
117+
valueFrom:
118+
fieldRef:
119+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
120+
- name: MEGASCALE_SLICE_ID
121+
valueFrom:
122+
fieldRef:
123+
fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index']
124+
- name: PATHWAYS_HEAD
125+
valueFrom:
126+
fieldRef:
127+
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
128+
- name: MEGASCALE_COORDINATOR_ADDRESS
129+
valueFrom:
130+
fieldRef:
131+
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
132+
image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_server:latest
133+
imagePullPolicy: Always
134+
name: pathways-worker
135+
ports:
136+
- containerPort: 29005
137+
protocol: TCP
138+
- containerPort: 29006
139+
protocol: TCP
140+
- containerPort: 8471
141+
protocol: TCP
142+
- containerPort: 8080
143+
protocol: TCP
144+
resources:
145+
limits:
146+
google.com/tpu: "4"
147+
dnsPolicy: ClusterFirstWithHostNet
148+
hostNetwork: true
149+
nodeSelector:
150+
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
151+
cloud.google.com/gke-tpu-topology: 4x8
152+
restartPolicy: OnFailure
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
apiVersion: jobset.x-k8s.io/v1alpha2
2+
kind: JobSet
3+
metadata:
4+
name: ${PROXY_NAME}
5+
spec:
6+
coordinator:
7+
replicatedJob: pathways-head
8+
failurePolicy:
9+
maxRestarts: 1
10+
restartStrategy: Recreate
11+
network:
12+
enableDNSHostnames: true
13+
publishNotReadyAddresses: true
14+
replicatedJobs:
15+
- name: pathways-head
16+
replicas: 1
17+
template:
18+
metadata:
19+
annotations:
20+
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
21+
spec:
22+
backoffLimit: 0
23+
completionMode: Indexed
24+
completions: 1
25+
parallelism: 1
26+
template:
27+
metadata:
28+
labels:
29+
kueue.x-k8s.io/podset: pathways-head
30+
spec:
31+
containers:
32+
- args:
33+
- --server_port=29000
34+
- --resource_manager_address=${PATHWAYS_HEAD}:${PATHWAYS_HEAD_PORT}
35+
- --gcs_scratch_location=${GCS_BUCKET}
36+
- --virtual_slices=${EXPECTED_INSTANCES}
37+
env:
38+
- name: PATHWAYS_HEAD
39+
valueFrom:
40+
fieldRef:
41+
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
42+
image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_proxy_server:latest
43+
imagePullPolicy: Always
44+
name: pathways-proxy
45+
ports:
46+
- containerPort: 29000
47+
protocol: TCP
48+
resources:
49+
limits:
50+
cpu: "16"
51+
memory: 100G
52+
nodeSelector:
53+
cloud.google.com/gke-nodepool: cpu-np
54+
dnsPolicy: ClusterFirstWithHostNet
55+
hostNetwork: true
56+
restartPolicy: OnFailure
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Script to run JAX code on TPU with the Managed Pathways service."""
15+
16+
from collections.abc import Sequence
17+
from absl import app
18+
import tpu_manager
19+
20+
21+
def main(argv: Sequence[str]) -> None:
22+
if len(argv) > 1:
23+
raise app.UsageError("Too many command-line arguments.")
24+
with tpu_manager.connect(
25+
"pw-scale-test-v5e-32",
26+
"cloud-tpu-multipod-dev",
27+
"us-south1",
28+
"gs://akshu-v5e",
29+
"pathways-akshu-s4-rw7-pathways-head-0-0.pathways-akshu-s4-rw7:29001",
30+
{"tpuv5e:4x8": 2},
31+
) as tm:
32+
print(tm)
33+
import jax
34+
import jax.numpy as jnp
35+
import pathwaysutils
36+
import pprint
37+
import time
38+
39+
pathwaysutils.initialize()
40+
41+
orig_matrix = jnp.zeros(5)
42+
43+
print("start")
44+
result_matrix = orig_matrix + 1
45+
print("Original Random Matrix:")
46+
pprint.pprint(orig_matrix)
47+
print("\nMatrix after adding 1:")
48+
pprint.pprint(result_matrix)
49+
50+
51+
if __name__ == "__main__":
52+
app.run(main)

0 commit comments

Comments
 (0)