Skip to content

Commit

Permalink
type check
Browse files Browse the repository at this point in the history
  • Loading branch information
cathyzbn committed Feb 10, 2025
1 parent 5b4d3a4 commit ca16303
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
exclude: ".git"
default_stages:
- commit
- pre-commit
fail_fast: true

repos:
Expand Down
137 changes: 78 additions & 59 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import functools
import logging
import multiprocessing
import os
Expand All @@ -8,7 +7,8 @@
import tempfile
import time
from dataclasses import dataclass
from typing import Callable, Optional, Sequence
from queue import Empty as QueueEmpty
from typing import List, Optional, Sequence, Tuple

import draccus
import mergedeep
Expand Down Expand Up @@ -63,6 +63,7 @@ class TpuFailed(_TpuRunResult):
class TpuRunError(_TpuRunResult):
error: Exception


@ray.remote
class TPUHeadNodeActor:
def __init__(self):
Expand All @@ -71,9 +72,9 @@ def __init__(self):
self.ip = socket.gethostbyname(socket.gethostname())
self.worker_actors = None

def get_info(self):
def get_info(self) -> Tuple[str, int, str]:
return self.pod_name, self.num_hosts, self.ip

def run(self, remote_fn) -> _TpuRunResult:
if self.worker_actors is not None:
raise RuntimeError("Actors already created")
Expand All @@ -84,15 +85,15 @@ def run(self, remote_fn) -> _TpuRunResult:
for i, (actor, info) in enumerate(zip(self.worker_actors, worker_infos)):
if info[0] is None:
raise RuntimeError(f"Worker actor {i} returned invalid info: {info}")
logger.info(f"Initialized worker slice actors {self.worker_actors}")
except Exception as e:
self.cleanup()
raise RuntimeError("Failed to initialize worker actors") from e

# Start process on all workers
try:
futures = [
actor.run.remote(remote_fn, self.ip, i, self.num_hosts)
for i, actor in enumerate(self.worker_actors)
actor.run.remote(remote_fn, self.ip, i, self.num_hosts) for i, actor in enumerate(self.worker_actors)
]
run_infos = ray.get(futures)

Expand All @@ -113,8 +114,8 @@ def run(self, remote_fn) -> _TpuRunResult:
logger.exception(f"Failed to run job on {self.pod_name}")
_cancel_all_futures(futures)
raise RuntimeError(f"Failed to run job on {self.pod_name}") from e
def cleanup(self):

def cleanup(self) -> None:
if self.worker_actors is not None:
for actor in self.worker_actors:
try:
Expand All @@ -124,6 +125,7 @@ def cleanup(self):
logger.exception(f"Failed to kill worker actor {actor}")
self.worker_actors = None


@ray.remote
class TPUWorkerActor:
def __init__(self):
Expand All @@ -133,13 +135,13 @@ def __init__(self):
self.process: multiprocessing.Process | None = None
self.queue: multiprocessing.Queue | None = None

def get_info(self):
def get_info(self) -> Tuple[str, int, str]:
return self.pod_name, self.num_hosts, self.ip
def run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:

def run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuInfo:
if self.process is not None:
raise RuntimeError("Process already started")

port = 8081
mxla_env = {
"MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}",
Expand All @@ -154,14 +156,11 @@ def run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:

# Create queue for process communication
self.queue = multiprocessing.Queue()
self.process = multiprocessing.Process(
target=self._run_in_process,
args=(remote_fn, self.queue)
)
self.process = multiprocessing.Process(target=self._run_in_process, args=(remote_fn, self.queue))
self.process.start()
return info

def _run_in_process(self, fn, queue):
def _run_in_process(self, fn, queue) -> None:
try:
future = fn.remote()
result = ray.get(future)
Expand All @@ -173,37 +172,41 @@ def _run_in_process(self, fn, queue):
def wait(self) -> object:
if self.process is None or self.queue is None:
raise RuntimeError("No process running")

self.process.join()
try:
success, value = self.queue.get()
self.process = None
self.queue = None

if success:
return value
else:
value.reraise() # Reraise the exception with original traceback
except multiprocessing.queues.Empty:
logger.log("Process timed out")
except QueueEmpty:
logger.exception("Process timed out")
self.cleanup()

def cleanup(self):
raise RuntimeError("Process timed out")

if success:
return value
else:
value.reraise()
return None

def cleanup(self) -> None:
logger.info(f"Cleaning up worker actor {self.pod_name}")
if self.process is not None:
self.process.terminate()
self.process.join()
self.process = None
self.queue = None

def _cancel_all_futures(futures):

def _cancel_all_futures(futures) -> None:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")

def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):

def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env) -> Tuple[RemoteFunction, str]:
if not isinstance(remote_fn, RemoteFunction):
logger.info("CATHY log: decorating non remote function")
remote_fn = ray.remote(remote_fn)
Expand Down Expand Up @@ -233,13 +236,14 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):
logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host")
return remote_fn, tpu_name

def run_on_pod_resumable(
remote_fn,
tpu_type: str,

def run_on_pod_resumable_shared(
remote_fn,
tpu_type: str,
num_slices: int = 1,
max_retries_preemption: int = int(1e6),
max_retries_failure: int = 10
) -> object | list[object]:
max_retries_preemption: int = int(1e6),
max_retries_failure: int = 10,
) -> List[object]:
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.
Handles both single-slice and multi-slice cases.
Expand All @@ -254,7 +258,6 @@ def run_on_pod_resumable(
Returns:
For single-slice: The result of the function
For multi-slice: List of results from each slice
TODO(cathy): refactor to always return a list of results
"""
num_failures = 0
num_preemptions = 0
Expand All @@ -266,7 +269,7 @@ def run_on_pod_resumable(
logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}")
attempt += 1
problem = None

# Create head node actors, one per slice
head_actors = [HeadNodeActor.remote() for _ in range(num_slices)]
infos_futures = [actor.get_info.remote() for actor in head_actors]
Expand All @@ -277,9 +280,9 @@ def run_on_pod_resumable(
raise RuntimeError(f"Worker actor {actor} returned invalid info: {info}")

# Run the job on all slices
futures = [actor.run.remote(remote_fn) for actor in head_actors]
outs = ray.get(futures)
run_futures = [actor.run.remote(remote_fn) for actor in head_actors]
outs = ray.get(run_futures)

# Check results from all slices
results = []
all_succeeded = True
Expand All @@ -305,11 +308,10 @@ def run_on_pod_resumable(
break
else:
raise RuntimeError(f"Unexpected result: {out}")

if all_succeeded:
logger.info("Success")
return results[0] if num_slices == 1 else results

return results
except ray.exceptions.ActorUnavailableError as e:
problem = e
num_preemptions += 1
Expand All @@ -329,6 +331,8 @@ def run_on_pod_resumable(
num_failures += 1
logger.warning(f"Failed {num_failures} times", exc_info=e)
except Exception as e:
if "run_futures" in locals():
_cancel_all_futures(run_futures)
problem = e
num_failures += 1
if num_failures >= max_retries_failure:
Expand All @@ -347,21 +351,33 @@ def run_on_pod_resumable(

if num_preemptions >= max_retries_preemption:
raise RuntimeError("Preempted too many times") from problem
elif num_failures >= max_retries_failure:
else: # num_failures >= max_retries_failure
raise RuntimeError("Failed too many times") from problem


def run_on_pod_multislice_resumable(
remote_fn,
tpu_type: str,
num_slices,
def run_on_pod_resumable(
remote_fn,
tpu_type: str,
max_retries_preemption: int = int(1e6),
max_retries_failure: int = 10
) -> list[object]:
"""
TODO(cathy): deprecated?
"""
return run_on_pod_resumable(remote_fn, tpu_type, num_slices, max_retries_preemption, max_retries_failure)
max_retries_failure: int = 10,
) -> object:
result = run_on_pod_resumable_shared(
remote_fn,
tpu_type,
num_slices=1,
max_retries_preemption=max_retries_preemption,
max_retries_failure=max_retries_failure,
)
assert len(result) == 1
return result[0]


def run_on_pod_multislice_resumable(
remote_fn, tpu_type: str, num_slices, max_retries_preemption: int = int(1e6), max_retries_failure: int = 10
) -> List[object]:
"""TODO: (cathy) deprecate this"""
return run_on_pod_resumable_shared(remote_fn, tpu_type, num_slices, max_retries_preemption, max_retries_failure)


def _run_command(*args, **kwargs):
return subprocess.check_call(args, **kwargs)
Expand All @@ -382,10 +398,15 @@ def run_docker():
logger.exception("Failed to run docker command")
raise e

run_on_pod_resumable(
ray.remote(run_docker), tpu_type=tpu_type, num_slices=num_slices, max_retries_failure=retries, max_retries_preemption=10000
run_on_pod_resumable_shared(
ray.remote(run_docker),
tpu_type=tpu_type,
num_slices=num_slices,
max_retries_failure=retries,
max_retries_preemption=10000,
)


def _kill_old_container(name):
try:
logger.info(f"Killing old container {name}")
Expand All @@ -409,11 +430,9 @@ def _handle_ray_error(tpu_info: _TpuInfo, e: RayError):
elif isinstance(e, WorkerCrashedError):
logger.exception("Worker crashed", exc_info=e)
return TpuPreempted(tpu_info, e)

elif isinstance(e, RayTaskError):
logger.exception("Ray task error", exc_info=e)
return TpuPreempted(tpu_info, e)

elif isinstance(e, RaySystemError):
logger.exception("System error", exc_info=e)
return TpuRunError(tpu_info, e)
Expand Down

0 comments on commit ca16303

Please sign in to comment.