-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement local jupyter notebook execution support
- Loading branch information
Showing
9 changed files
with
999 additions
and
349 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
python/packages/autogen-ext/src/autogen_ext/code_executors/jupyter/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from ._jupyter_client import JupyterClient | ||
from ._jupyter_code_executor import JupyterCodeExecutor, JupyterCodeResult | ||
from ._jupyter_connectable import JupyterConnectable | ||
from ._jupyter_connection_info import JupyterConnectionInfo | ||
from ._local_jupyter_server import LocalJupyterServer | ||
|
||
__all__ = [ | ||
"JupyterConnectable", | ||
"JupyterConnectionInfo", | ||
"JupyterClient", | ||
"LocalJupyterServer", | ||
"JupyterCodeExecutor", | ||
"JupyterCodeResult", | ||
] |
206 changes: 206 additions & 0 deletions
206
python/packages/autogen-ext/src/autogen_ext/code_executors/jupyter/_jupyter_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
from __future__ import annotations | ||
|
||
import sys | ||
from dataclasses import dataclass | ||
from types import TracebackType | ||
from typing import Any, AsyncGenerator, cast | ||
|
||
if sys.version_info >= (3, 11): | ||
from typing import Self | ||
else: | ||
from typing_extensions import Self | ||
|
||
import datetime | ||
import json | ||
import uuid | ||
|
||
import requests | ||
from requests.adapters import HTTPAdapter, Retry | ||
from websockets.asyncio.client import ClientConnection, connect | ||
|
||
from ._jupyter_connection_info import JupyterConnectionInfo | ||
|
||
|
||
class JupyterClient: | ||
def __init__(self, connection_info: JupyterConnectionInfo): | ||
"""(Experimental) A client for communicating with a Jupyter gateway server. | ||
Args: | ||
connection_info (JupyterConnectionInfo): Connection information | ||
""" | ||
self._connection_info = connection_info | ||
self._session = requests.Session() | ||
retries = Retry(total=5, backoff_factor=0.1) | ||
self._session.mount("http://", HTTPAdapter(max_retries=retries)) | ||
|
||
def _get_headers(self) -> dict[str, str]: | ||
if self._connection_info.token is None: | ||
return {} | ||
return {"Authorization": f"token {self._connection_info.token}"} | ||
|
||
def _get_cookies(self) -> str: | ||
cookies = self._session.cookies.get_dict() | ||
return "; ".join([f"{name}={value}" for name, value in cookies.items()]) | ||
|
||
def _get_api_base_url(self) -> str: | ||
protocol = "https" if self._connection_info.use_https else "http" | ||
port = f":{self._connection_info.port}" if self._connection_info.port else "" | ||
return f"{protocol}://{self._connection_info.host}{port}" | ||
|
||
def _get_ws_base_url(self) -> str: | ||
port = f":{self._connection_info.port}" if self._connection_info.port else "" | ||
return f"ws://{self._connection_info.host}{port}" | ||
|
||
def list_kernel_specs(self) -> dict[str, dict[str, str]]: | ||
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers()) | ||
return cast(dict[str, dict[str, str]], response.json()) | ||
|
||
def list_kernels(self) -> list[dict[str, str]]: | ||
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers()) | ||
return cast(list[dict[str, str]], response.json()) | ||
|
||
def start_kernel(self, kernel_spec_name: str) -> str: | ||
"""Start a new kernel. | ||
Args: | ||
kernel_spec_name (str): Name of the kernel spec to start | ||
Returns: | ||
str: ID of the started kernel | ||
""" | ||
|
||
response = self._session.post( | ||
f"{self._get_api_base_url()}/api/kernels", | ||
headers=self._get_headers(), | ||
json={"name": kernel_spec_name}, | ||
) | ||
return cast(str, response.json()["id"]) | ||
|
||
def delete_kernel(self, kernel_id: str) -> None: | ||
response = self._session.delete( | ||
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers() | ||
) | ||
response.raise_for_status() | ||
|
||
def restart_kernel(self, kernel_id: str) -> None: | ||
response = self._session.post( | ||
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers() | ||
) | ||
response.raise_for_status() | ||
|
||
async def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient: | ||
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels" | ||
headers = self._get_headers() | ||
headers["Cookie"] = self._get_cookies() | ||
websocket = await connect(ws_url, additional_headers=headers) | ||
return JupyterKernelClient(websocket) | ||
|
||
|
||
class JupyterKernelClient: | ||
"""A client for communicating with a Jupyter kernel.""" | ||
|
||
@dataclass | ||
class ExecutionResult: | ||
@dataclass | ||
class DataItem: | ||
mime_type: str | ||
data: str | ||
|
||
is_ok: bool | ||
output: str | ||
data_items: list[DataItem] | ||
|
||
def __init__(self, websocket: ClientConnection): | ||
self._session_id: str = uuid.uuid4().hex | ||
self._websocket = websocket | ||
|
||
async def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str: | ||
timestamp = datetime.datetime.now().isoformat() | ||
message_id = uuid.uuid4().hex | ||
message = { | ||
"header": { | ||
"username": "autogen", | ||
"version": "5.0", | ||
"session": self._session_id, | ||
"msg_id": message_id, | ||
"msg_type": message_type, | ||
"date": timestamp, | ||
}, | ||
"parent_header": {}, | ||
"channel": channel, | ||
"content": content, | ||
"metadata": {}, | ||
"buffers": {}, | ||
} | ||
await self._websocket.send(json.dumps(message)) | ||
return message_id | ||
|
||
async def wait_for_ready(self) -> None: | ||
message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request") | ||
|
||
async for message in self._receive_message(message_id): | ||
if message["msg_type"] == "kernel_info_reply": | ||
break | ||
|
||
async def execute(self, code: str) -> ExecutionResult: | ||
message_id = await self._send_message( | ||
content={ | ||
"code": code, | ||
"silent": False, | ||
"store_history": True, | ||
"user_expressions": {}, | ||
"allow_stdin": False, | ||
"stop_on_error": True, | ||
}, | ||
channel="shell", | ||
message_type="execute_request", | ||
) | ||
|
||
text_output: list[str] = [] | ||
data_output: list[JupyterKernelClient.ExecutionResult.DataItem] = [] | ||
|
||
async for message in self._receive_message(message_id): | ||
content = message["content"] | ||
match message["msg_type"]: | ||
case "execute_result" | "display_data": | ||
for data_type, data in content["data"].items(): | ||
match data_type: | ||
case "text/plain": | ||
text_output.append(data) | ||
case data if data.startswith("image/") or data == "text/html": | ||
data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data)) | ||
case _: | ||
text_output.append(json.dumps(data)) | ||
case "stream": | ||
text_output.append(content["text"]) | ||
case "error": | ||
return JupyterKernelClient.ExecutionResult( | ||
is_ok=False, | ||
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}", | ||
data_items=[], | ||
) | ||
case _: | ||
pass | ||
|
||
if message["msg_type"] == "status" and content["execution_state"] == "idle": | ||
break | ||
|
||
return JupyterKernelClient.ExecutionResult( | ||
is_ok=True, output="\n".join([output for output in text_output]), data_items=data_output | ||
) | ||
|
||
async def __aenter__(self) -> Self: | ||
return self | ||
|
||
async def __aexit__( | ||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | ||
) -> None: | ||
await self._websocket.close() | ||
|
||
async def _receive_message(self, message_id: str) -> AsyncGenerator[dict[str, Any]]: | ||
async for data in self._websocket: | ||
if isinstance(data, bytes): | ||
data = data.decode("utf-8") | ||
message = cast(dict[str, Any], json.loads(data)) | ||
if message.get("parent_header", {}).get("msg_id") == message_id: | ||
yield message |
146 changes: 146 additions & 0 deletions
146
python/packages/autogen-ext/src/autogen_ext/code_executors/jupyter/_jupyter_code_executor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import asyncio | ||
import base64 | ||
import json | ||
import sys | ||
import uuid | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from types import TracebackType | ||
|
||
if sys.version_info >= (3, 11): | ||
from typing import Self | ||
else: | ||
from typing_extensions import Self | ||
|
||
from autogen_core import CancellationToken | ||
from autogen_core.code_executor import CodeBlock, CodeExecutor, CodeResult | ||
|
||
from .._common import silence_pip | ||
from ._jupyter_connectable import JupyterConnectable | ||
|
||
|
||
@dataclass | ||
class JupyterCodeResult(CodeResult): | ||
"""A code result class for Jupyter code executor.""" | ||
|
||
output_files: list[Path] | ||
|
||
|
||
class JupyterCodeExecutor(CodeExecutor): | ||
def __init__( | ||
self, | ||
server: JupyterConnectable, | ||
kernel_name: str = "python3", | ||
timeout: int = 60, | ||
output_dir: Path = Path("."), | ||
): | ||
"""A code executor class that executes code statefully using | ||
a Jupyter server supplied to this class. | ||
Each execution is stateful and can access variables created from previous | ||
executions in the same session. | ||
Args: | ||
server (JupyterConnectable): The Jupyter server to use. | ||
kernel_name (str): The kernel name to use. Make sure it is installed. | ||
By default, it is "python3". | ||
timeout (int): The timeout for code execution, by default 60. | ||
output_dir (Path): The directory to save output files, by default ".". | ||
""" | ||
if timeout < 1: | ||
raise ValueError("Timeout must be greater than or equal to 1.") | ||
|
||
self._jupyter_client = server.get_client() | ||
self._kernel_name = kernel_name | ||
self._timeout = timeout | ||
self._output_dir = output_dir | ||
self.start() | ||
|
||
async def execute_code_blocks( | ||
self, code_blocks: list[CodeBlock], cancellation_token: CancellationToken | ||
) -> JupyterCodeResult: | ||
"""Execute code blocks and return the result. | ||
Args: | ||
code_blocks (list[CodeBlock]): The code blocks to execute. | ||
Returns: | ||
JupyterCodeResult: The result of the code execution. | ||
Raises: | ||
asyncio.TimeoutError: Code execution timeouts | ||
asyncio.CancelledError: CancellationToken evoked during execution | ||
""" | ||
async with await self._jupyter_client.get_kernel_client(self._kernel_id) as kernel_client: | ||
wait_for_ready_task = asyncio.create_task(kernel_client.wait_for_ready()) | ||
cancellation_token.link_future(wait_for_ready_task) | ||
await asyncio.wait_for(wait_for_ready_task, timeout=self._timeout) | ||
|
||
outputs: list[str] = [] | ||
output_files: list[Path] = [] | ||
for code_block in code_blocks: | ||
code = silence_pip(code_block.code, code_block.language) | ||
execute_task = asyncio.create_task(kernel_client.execute(code)) | ||
cancellation_token.link_future(execute_task) | ||
result = await asyncio.wait_for(execute_task, timeout=self._timeout) | ||
|
||
if result.is_ok: | ||
outputs.append(result.output) | ||
for data in result.data_items: | ||
match data.mime_type: | ||
case "image/png": | ||
path = self._save_image(data.data) | ||
outputs.append(f"Image data saved to {path}") | ||
output_files.append(path) | ||
case "text/html": | ||
path = self._save_html(data.data) | ||
outputs.append(f"HTML data saved to {path}") | ||
output_files.append(path) | ||
case _: | ||
outputs.append(json.dumps(data.data)) | ||
else: | ||
return JupyterCodeResult(exit_code=1, output=f"ERROR: {result.output}", output_files=[]) | ||
|
||
return JupyterCodeResult( | ||
exit_code=0, output="\n".join([output for output in outputs]), output_files=output_files | ||
) | ||
|
||
async def restart(self) -> None: | ||
"""Restart the code executor.""" | ||
self._jupyter_client.restart_kernel(self._kernel_id) | ||
self._jupyter_kernel_client = self._jupyter_client.get_kernel_client(self._kernel_id) | ||
|
||
def start(self) -> None: | ||
"""Start the kernel.""" | ||
available_kernels = self._jupyter_client.list_kernel_specs() | ||
if self._kernel_name not in available_kernels["kernelspecs"]: | ||
raise ValueError(f"Kernel {self._kernel_name} is not installed.") | ||
|
||
self._kernel_id = self._jupyter_client.start_kernel(self._kernel_name) | ||
|
||
def stop(self) -> None: | ||
"""Stop the kernel.""" | ||
if self._kernel_id is not None: | ||
self._jupyter_client.delete_kernel(self._kernel_id) | ||
self._kernel_id = None | ||
|
||
def __enter__(self) -> Self: | ||
return self | ||
|
||
def __exit__( | ||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | ||
) -> None: | ||
self.stop() | ||
|
||
def _save_image(self, image_data_base64: str) -> Path: | ||
"""Save image data to a file.""" | ||
image_data = base64.b64decode(image_data_base64) | ||
path = self._output_dir / f"{uuid.uuid4().hex}.png" | ||
path.write_bytes(image_data) | ||
return path.absolute() | ||
|
||
def _save_html(self, html_data: str) -> Path: | ||
"""Save html data to a file.""" | ||
path = self._output_dir / f"{uuid.uuid4().hex}.html" | ||
path.write_text(html_data) | ||
return path.absolute() |
Oops, something went wrong.