-
Notifications
You must be signed in to change notification settings - Fork 20
Update skypilot, use async api #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import sky | ||
import os | ||
import semver | ||
import asyncio | ||
from dotenv import dotenv_values | ||
|
||
from .utils import ( | ||
|
@@ -10,6 +11,7 @@ | |
wait_for_art_server_to_start, | ||
get_art_server_base_url, | ||
get_vllm_base_url, | ||
get_task_job_id, | ||
) | ||
|
||
from .. import dev | ||
|
@@ -34,7 +36,7 @@ async def initialize_cluster( | |
) -> None: | ||
self = cls.__new__(cls) | ||
self._cluster_name = cluster_name | ||
|
||
self._job_id = None | ||
if gpu is None and resources is None: | ||
raise ValueError("Either gpu or resources must be provided") | ||
|
||
|
@@ -57,37 +59,49 @@ async def initialize_cluster( | |
|
||
# check if cluster already exists | ||
cluster_status = await to_thread_typed( | ||
lambda: sky.status(cluster_names=[self._cluster_name]) | ||
lambda: sky.stream_and_get(sky.status(cluster_names=[self._cluster_name])) | ||
) | ||
if ( | ||
len(cluster_status) == 0 | ||
or cluster_status[0]["status"] != sky.ClusterStatus.UP | ||
): | ||
await self._launch_cluster(resources, art_version, env_path) | ||
launch_job_id = await self._launch_cluster(resources, art_version, env_path) | ||
await asyncio.to_thread( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer to use the |
||
sky.tail_logs, | ||
cluster_name=self._cluster_name, | ||
job_id=launch_job_id, | ||
follow=True, | ||
) | ||
else: | ||
print(f"Cluster {self._cluster_name} exists, using it...") | ||
|
||
if await is_task_created( | ||
cluster_name=self._cluster_name, task_name="art_server" | ||
): | ||
self._job_id = await get_task_job_id( | ||
cluster_name=self._cluster_name, task_name="art_server" | ||
) | ||
print("Art server task already running, using it...") | ||
else: | ||
art_server_task = sky.Task(name="art_server", run="uv run art") | ||
resources = await to_thread_typed( | ||
lambda: sky.status(cluster_names=[self._cluster_name])[0][ | ||
"handle" | ||
].launched_resources | ||
clusters = await to_thread_typed( | ||
lambda: sky.stream_and_get( | ||
sky.status(cluster_names=[self._cluster_name]) | ||
) | ||
) | ||
art_server_task.set_resources(resources) | ||
art_server_task.set_resources(clusters[0]["handle"].launched_resources) | ||
|
||
# run art server task | ||
await to_thread_typed( | ||
lambda: sky.exec( | ||
task=art_server_task, | ||
cluster_name=self._cluster_name, | ||
detach_run=True, | ||
job_id, handler = await to_thread_typed( | ||
lambda: sky.stream_and_get( | ||
sky.exec( | ||
task=art_server_task, | ||
cluster_name=self._cluster_name, | ||
) | ||
) | ||
) | ||
self._job_id = job_id | ||
|
||
print("Task launched, waiting for it to start...") | ||
await wait_for_art_server_to_start(cluster_name=self._cluster_name) | ||
print("Art server task started") | ||
|
@@ -98,6 +112,15 @@ async def initialize_cluster( | |
# Manually call the real __init__ now that base_url is ready | ||
super(SkyPilotAPI, self).__init__(base_url=base_url) | ||
|
||
if self._job_id is not None: | ||
asyncio.create_task( | ||
asyncio.to_thread( | ||
sky.tail_logs, | ||
cluster_name=self._cluster_name, | ||
job_id=self._job_id, | ||
follow=True, | ||
) | ||
) | ||
return self | ||
|
||
async def _launch_cluster( | ||
|
@@ -163,9 +186,13 @@ async def _launch_cluster( | |
print(task) | ||
|
||
try: | ||
await to_thread_typed( | ||
lambda: sky.launch(task=task, cluster_name=self._cluster_name) | ||
job_id, handler = await to_thread_typed( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we use handler anywhere? |
||
lambda: sky.stream_and_get( | ||
sky.launch(task=task, cluster_name=self._cluster_name) | ||
) | ||
) | ||
|
||
return job_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of returning job_id here, would it make sense to start tailing logs within _launch_cluster? That would reduce the amount of code in the main initialize_cluster function, and seems like it should be part of the launching process. |
||
except Exception as e: | ||
print(f"Error launching cluster: {e}") | ||
print() | ||
|
@@ -204,4 +231,6 @@ async def _prepare_backend_for_training( | |
return [vllm_base_url, api_key] | ||
|
||
async def down(self) -> None: | ||
await to_thread_typed(lambda: sky.down(cluster_name=self._cluster_name)) | ||
await to_thread_typed( | ||
lambda: sky.stream_and_get(sky.down(cluster_name=self._cluster_name)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we rename this to
_art_server_job_id
? We may have more tasks in the future, and I can imagine myself forgetting whether this id refers to the task of launching the cluster as a whole or the art server in particular.