Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
_SERVER_IMAGE = flags.DEFINE_string(
"server_image", None, "Full path to the server Docker image"
)
_SIDECAR_IMAGE = flags.DEFINE_string(
"sidecar_image",
"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_0.10.0",
"Full path to the sidecar Docker image",
)
_TPU_TYPE = flags.DEFINE_enum(
"tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type"
)
Expand All @@ -52,6 +57,7 @@
False,
"If true, only print the generated YAML without deploying.",
)
_SIDECAR_SHM_DIR = "/tmp/sidecar_dir"


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -191,6 +197,7 @@ def run_deployment(
jobset_name,
gcs_bucket,
server_image,
sidecar_image,
template_file,
dry_run,
deploy_func: Callable[[dict[str, Any]], None] = deploy_jobset,
Expand All @@ -202,6 +209,8 @@ def run_deployment(
context = {
"JOBSET_NAME": jobset_name,
"SERVER_IMAGE": server_image,
"SIDECAR_IMAGE": sidecar_image,
"SIDECAR_SHM_DIR": _SIDECAR_SHM_DIR,
"GCS_SCRATCH_LOCATION": gcs_bucket,
"NUM_SLICES": num_slices,
"INSTANCE_TYPE": f"{tpu_config.instance_prefix}:{topology}",
Expand Down Expand Up @@ -246,6 +255,7 @@ def main(argv: Sequence[str]) -> None:
jobset_name=_JOBSET_NAME.value,
gcs_bucket=_GCS_BUCKET.value,
server_image=server_image,
sidecar_image=_SIDECAR_IMAGE.value,
template_file=_TEMPLATE_FILE.value,
dry_run=_DRY_RUN.value,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Use the JAX image with the custom-built sidecar as the base.
FROM us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_0.10.0

# Set the working directory
WORKDIR /app

# 1. Upgrade pip and build tools
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --upgrade pip setuptools wheel

# 2. Clone MaxText
RUN git clone https://github.com/google/maxtext.git

# ADD THE CACHE MOUNT HERE
# Install the same version of JAX and JAXlib as the base image.
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -r maxtext/src/dependencies/requirements/base_requirements/requirements.txt && \
uv pip install --upgrade jax==0.10.0 jaxlib==0.10.0

# 3. (optional) Copy your local edits to MaxText requirements and src, if any.
# Make sure you're running this docker build from the root of your local MaxText
# checkout.
# COPY maxtext/src/dependencies/requirements/base_requirements/requirements.txt ./requirements.txt
# COPY maxtext/src /app/maxtext/src

# Ensure MaxText src is in PYTHONPATH
# ENV PYTHONPATH=/app/maxtext/src:$PYTHONPATH
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,26 @@ class ProxyOptions:
use_insecure_credentials: Whether to use insecure gRPC credentials for the
proxy server.
xla_flags: A list of XLA flags to pass to the proxy server.
sidecar: Whether to use the worker sidecar or not.
"""
use_insecure_credentials: bool = False
xla_flags: list[str] = dataclasses.field(default_factory=list)
sidecar: bool = False

@classmethod
def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
"""Creates a ProxyOptions object from a list of 'key:value' strings."""
use_insecure = False
use_sidecar = False
xla_flags = []
for option in options or []:
if ":" in option:
key, value = option.split(":", 1)
key_strip = key.strip().lower()
if key_strip == "use_insecure_credentials":
use_insecure = value.strip().lower() == "true"
elif key_strip == "sidecar":
use_sidecar = value.strip().lower() == "true"
elif key_strip == "xla_flags":
val_strip = value.strip()
if (
Expand All @@ -78,7 +83,11 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
if xla_flags:
validators.validate_xla_flags(xla_flags)

return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags)
return cls(
use_insecure_credentials=use_insecure,
xla_flags=xla_flags,
sidecar=use_sidecar,
)


def _deploy_pathways_proxy_server(
Expand Down Expand Up @@ -134,6 +143,9 @@ def _deploy_pathways_proxy_server(
)
proxy_args_str = "\n" + proxy_args_str

if proxy_options.sidecar:
proxy_args_str += "\n - --sidecar_name=external"

template = string.Template(yaml_template)
substituted_yaml = template.substitute(
PROXY_JOB_NAME=proxy_job_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ spec:
- --server_port=29005
- --resource_manager_address=$$(PATHWAYS_HEAD):29001
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
- --cloud_pathways_sidecar_shm_directory=${SIDECAR_SHM_DIR}
env:
- name: TPU_MIN_LOG_LEVEL
value: "0"
Expand Down Expand Up @@ -133,8 +134,47 @@ spec:
limits:
google.com/tpu: "${CHIPS_PER_VM}"
volumeMounts:
- mountPath: /tmp
name: shared-tmp
- name: shared-tmp
mountPath: /tmp
- name: sidecar-shared-memory
mountPath: ${SIDECAR_SHM_DIR}
initContainers:
- name: colocated-python-sidecar
image: ${SIDECAR_IMAGE}
imagePullPolicy: Always
env:
- name: GRPC_SERVER_ADDRESS
value: '''0.0.0.0:50051'''
- name: CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY
value: ${SIDECAR_SHM_DIR}
- name: PYTHONUNBUFFERED
value: '1'
# --- High Verbosity Logging Variables ---
- name: LOGLEVEL
value: 'DEBUG'
- name: GLOG_minloglevel
value: '0' # 0 = INFO level base
- name: GLOG_v
value: '5' # Extreme verbosity for all C++ modules
- name: TF_CPP_MIN_LOG_LEVEL
value: '0'
- name: TF_CPP_MIN_VLOG_LEVEL
value: '5' # TF/XLA verbose logging
- name: TPU_MIN_LOG_LEVEL
value: '0'
- name: GLOG_vmodule
value: 'jax_array_handlers=5,type_handlers=5,tensorstore_utils=5'
# ----------------------------------------
ports:
- containerPort: 50051
protocol: TCP
resources: {}
restartPolicy: Always
volumeMounts:
- name: shared-tmp
mountPath: /tmp
- name: sidecar-shared-memory
mountPath: ${SIDECAR_SHM_DIR}
dnsPolicy: ClusterFirstWithHostNet
hostNetwork: true
nodeSelector:
Expand All @@ -146,6 +186,9 @@ spec:
hostPath:
path: /tmp
type: DirectoryOrCreate
- name: sidecar-shared-memory
emptyDir:
medium: Memory
startupPolicy:
startupPolicyOrder: InOrder
successPolicy:
Expand Down
Loading