mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 13:52:43 +08:00
Add extensive typing to openhands/runtime/plugins directory (#7726)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
bb98d94b35
commit
883da1b28c
@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.observation import Observation
|
||||
from openhands.runtime.plugins.agent_skills import agentskills
|
||||
from openhands.runtime.plugins.requirement import Plugin, PluginRequirement
|
||||
|
||||
@ -12,3 +14,11 @@ class AgentSkillsRequirement(PluginRequirement):
|
||||
|
||||
class AgentSkillsPlugin(Plugin):
|
||||
name: str = 'agent_skills'
|
||||
|
||||
async def initialize(self, username: str) -> None:
|
||||
"""Initialize the plugin."""
|
||||
pass
|
||||
|
||||
async def run(self, action: Action) -> Observation:
|
||||
"""Run the plugin for a given action."""
|
||||
raise NotImplementedError('AgentSkillsPlugin does not support run method')
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -19,8 +18,14 @@ class JupyterRequirement(PluginRequirement):
|
||||
|
||||
class JupyterPlugin(Plugin):
|
||||
name: str = 'jupyter'
|
||||
kernel_gateway_port: int
|
||||
kernel_id: str
|
||||
gateway_process: asyncio.subprocess.Process
|
||||
python_interpreter_path: str
|
||||
|
||||
async def initialize(self, username: str, kernel_id: str = 'openhands-default'):
|
||||
async def initialize(
|
||||
self, username: str, kernel_id: str = 'openhands-default'
|
||||
) -> None:
|
||||
self.kernel_gateway_port = find_available_tcp_port(40000, 49999)
|
||||
self.kernel_id = kernel_id
|
||||
if username in ['root', 'openhands']:
|
||||
@ -61,19 +66,22 @@ class JupyterPlugin(Plugin):
|
||||
)
|
||||
logger.debug(f'Jupyter launch command: {jupyter_launch_command}')
|
||||
|
||||
self.gateway_process = subprocess.Popen(
|
||||
# Using asyncio.create_subprocess_shell instead of subprocess.Popen
|
||||
# to avoid ASYNC101 linting error
|
||||
self.gateway_process = await asyncio.create_subprocess_shell(
|
||||
jupyter_launch_command,
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
)
|
||||
# read stdout until the kernel gateway is ready
|
||||
output = ''
|
||||
while should_continue() and self.gateway_process.stdout is not None:
|
||||
line = self.gateway_process.stdout.readline().decode('utf-8')
|
||||
line_bytes = await self.gateway_process.stdout.readline()
|
||||
line = line_bytes.decode('utf-8')
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
logger.debug('Waiting for jupyter kernel gateway to start...')
|
||||
|
||||
logger.debug(
|
||||
|
||||
@ -50,22 +50,22 @@ def strip_ansi(o: str) -> str:
|
||||
|
||||
|
||||
class JupyterKernel:
|
||||
def __init__(self, url_suffix, convid, lang='python'):
|
||||
def __init__(self, url_suffix: str, convid: str, lang: str = 'python') -> None:
|
||||
self.base_url = f'http://{url_suffix}'
|
||||
self.base_ws_url = f'ws://{url_suffix}'
|
||||
self.lang = lang
|
||||
self.kernel_id = None
|
||||
self.ws = None
|
||||
self.kernel_id: str | None = None
|
||||
self.ws: tornado.websocket.WebSocketClientConnection | None = None
|
||||
self.convid = convid
|
||||
logging.info(
|
||||
f'Jupyter kernel created for conversation {convid} at {url_suffix}'
|
||||
)
|
||||
|
||||
self.heartbeat_interval = 10000 # 10 seconds
|
||||
self.heartbeat_callback = None
|
||||
self.heartbeat_callback: PeriodicCallback | None = None
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
await self.execute(r'%colors nocolor')
|
||||
# pre-defined tools
|
||||
self.tools_to_run: list[str] = [
|
||||
@ -76,7 +76,7 @@ class JupyterKernel:
|
||||
logging.info(f'Tool [{tool}] initialized:\n{res}')
|
||||
self.initialized = True
|
||||
|
||||
async def _send_heartbeat(self):
|
||||
async def _send_heartbeat(self) -> None:
|
||||
if not self.ws:
|
||||
return
|
||||
try:
|
||||
@ -91,7 +91,7 @@ class JupyterKernel:
|
||||
'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?'
|
||||
)
|
||||
|
||||
async def _connect(self):
|
||||
async def _connect(self) -> None:
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
self.ws = None
|
||||
@ -138,7 +138,7 @@ class JupyterKernel:
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(2),
|
||||
)
|
||||
async def execute(self, code, timeout=120):
|
||||
async def execute(self, code: str, timeout: int = 120) -> str:
|
||||
if not self.ws:
|
||||
await self._connect()
|
||||
|
||||
@ -170,45 +170,49 @@ class JupyterKernel:
|
||||
)
|
||||
logging.info(f'Executed code in jupyter kernel:\n{res}')
|
||||
|
||||
outputs = []
|
||||
outputs: list[str] = []
|
||||
|
||||
async def wait_for_messages():
|
||||
async def wait_for_messages() -> bool:
|
||||
execution_done = False
|
||||
while not execution_done:
|
||||
assert self.ws is not None
|
||||
msg = await self.ws.read_message()
|
||||
msg = json_decode(msg)
|
||||
msg_type = msg['msg_type']
|
||||
parent_msg_id = msg['parent_header'].get('msg_id', None)
|
||||
if msg is None:
|
||||
continue
|
||||
msg_dict = json_decode(msg)
|
||||
msg_type = msg_dict['msg_type']
|
||||
parent_msg_id = msg_dict['parent_header'].get('msg_id', None)
|
||||
|
||||
if parent_msg_id != msg_id:
|
||||
continue
|
||||
|
||||
if os.environ.get('DEBUG'):
|
||||
logging.info(
|
||||
f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}"
|
||||
f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg_dict['content']}"
|
||||
)
|
||||
|
||||
if msg_type == 'error':
|
||||
traceback = '\n'.join(msg['content']['traceback'])
|
||||
traceback = '\n'.join(msg_dict['content']['traceback'])
|
||||
outputs.append(traceback)
|
||||
execution_done = True
|
||||
elif msg_type == 'stream':
|
||||
outputs.append(msg['content']['text'])
|
||||
outputs.append(msg_dict['content']['text'])
|
||||
elif msg_type in ['execute_result', 'display_data']:
|
||||
outputs.append(msg['content']['data']['text/plain'])
|
||||
if 'image/png' in msg['content']['data']:
|
||||
outputs.append(msg_dict['content']['data']['text/plain'])
|
||||
if 'image/png' in msg_dict['content']['data']:
|
||||
# use markdone to display image (in case of large image)
|
||||
outputs.append(
|
||||
f"\n\n"
|
||||
f"\n\n"
|
||||
)
|
||||
|
||||
elif msg_type == 'execute_reply':
|
||||
execution_done = True
|
||||
return execution_done
|
||||
|
||||
async def interrupt_kernel():
|
||||
async def interrupt_kernel() -> None:
|
||||
client = AsyncHTTPClient()
|
||||
if self.kernel_id is None:
|
||||
return
|
||||
interrupt_response = await client.fetch(
|
||||
f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt',
|
||||
method='POST',
|
||||
@ -234,7 +238,7 @@ class JupyterKernel:
|
||||
logging.info(f'OUTPUT:\n{ret}')
|
||||
return ret
|
||||
|
||||
async def shutdown_async(self):
|
||||
async def shutdown_async(self) -> None:
|
||||
if self.kernel_id:
|
||||
client = AsyncHTTPClient()
|
||||
await client.fetch(
|
||||
@ -248,10 +252,10 @@ class JupyterKernel:
|
||||
|
||||
|
||||
class ExecuteHandler(tornado.web.RequestHandler):
|
||||
def initialize(self, jupyter_kernel):
|
||||
def initialize(self, jupyter_kernel: JupyterKernel) -> None:
|
||||
self.jupyter_kernel = jupyter_kernel
|
||||
|
||||
async def post(self):
|
||||
async def post(self) -> None:
|
||||
data = json_decode(self.request.body)
|
||||
code = data.get('code')
|
||||
|
||||
@ -265,10 +269,10 @@ class ExecuteHandler(tornado.web.RequestHandler):
|
||||
self.write(output)
|
||||
|
||||
|
||||
def make_app():
|
||||
def make_app() -> tornado.web.Application:
|
||||
jupyter_kernel = JupyterKernel(
|
||||
f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT')}",
|
||||
os.environ.get('JUPYTER_GATEWAY_KERNEL_ID'),
|
||||
f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT', '8888')}",
|
||||
os.environ.get('JUPYTER_GATEWAY_KERNEL_ID', 'default'),
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize())
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ class Plugin:
|
||||
name: str
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self, username: str):
|
||||
async def initialize(self, username: str) -> None:
|
||||
"""Initialize the plugin."""
|
||||
pass
|
||||
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.observation import Observation
|
||||
from openhands.runtime.plugins.requirement import Plugin, PluginRequirement
|
||||
from openhands.runtime.utils.system import check_port_available
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
@ -17,10 +19,11 @@ class VSCodeRequirement(PluginRequirement):
|
||||
|
||||
class VSCodePlugin(Plugin):
|
||||
name: str = 'vscode'
|
||||
vscode_port: int | None = None
|
||||
vscode_connection_token: str | None = None
|
||||
vscode_port: Optional[int] = None
|
||||
vscode_connection_token: Optional[str] = None
|
||||
gateway_process: asyncio.subprocess.Process
|
||||
|
||||
async def initialize(self, username: str):
|
||||
async def initialize(self, username: str) -> None:
|
||||
if username not in ['root', 'openhands']:
|
||||
self.vscode_port = None
|
||||
self.vscode_connection_token = None
|
||||
@ -41,22 +44,29 @@ class VSCodePlugin(Plugin):
|
||||
'EOF'
|
||||
)
|
||||
|
||||
self.gateway_process = subprocess.Popen(
|
||||
# Using asyncio.create_subprocess_shell instead of subprocess.Popen
|
||||
# to avoid ASYNC101 linting error
|
||||
self.gateway_process = await asyncio.create_subprocess_shell(
|
||||
cmd,
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
)
|
||||
# read stdout until the kernel gateway is ready
|
||||
output = ''
|
||||
while should_continue() and self.gateway_process.stdout is not None:
|
||||
line = self.gateway_process.stdout.readline().decode('utf-8')
|
||||
line_bytes = await self.gateway_process.stdout.readline()
|
||||
line = line_bytes.decode('utf-8')
|
||||
print(line)
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
logger.debug('Waiting for VSCode server to start...')
|
||||
|
||||
logger.debug(
|
||||
f'VSCode server started at port {self.vscode_port}. Output: {output}'
|
||||
)
|
||||
|
||||
async def run(self, action: Action) -> Observation:
|
||||
"""Run the plugin for a given action."""
|
||||
raise NotImplementedError('VSCodePlugin does not support run method')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user