Skip to content

Update skypilot, use async api #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/tau-bench/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def launch_model(model_key: str):
print(f"Launching task on cluster: {cluster_name}")

print("Checking for existing cluster and jobs…")
cluster_status = sky.get(sky.status(cluster_names=[cluster_name]))
cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name]))
if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP:
print(f"Cluster {cluster_name} is UP. Canceling any active jobs…")
sky.stream_and_get(sky.cancel(cluster_name, all=True))
Expand Down
2 changes: 1 addition & 1 deletion examples/art-e/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def launch_model(model_key: str):
print(f"Launching task on cluster: {cluster_name}")

print("Checking for existing cluster and jobs…")
cluster_status = sky.get(sky.status(cluster_names=[cluster_name]))
cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name]))
if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP:
print(f"Cluster {cluster_name} is UP. Canceling any active jobs…")
sky.stream_and_get(sky.cancel(cluster_name, all=True))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ dev-dependencies = [
"openpipe>=4.49.0",
"hatch>=1.14.1",
"ruff>=0.12.1",
"skypilot[cudo,do,fluidstack,gcp,lambda,paperspace,runpod]==0.8.0",
"skypilot[cudo,do,fluidstack,gcp,lambda,paperspace,runpod]==0.9.3",
]

[tool.uv.sources]
Expand Down
65 changes: 50 additions & 15 deletions src/art/skypilot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
wait_for_art_server_to_start,
get_art_server_base_url,
get_vllm_base_url,
get_task_job_id,
)

from .. import dev
Expand Down Expand Up @@ -39,6 +40,7 @@ async def initialize_cluster(
self = cls.__new__(cls)
self._cluster_name = cluster_name
self._envs = {}
self._art_server_job_id = None

if env_path is not None:
self._envs = {
Expand Down Expand Up @@ -69,13 +71,14 @@ async def initialize_cluster(

# check if cluster already exists
cluster_status = await to_thread_typed(
lambda: sky.status(cluster_names=[self._cluster_name])
lambda: sky.stream_and_get(sky.status(cluster_names=[self._cluster_name]))
)
if (
len(cluster_status) == 0
or cluster_status[0]["status"] != sky.ClusterStatus.UP
):
await self._launch_cluster(resources, art_version)

else:
print(f"Cluster {self._cluster_name} exists, using it...")

Expand All @@ -93,31 +96,43 @@ async def initialize_cluster(
art_server_running = False

if art_server_running:
self._art_server_job_id = await get_task_job_id(
cluster_name=self._cluster_name, task_name="art_server"
)
print("Art server task already running, using it…")
else:
art_server_task = sky.Task(name="art_server", run="uv run art")
resources = await to_thread_typed(
lambda: sky.status(cluster_names=[self._cluster_name])[0][
"handle"
].launched_resources

clusters = await to_thread_typed(
lambda: sky.stream_and_get(
sky.status(cluster_names=[self._cluster_name])
)
)
resources = clusters[0]["handle"].launched_resources

# For some reason, skypilot doesn't support the region and zone set
resources = resources.copy(region=None, zone=None)

# If a local path was provided for art_version, ensure it is mounted so the latest
# code is synced to the remote cluster every time we (re)launch the art_server task.
if art_version is not None and os.path.exists(art_version):
art_server_task.workdir = art_version

# print(clusters[0]["handle"].launched_resources)
art_server_task.set_resources(cast(sky.Resources, resources))
art_server_task.update_envs(self._envs)

# run art server task
await to_thread_typed(
lambda: sky.exec(
task=art_server_task,
cluster_name=self._cluster_name,
detach_run=True,
job_id, _ = await to_thread_typed(
lambda: sky.stream_and_get(
sky.exec(
task=art_server_task,
cluster_name=self._cluster_name,
)
)
)
self._art_server_job_id = job_id

print("Task launched, waiting for it to start...")
await wait_for_art_server_to_start(cluster_name=self._cluster_name)
print("Art server task started")
Expand All @@ -126,8 +141,17 @@ async def initialize_cluster(
print(f"Using base_url: {base_url}")

# Manually call the real __init__ now that base_url is ready
super(SkyPilotBackend, self).__init__(base_url=base_url)
super(cls, self).__init__(base_url=base_url)

if self._art_server_job_id is not None:
asyncio.create_task(
asyncio.to_thread(
sky.tail_logs,
cluster_name=self._cluster_name,
job_id=self._art_server_job_id,
follow=True,
)
)
return self

async def _launch_cluster(
Expand Down Expand Up @@ -185,13 +209,22 @@ async def _launch_cluster(
task.setup = setup_script

try:
job_id, _ = await to_thread_typed(
lambda: sky.stream_and_get(
sky.launch(
task=task, cluster_name=self._cluster_name, retry_until_up=True
)
)
)

await to_thread_typed(
lambda: sky.launch(
task=task,
lambda: sky.tail_logs(
cluster_name=self._cluster_name,
retry_until_up=True,
job_id=job_id,
follow=True,
)
)

except Exception as e:
print(f"Error launching cluster: {e}")
print()
Expand Down Expand Up @@ -230,4 +263,6 @@ async def _prepare_backend_for_training(
return (vllm_base_url, api_key)

async def down(self) -> None:
await to_thread_typed(lambda: sky.down(cluster_name=self._cluster_name))
await to_thread_typed(
lambda: sky.stream_and_get(sky.down(cluster_name=self._cluster_name))
)
2 changes: 1 addition & 1 deletion src/art/skypilot/stop_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

async def stop_server() -> None:
cluster_status = await to_thread_typed(
lambda: sky.status(cluster_names=[args.cluster])
lambda: sky.stream_and_get(sky.status(cluster_names=[args.cluster]))
)
if len(cluster_status) == 0 or cluster_status[0]["status"] != sky.ClusterStatus.UP:
raise ValueError(f"Cluster {args.cluster} is not running")
Expand Down
15 changes: 14 additions & 1 deletion src/art/skypilot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,27 @@ async def to_thread_typed(func: Callable[[], T]) -> T:


async def get_task_status(cluster_name: str, task_name: str) -> sky.JobStatus | None:
job_queue = await to_thread_typed(lambda: sky.queue(cluster_name))
job_queue = await to_thread_typed(
lambda: sky.stream_and_get(sky.queue(cluster_name))
)

for job in job_queue:
if job["job_name"] == task_name:
return job["status"]
return None


async def get_task_job_id(cluster_name: str, task_name: str) -> str:
job_queue = await to_thread_typed(
lambda: sky.stream_and_get(sky.queue(cluster_name))
)

for job in job_queue:
if job["job_name"] == task_name:
return job["job_id"]
return None


async def is_task_created(cluster_name: str, task_name: str) -> bool:
task_status = await get_task_status(cluster_name, task_name)
if task_status is None:
Expand Down
Loading