Skip to content

Update skypilot, use async api #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dev-dependencies = [
"openpipe>=4.49.0",
"hatch>=1.14.1",
"ruff>=0.12.1",
"skypilot[cudo,do,fluidstack,gcp,lambda,paperspace,runpod]==0.8.0",
"skypilot[cudo,do,fluidstack,gcp,lambda,paperspace,runpod]==0.9.3",
]

[tool.uv.sources]
Expand Down
65 changes: 50 additions & 15 deletions src/art/skypilot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
wait_for_art_server_to_start,
get_art_server_base_url,
get_vllm_base_url,
get_task_job_id,
)

from .. import dev
Expand Down Expand Up @@ -39,6 +40,7 @@ async def initialize_cluster(
self = cls.__new__(cls)
self._cluster_name = cluster_name
self._envs = {}
self._art_server_job_id = None

if env_path is not None:
self._envs = {
Expand Down Expand Up @@ -69,13 +71,14 @@ async def initialize_cluster(

# check if cluster already exists
cluster_status = await to_thread_typed(
lambda: sky.status(cluster_names=[self._cluster_name])
lambda: sky.stream_and_get(sky.status(cluster_names=[self._cluster_name]))
)
if (
len(cluster_status) == 0
or cluster_status[0]["status"] != sky.ClusterStatus.UP
):
await self._launch_cluster(resources, art_version)

else:
print(f"Cluster {self._cluster_name} exists, using it...")

Expand All @@ -93,31 +96,43 @@ async def initialize_cluster(
art_server_running = False

if art_server_running:
self._art_server_job_id = await get_task_job_id(
cluster_name=self._cluster_name, task_name="art_server"
)
print("Art server task already running, using it…")
else:
art_server_task = sky.Task(name="art_server", run="uv run art")
resources = await to_thread_typed(
lambda: sky.status(cluster_names=[self._cluster_name])[0][
"handle"
].launched_resources

clusters = await to_thread_typed(
lambda: sky.stream_and_get(
sky.status(cluster_names=[self._cluster_name])
)
)
resources = clusters[0]["handle"].launched_resources

# For some reason, skypilot doesn't support the region and zone set
resources = resources.copy(region=None, zone=None)

# If a local path was provided for art_version, ensure it is mounted so the latest
# code is synced to the remote cluster every time we (re)launch the art_server task.
if art_version is not None and os.path.exists(art_version):
art_server_task.workdir = art_version

# print(clusters[0]["handle"].launched_resources)
art_server_task.set_resources(cast(sky.Resources, resources))
art_server_task.update_envs(self._envs)

# run art server task
await to_thread_typed(
lambda: sky.exec(
task=art_server_task,
cluster_name=self._cluster_name,
detach_run=True,
job_id, _ = await to_thread_typed(
lambda: sky.stream_and_get(
sky.exec(
task=art_server_task,
cluster_name=self._cluster_name,
)
)
)
self._art_server_job_id = job_id

print("Task launched, waiting for it to start...")
await wait_for_art_server_to_start(cluster_name=self._cluster_name)
print("Art server task started")
Expand All @@ -126,8 +141,17 @@ async def initialize_cluster(
print(f"Using base_url: {base_url}")

# Manually call the real __init__ now that base_url is ready
super(SkyPilotBackend, self).__init__(base_url=base_url)
super(cls, self).__init__(base_url=base_url)

if self._art_server_job_id is not None:
asyncio.create_task(
asyncio.to_thread(
sky.tail_logs,
cluster_name=self._cluster_name,
job_id=self._art_server_job_id,
follow=True,
)
)
return self

async def _launch_cluster(
Expand Down Expand Up @@ -185,13 +209,22 @@ async def _launch_cluster(
task.setup = setup_script

try:
job_id, _ = await to_thread_typed(
lambda: sky.stream_and_get(
sky.launch(
task=task, cluster_name=self._cluster_name, retry_until_up=True
)
)
)

await to_thread_typed(
lambda: sky.launch(
task=task,
lambda: sky.tail_logs(
cluster_name=self._cluster_name,
retry_until_up=True,
job_id=job_id,
follow=True,
)
)

except Exception as e:
print(f"Error launching cluster: {e}")
print()
Expand Down Expand Up @@ -230,4 +263,6 @@ async def _prepare_backend_for_training(
return (vllm_base_url, api_key)

async def down(self) -> None:
await to_thread_typed(lambda: sky.down(cluster_name=self._cluster_name))
await to_thread_typed(
lambda: sky.stream_and_get(sky.down(cluster_name=self._cluster_name))
)
15 changes: 14 additions & 1 deletion src/art/skypilot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,27 @@ async def to_thread_typed(func: Callable[[], T]) -> T:


async def get_task_status(cluster_name: str, task_name: str) -> sky.JobStatus | None:
job_queue = await to_thread_typed(lambda: sky.queue(cluster_name))
job_queue = await to_thread_typed(
lambda: sky.stream_and_get(sky.queue(cluster_name))
)

for job in job_queue:
if job["job_name"] == task_name:
return job["status"]
return None


async def get_task_job_id(cluster_name: str, task_name: str) -> str:
job_queue = await to_thread_typed(
lambda: sky.stream_and_get(sky.queue(cluster_name))
)

for job in job_queue:
if job["job_name"] == task_name:
return job["job_id"]
return None


async def is_task_created(cluster_name: str, task_name: str) -> bool:
task_status = await get_task_status(cluster_name, task_name)
if task_status is None:
Expand Down
107 changes: 12 additions & 95 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.