232 lines
8.5 KiB
Python
232 lines
8.5 KiB
Python
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
# SPDX-License-Identifier: MIT
|
|
from __future__ import annotations
|
|
|
|
import datetime
|
|
import json
|
|
import sys
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from types import TracebackType
|
|
from typing import Any, Optional, cast
|
|
|
|
from ...doc_utils import export_module
|
|
|
|
if sys.version_info >= (3, 11):
|
|
from typing import Self
|
|
else:
|
|
from typing_extensions import Self
|
|
|
|
|
|
import requests
|
|
from requests.adapters import HTTPAdapter, Retry
|
|
|
|
from ...import_utils import optional_import_block, require_optional_import
|
|
from .base import JupyterConnectionInfo
|
|
|
|
with optional_import_block():
|
|
import websocket
|
|
from websocket import WebSocket
|
|
|
|
|
|
@export_module("autogen.coding.jupyter")
|
|
class JupyterClient:
|
|
def __init__(self, connection_info: JupyterConnectionInfo):
|
|
"""(Experimental) A client for communicating with a Jupyter gateway server.
|
|
|
|
Args:
|
|
connection_info (JupyterConnectionInfo): Connection information
|
|
"""
|
|
self._connection_info = connection_info
|
|
self._session = requests.Session()
|
|
retries = Retry(total=5, backoff_factor=0.1)
|
|
self._session.mount("http://", HTTPAdapter(max_retries=retries))
|
|
|
|
def _get_headers(self) -> dict[str, str]:
|
|
if self._connection_info.token is None:
|
|
return {}
|
|
return {"Authorization": f"token {self._connection_info.token}"}
|
|
|
|
def _get_api_base_url(self) -> str:
|
|
protocol = "https" if self._connection_info.use_https else "http"
|
|
port = f":{self._connection_info.port}" if self._connection_info.port else ""
|
|
return f"{protocol}://{self._connection_info.host}{port}"
|
|
|
|
def _get_ws_base_url(self) -> str:
|
|
port = f":{self._connection_info.port}" if self._connection_info.port else ""
|
|
return f"ws://{self._connection_info.host}{port}"
|
|
|
|
def list_kernel_specs(self) -> dict[str, dict[str, str]]:
|
|
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
|
|
return cast(dict[str, dict[str, str]], response.json())
|
|
|
|
def list_kernels(self) -> list[dict[str, str]]:
|
|
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers())
|
|
return cast(list[dict[str, str]], response.json())
|
|
|
|
def start_kernel(self, kernel_spec_name: str) -> str:
|
|
"""Start a new kernel.
|
|
|
|
Args:
|
|
kernel_spec_name (str): Name of the kernel spec to start
|
|
|
|
Returns:
|
|
str: ID of the started kernel
|
|
"""
|
|
response = self._session.post(
|
|
f"{self._get_api_base_url()}/api/kernels",
|
|
headers=self._get_headers(),
|
|
json={"name": kernel_spec_name},
|
|
)
|
|
return cast(str, response.json()["id"])
|
|
|
|
def delete_kernel(self, kernel_id: str) -> None:
|
|
response = self._session.delete(
|
|
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers()
|
|
)
|
|
response.raise_for_status()
|
|
|
|
def restart_kernel(self, kernel_id: str) -> None:
|
|
response = self._session.post(
|
|
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers()
|
|
)
|
|
response.raise_for_status()
|
|
|
|
@require_optional_import("websocket", "jupyter-executor")
|
|
def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient:
|
|
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
|
|
ws = websocket.create_connection(ws_url, header=self._get_headers())
|
|
return JupyterKernelClient(ws)
|
|
|
|
|
|
@require_optional_import("websocket", "jupyter-executor")
|
|
class JupyterKernelClient:
|
|
"""(Experimental) A client for communicating with a Jupyter kernel."""
|
|
|
|
@dataclass
|
|
class ExecutionResult:
|
|
@dataclass
|
|
class DataItem:
|
|
mime_type: str
|
|
data: str
|
|
|
|
is_ok: bool
|
|
output: str
|
|
data_items: list[DataItem]
|
|
|
|
def __init__(self, websocket: WebSocket): # type: ignore[no-any-unimported]
|
|
self._session_id: str = uuid.uuid4().hex
|
|
self._websocket: WebSocket = websocket # type: ignore[no-any-unimported]
|
|
|
|
def __enter__(self) -> Self:
|
|
return self
|
|
|
|
def __exit__(
|
|
self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
|
) -> None:
|
|
self.stop()
|
|
|
|
def stop(self) -> None:
|
|
self._websocket.close()
|
|
|
|
def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str:
|
|
timestamp = datetime.datetime.now().isoformat()
|
|
message_id = uuid.uuid4().hex
|
|
message = {
|
|
"header": {
|
|
"username": "autogen",
|
|
"version": "5.0",
|
|
"session": self._session_id,
|
|
"msg_id": message_id,
|
|
"msg_type": message_type,
|
|
"date": timestamp,
|
|
},
|
|
"parent_header": {},
|
|
"channel": channel,
|
|
"content": content,
|
|
"metadata": {},
|
|
"buffers": {},
|
|
}
|
|
self._websocket.send_text(json.dumps(message))
|
|
return message_id
|
|
|
|
def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[dict[str, Any]]:
|
|
self._websocket.settimeout(timeout_seconds)
|
|
try:
|
|
data = self._websocket.recv()
|
|
if isinstance(data, bytes):
|
|
data = data.decode("utf-8")
|
|
return cast(dict[str, Any], json.loads(data))
|
|
except websocket.WebSocketTimeoutException:
|
|
return None
|
|
|
|
def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool:
|
|
message_id = self._send_message(content={}, channel="shell", message_type="kernel_info_request")
|
|
while True:
|
|
message = self._receive_message(timeout_seconds)
|
|
# This means we timed out with no new messages.
|
|
if message is None:
|
|
return False
|
|
if (
|
|
message.get("parent_header", {}).get("msg_id") == message_id
|
|
and message["msg_type"] == "kernel_info_reply"
|
|
):
|
|
return True
|
|
|
|
def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult:
|
|
message_id = self._send_message(
|
|
content={
|
|
"code": code,
|
|
"silent": False,
|
|
"store_history": True,
|
|
"user_expressions": {},
|
|
"allow_stdin": False,
|
|
"stop_on_error": True,
|
|
},
|
|
channel="shell",
|
|
message_type="execute_request",
|
|
)
|
|
|
|
text_output = []
|
|
data_output = []
|
|
while True:
|
|
message = self._receive_message(timeout_seconds)
|
|
if message is None:
|
|
return JupyterKernelClient.ExecutionResult(
|
|
is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[]
|
|
)
|
|
|
|
# Ignore messages that are not for this execution.
|
|
if message.get("parent_header", {}).get("msg_id") != message_id:
|
|
continue
|
|
|
|
msg_type = message["msg_type"]
|
|
content = message["content"]
|
|
if msg_type in ["execute_result", "display_data"]:
|
|
for data_type, data in content["data"].items():
|
|
if data_type == "text/plain":
|
|
text_output.append(data)
|
|
elif data_type.startswith("image/") or data_type == "text/html":
|
|
data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data))
|
|
else:
|
|
text_output.append(json.dumps(data))
|
|
elif msg_type == "stream":
|
|
text_output.append(content["text"])
|
|
elif msg_type == "error":
|
|
# Output is an error.
|
|
return JupyterKernelClient.ExecutionResult(
|
|
is_ok=False,
|
|
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
|
|
data_items=[],
|
|
)
|
|
if msg_type == "status" and content["execution_state"] == "idle":
|
|
break
|
|
|
|
return JupyterKernelClient.ExecutionResult(
|
|
is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output
|
|
)
|