mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Various async fixes (#4722)
This commit is contained in:
parent
0595d2336a
commit
250fcbe62c
@ -11,6 +11,7 @@ from openhands.events.event import Event, EventSource
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
from openhands.storage import FileStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class EventStreamSubscriber(str, Enum):
|
||||
@ -22,14 +23,29 @@ class EventStreamSubscriber(str, Enum):
|
||||
TEST = 'test'
|
||||
|
||||
|
||||
def session_exists(sid: str, file_store: FileStore) -> bool:
|
||||
async def session_exists(sid: str, file_store: FileStore) -> bool:
|
||||
try:
|
||||
file_store.list(f'sessions/{sid}')
|
||||
await call_sync_from_async(file_store.list, f'sessions/{sid}')
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
class AsyncEventStreamWrapper:
|
||||
def __init__(self, event_stream, *args, **kwargs):
|
||||
self.event_stream = event_stream
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __aiter__(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an async generator that yields events
|
||||
for event in self.event_stream.get_events(*self.args, **self.kwargs):
|
||||
# Run the blocking get_events() in a thread pool
|
||||
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventStream:
|
||||
sid: str
|
||||
|
||||
@ -82,6 +82,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
config: The LLM configuration.
|
||||
metrics: The metrics to use.
|
||||
"""
|
||||
self._tried_model_info = False
|
||||
self.metrics: Metrics = (
|
||||
metrics if metrics is not None else Metrics(model_name=config.model)
|
||||
)
|
||||
@ -91,56 +92,6 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# litellm actually uses base Exception here for unknown model
|
||||
self.model_info: ModelInfo | None = None
|
||||
|
||||
try:
|
||||
if self.config.model.startswith('openrouter'):
|
||||
self.model_info = litellm.get_model_info(self.config.model)
|
||||
except Exception as e:
|
||||
logger.debug(f'Error getting model info: {e}')
|
||||
|
||||
if self.config.model.startswith('litellm_proxy/'):
|
||||
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
|
||||
# GET {base_url}/v1/model/info with litellm_model_id as path param
|
||||
response = requests.get(
|
||||
f'{self.config.base_url}/v1/model/info',
|
||||
headers={'Authorization': f'Bearer {self.config.api_key}'},
|
||||
)
|
||||
resp_json = response.json()
|
||||
if 'data' not in resp_json:
|
||||
logger.error(
|
||||
f'Error getting model info from LiteLLM proxy: {resp_json}'
|
||||
)
|
||||
all_model_info = resp_json.get('data', [])
|
||||
current_model_info = next(
|
||||
(
|
||||
info
|
||||
for info in all_model_info
|
||||
if info['model_name']
|
||||
== self.config.model.removeprefix('litellm_proxy/')
|
||||
),
|
||||
None,
|
||||
)
|
||||
if current_model_info:
|
||||
self.model_info = current_model_info['model_info']
|
||||
|
||||
# Last two attempts to get model info from NAME
|
||||
if not self.model_info:
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split(':')[0]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
if not self.model_info:
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split('/')[-1]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(f'Model info: {self.model_info}')
|
||||
|
||||
if self.config.log_completions:
|
||||
if self.config.log_completions_folder is None:
|
||||
raise RuntimeError(
|
||||
@ -148,32 +99,26 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
os.makedirs(self.config.log_completions_folder, exist_ok=True)
|
||||
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if self.config.max_input_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_input_tokens' in self.model_info
|
||||
and isinstance(self.model_info['max_input_tokens'], int)
|
||||
):
|
||||
self.config.max_input_tokens = self.model_info['max_input_tokens']
|
||||
else:
|
||||
# Safe fallback for any potentially viable model
|
||||
self.config.max_input_tokens = 4096
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
max_tokens=self.config.max_output_tokens,
|
||||
timeout=self.config.timeout,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
drop_params=self.config.drop_params,
|
||||
)
|
||||
|
||||
if self.config.max_output_tokens is None:
|
||||
# Safe default for any potentially viable model
|
||||
self.config.max_output_tokens = 4096
|
||||
if self.model_info is not None:
|
||||
# max_output_tokens has precedence over max_tokens, if either exists.
|
||||
# litellm has models with both, one or none of these 2 parameters!
|
||||
if 'max_output_tokens' in self.model_info and isinstance(
|
||||
self.model_info['max_output_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_output_tokens']
|
||||
elif 'max_tokens' in self.model_info and isinstance(
|
||||
self.model_info['max_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_tokens']
|
||||
if self.vision_is_active():
|
||||
logger.debug('LLM: model has vision enabled')
|
||||
if self.is_caching_prompt_active():
|
||||
logger.debug('LLM: caching prompt enabled')
|
||||
if self.is_function_calling_active():
|
||||
logger.debug('LLM: model supports function calling')
|
||||
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
@ -207,6 +152,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
self.init_model_info()
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
@ -300,6 +246,87 @@ class LLM(RetryMixin, DebugMixin):
|
||||
"""
|
||||
return self._completion
|
||||
|
||||
def init_model_info(self):
|
||||
if self._tried_model_info:
|
||||
return
|
||||
self._tried_model_info = True
|
||||
try:
|
||||
if self.config.model.startswith('openrouter'):
|
||||
self.model_info = litellm.get_model_info(self.config.model)
|
||||
except Exception as e:
|
||||
logger.debug(f'Error getting model info: {e}')
|
||||
|
||||
if self.config.model.startswith('litellm_proxy/'):
|
||||
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
|
||||
# GET {base_url}/v1/model/info with litellm_model_id as path param
|
||||
response = requests.get(
|
||||
f'{self.config.base_url}/v1/model/info',
|
||||
headers={'Authorization': f'Bearer {self.config.api_key}'},
|
||||
)
|
||||
resp_json = response.json()
|
||||
if 'data' not in resp_json:
|
||||
logger.error(
|
||||
f'Error getting model info from LiteLLM proxy: {resp_json}'
|
||||
)
|
||||
all_model_info = resp_json.get('data', [])
|
||||
current_model_info = next(
|
||||
(
|
||||
info
|
||||
for info in all_model_info
|
||||
if info['model_name']
|
||||
== self.config.model.removeprefix('litellm_proxy/')
|
||||
),
|
||||
None,
|
||||
)
|
||||
if current_model_info:
|
||||
self.model_info = current_model_info['model_info']
|
||||
|
||||
# Last two attempts to get model info from NAME
|
||||
if not self.model_info:
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split(':')[0]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
if not self.model_info:
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split('/')[-1]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(f'Model info: {self.model_info}')
|
||||
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if self.config.max_input_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_input_tokens' in self.model_info
|
||||
and isinstance(self.model_info['max_input_tokens'], int)
|
||||
):
|
||||
self.config.max_input_tokens = self.model_info['max_input_tokens']
|
||||
else:
|
||||
# Safe fallback for any potentially viable model
|
||||
self.config.max_input_tokens = 4096
|
||||
|
||||
if self.config.max_output_tokens is None:
|
||||
# Safe default for any potentially viable model
|
||||
self.config.max_output_tokens = 4096
|
||||
if self.model_info is not None:
|
||||
# max_output_tokens has precedence over max_tokens, if either exists.
|
||||
# litellm has models with both, one or none of these 2 parameters!
|
||||
if 'max_output_tokens' in self.model_info and isinstance(
|
||||
self.model_info['max_output_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_output_tokens']
|
||||
elif 'max_tokens' in self.model_info and isinstance(
|
||||
self.model_info['max_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_tokens']
|
||||
|
||||
def vision_is_active(self):
|
||||
return not self.config.disable_vision and self._supports_vision()
|
||||
|
||||
|
||||
@ -240,7 +240,7 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||
wait=tenacity.wait_fixed(5),
|
||||
)
|
||||
def _init_container(self):
|
||||
try:
|
||||
|
||||
@ -57,6 +57,7 @@ from openhands.events.observation import (
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.stream import AsyncEventStreamWrapper
|
||||
from openhands.llm import bedrock
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.auth import get_sid_from_token, sign_token
|
||||
@ -339,9 +340,12 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
latest_event_id = -1
|
||||
if websocket.query_params.get('latest_event_id'):
|
||||
latest_event_id = int(websocket.query_params.get('latest_event_id'))
|
||||
for event in session.agent_session.event_stream.get_events(
|
||||
start_id=latest_event_id + 1
|
||||
):
|
||||
|
||||
async_stream = AsyncEventStreamWrapper(
|
||||
session.agent_session.event_stream, latest_event_id + 1
|
||||
)
|
||||
|
||||
async for event in async_stream:
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
@ -665,9 +669,11 @@ async def submit_feedback(request: Request):
|
||||
# Assuming the storage service is already configured in the backend
|
||||
# and there is a function to handle the storage.
|
||||
body = await request.json()
|
||||
events = request.state.conversation.event_stream.get_events(filter_hidden=True)
|
||||
async_stream = AsyncEventStreamWrapper(
|
||||
request.state.conversation.event_stream, filter_hidden=True
|
||||
)
|
||||
trajectory = []
|
||||
for event in events:
|
||||
async for event in async_stream:
|
||||
trajectory.append(event_to_dict(event))
|
||||
feedback = FeedbackDataModel(
|
||||
email=body.get('email', ''),
|
||||
@ -678,7 +684,7 @@ async def submit_feedback(request: Request):
|
||||
trajectory=trajectory,
|
||||
)
|
||||
try:
|
||||
feedback_data = store_feedback(feedback)
|
||||
feedback_data = await call_sync_from_async(store_feedback, feedback)
|
||||
return JSONResponse(status_code=200, content=feedback_data)
|
||||
except Exception as e:
|
||||
logger.error(f'Error submitting feedback: {e}')
|
||||
|
||||
@ -101,7 +101,6 @@ class AgentSession:
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self._create_security_analyzer(config.security.security_analyzer)
|
||||
await self._create_runtime(
|
||||
runtime_name=runtime_name,
|
||||
@ -124,9 +123,14 @@ class AgentSession:
|
||||
self.controller.agent_task = self.controller.start_step_loop()
|
||||
await self.controller.agent_task # type: ignore
|
||||
|
||||
async def close(self):
|
||||
def close(self):
|
||||
"""Closes the Agent session"""
|
||||
self._closed = True
|
||||
def inner_close():
|
||||
asyncio.run(self._close())
|
||||
asyncio.get_event_loop().run_in_executor(None, inner_close)
|
||||
|
||||
async def _close(self):
|
||||
if self._closed:
|
||||
return
|
||||
if self.controller is not None:
|
||||
@ -138,16 +142,6 @@ class AgentSession:
|
||||
if self.security_analyzer is not None:
|
||||
await self.security_analyzer.close()
|
||||
|
||||
if self.loop:
|
||||
if self.loop.is_closed():
|
||||
logger.debug(
|
||||
'Trying to close already closed loop. (It probably never started correctly)'
|
||||
)
|
||||
else:
|
||||
self.loop.stop()
|
||||
self.loop = None
|
||||
|
||||
self._closed = True
|
||||
|
||||
def _create_security_analyzer(self, security_analyzer: str | None):
|
||||
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
|
||||
|
||||
@ -35,7 +35,7 @@ class SessionManager:
|
||||
|
||||
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
|
||||
if sid in self._sessions:
|
||||
asyncio.create_task(self._sessions[sid].close())
|
||||
self._sessions[sid].close()
|
||||
self._sessions[sid] = Session(
|
||||
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
|
||||
)
|
||||
@ -47,7 +47,7 @@ class SessionManager:
|
||||
return self._sessions.get(sid)
|
||||
|
||||
async def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
if not session_exists(sid, self.file_store):
|
||||
if not await session_exists(sid, self.file_store):
|
||||
return None
|
||||
c = Conversation(sid, file_store=self.file_store, config=self.config)
|
||||
await c.connect()
|
||||
@ -87,7 +87,7 @@ class SessionManager:
|
||||
for sid in session_ids_to_remove:
|
||||
to_del_session: Session | None = self._sessions.pop(sid, None)
|
||||
if to_del_session is not None:
|
||||
await to_del_session.close()
|
||||
to_del_session.close()
|
||||
logger.debug(
|
||||
f'Session {sid} and related resource have been removed due to inactivity.'
|
||||
)
|
||||
|
||||
@ -47,9 +47,9 @@ class Session:
|
||||
self.config = config
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
async def close(self):
|
||||
def close(self):
|
||||
self.is_alive = False
|
||||
await self.agent_session.close()
|
||||
self.agent_session.close()
|
||||
|
||||
async def loop_recv(self):
|
||||
try:
|
||||
@ -63,11 +63,11 @@ class Session:
|
||||
continue
|
||||
await self.dispatch(data)
|
||||
except WebSocketDisconnect:
|
||||
await self.close()
|
||||
logger.debug('WebSocket disconnected, sid: %s', self.sid)
|
||||
logger.info('WebSocket disconnected, sid: %s', self.sid)
|
||||
self.close()
|
||||
except RuntimeError as e:
|
||||
await self.close()
|
||||
logger.exception('Error in loop_recv: %s', e)
|
||||
self.close()
|
||||
|
||||
async def _initialize_agent(self, data: dict):
|
||||
self.agent_session.event_stream.add_event(
|
||||
@ -171,10 +171,12 @@ class Session:
|
||||
'Model does not support image upload, change to a different model or try without an image.'
|
||||
)
|
||||
return
|
||||
if self.agent_session.loop:
|
||||
if self.loop:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._add_event(event, EventSource.USER), self.agent_session.loop
|
||||
self._add_event(event, EventSource.USER), self.loop
|
||||
) # type: ignore
|
||||
else:
|
||||
raise RuntimeError('No event loop found')
|
||||
|
||||
async def _add_event(self, event, event_source):
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
@ -50,6 +50,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
||||
'max_output_tokens': 2000,
|
||||
}
|
||||
llm = LLM(default_config)
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 8000
|
||||
assert llm.config.max_output_tokens == 2000
|
||||
|
||||
@ -58,6 +59,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
||||
def test_llm_init_without_model_info(mock_get_model_info, default_config):
|
||||
mock_get_model_info.side_effect = Exception('Model info not available')
|
||||
llm = LLM(default_config)
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 4096
|
||||
assert llm.config.max_output_tokens == 4096
|
||||
|
||||
@ -108,6 +110,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
|
||||
'max_output_tokens': 1500,
|
||||
}
|
||||
llm = LLM(default_config)
|
||||
llm.init_model_info()
|
||||
assert llm.config.max_input_tokens == 7000
|
||||
assert llm.config.max_output_tokens == 1500
|
||||
mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user