Various async fixes (#4722)

This commit is contained in:
Robert Brennan 2024-11-04 07:08:09 -08:00 committed by GitHub
parent 0595d2336a
commit 250fcbe62c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 154 additions and 106 deletions

View File

@ -11,6 +11,7 @@ from openhands.events.event import Event, EventSource
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.storage import FileStore
from openhands.utils.async_utils import call_sync_from_async
class EventStreamSubscriber(str, Enum):
@ -22,14 +23,29 @@ class EventStreamSubscriber(str, Enum):
TEST = 'test'
def session_exists(sid: str, file_store: FileStore) -> bool:
async def session_exists(sid: str, file_store: FileStore) -> bool:
try:
file_store.list(f'sessions/{sid}')
await call_sync_from_async(file_store.list, f'sessions/{sid}')
return True
except FileNotFoundError:
return False
class AsyncEventStreamWrapper:
def __init__(self, event_stream, *args, **kwargs):
self.event_stream = event_stream
self.args = args
self.kwargs = kwargs
async def __aiter__(self):
loop = asyncio.get_running_loop()
# Create an async generator that yields events
for event in self.event_stream.get_events(*self.args, **self.kwargs):
# Run the blocking get_events() in a thread pool
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
@dataclass
class EventStream:
sid: str

View File

@ -82,6 +82,7 @@ class LLM(RetryMixin, DebugMixin):
config: The LLM configuration.
metrics: The metrics to use.
"""
self._tried_model_info = False
self.metrics: Metrics = (
metrics if metrics is not None else Metrics(model_name=config.model)
)
@ -91,56 +92,6 @@ class LLM(RetryMixin, DebugMixin):
# litellm actually uses base Exception here for unknown model
self.model_info: ModelInfo | None = None
try:
if self.config.model.startswith('openrouter'):
self.model_info = litellm.get_model_info(self.config.model)
except Exception as e:
logger.debug(f'Error getting model info: {e}')
if self.config.model.startswith('litellm_proxy/'):
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
# GET {base_url}/v1/model/info with litellm_model_id as path param
response = requests.get(
f'{self.config.base_url}/v1/model/info',
headers={'Authorization': f'Bearer {self.config.api_key}'},
)
resp_json = response.json()
if 'data' not in resp_json:
logger.error(
f'Error getting model info from LiteLLM proxy: {resp_json}'
)
all_model_info = resp_json.get('data', [])
current_model_info = next(
(
info
for info in all_model_info
if info['model_name']
== self.config.model.removeprefix('litellm_proxy/')
),
None,
)
if current_model_info:
self.model_info = current_model_info['model_info']
# Last two attempts to get model info from NAME
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split(':')[0]
)
# noinspection PyBroadException
except Exception:
pass
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split('/')[-1]
)
# noinspection PyBroadException
except Exception:
pass
logger.debug(f'Model info: {self.model_info}')
if self.config.log_completions:
if self.config.log_completions_folder is None:
raise RuntimeError(
@ -148,32 +99,26 @@ class LLM(RetryMixin, DebugMixin):
)
os.makedirs(self.config.log_completions_folder, exist_ok=True)
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
self.model_info is not None
and 'max_input_tokens' in self.model_info
and isinstance(self.model_info['max_input_tokens'], int)
):
self.config.max_input_tokens = self.model_info['max_input_tokens']
else:
# Safe fallback for any potentially viable model
self.config.max_input_tokens = 4096
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
max_tokens=self.config.max_output_tokens,
timeout=self.config.timeout,
temperature=self.config.temperature,
top_p=self.config.top_p,
drop_params=self.config.drop_params,
)
if self.config.max_output_tokens is None:
# Safe default for any potentially viable model
self.config.max_output_tokens = 4096
if self.model_info is not None:
# max_output_tokens has precedence over max_tokens, if either exists.
# litellm has models with both, one or none of these 2 parameters!
if 'max_output_tokens' in self.model_info and isinstance(
self.model_info['max_output_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_output_tokens']
elif 'max_tokens' in self.model_info and isinstance(
self.model_info['max_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_tokens']
if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
if self.is_function_calling_active():
logger.debug('LLM: model supports function calling')
self._completion = partial(
litellm_completion,
@ -207,6 +152,7 @@ class LLM(RetryMixin, DebugMixin):
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
self.init_model_info()
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
@ -300,6 +246,87 @@ class LLM(RetryMixin, DebugMixin):
"""
return self._completion
def init_model_info(self):
if self._tried_model_info:
return
self._tried_model_info = True
try:
if self.config.model.startswith('openrouter'):
self.model_info = litellm.get_model_info(self.config.model)
except Exception as e:
logger.debug(f'Error getting model info: {e}')
if self.config.model.startswith('litellm_proxy/'):
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
# GET {base_url}/v1/model/info with litellm_model_id as path param
response = requests.get(
f'{self.config.base_url}/v1/model/info',
headers={'Authorization': f'Bearer {self.config.api_key}'},
)
resp_json = response.json()
if 'data' not in resp_json:
logger.error(
f'Error getting model info from LiteLLM proxy: {resp_json}'
)
all_model_info = resp_json.get('data', [])
current_model_info = next(
(
info
for info in all_model_info
if info['model_name']
== self.config.model.removeprefix('litellm_proxy/')
),
None,
)
if current_model_info:
self.model_info = current_model_info['model_info']
# Last two attempts to get model info from NAME
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split(':')[0]
)
# noinspection PyBroadException
except Exception:
pass
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split('/')[-1]
)
# noinspection PyBroadException
except Exception:
pass
logger.debug(f'Model info: {self.model_info}')
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
self.model_info is not None
and 'max_input_tokens' in self.model_info
and isinstance(self.model_info['max_input_tokens'], int)
):
self.config.max_input_tokens = self.model_info['max_input_tokens']
else:
# Safe fallback for any potentially viable model
self.config.max_input_tokens = 4096
if self.config.max_output_tokens is None:
# Safe default for any potentially viable model
self.config.max_output_tokens = 4096
if self.model_info is not None:
# max_output_tokens has precedence over max_tokens, if either exists.
# litellm has models with both, one or none of these 2 parameters!
if 'max_output_tokens' in self.model_info and isinstance(
self.model_info['max_output_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_output_tokens']
elif 'max_tokens' in self.model_info and isinstance(
self.model_info['max_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_tokens']
def vision_is_active(self):
return not self.config.disable_vision and self._supports_vision()

View File

@ -240,7 +240,7 @@ class EventStreamRuntime(Runtime):
@tenacity.retry(
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
wait=tenacity.wait_fixed(5),
)
def _init_container(self):
try:

View File

@ -57,6 +57,7 @@ from openhands.events.observation import (
NullObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.llm import bedrock
from openhands.runtime.base import Runtime
from openhands.server.auth import get_sid_from_token, sign_token
@ -339,9 +340,12 @@ async def websocket_endpoint(websocket: WebSocket):
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
for event in session.agent_session.event_stream.get_events(
start_id=latest_event_id + 1
):
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
event,
(
@ -665,9 +669,11 @@ async def submit_feedback(request: Request):
# Assuming the storage service is already configured in the backend
# and there is a function to handle the storage.
body = await request.json()
events = request.state.conversation.event_stream.get_events(filter_hidden=True)
async_stream = AsyncEventStreamWrapper(
request.state.conversation.event_stream, filter_hidden=True
)
trajectory = []
for event in events:
async for event in async_stream:
trajectory.append(event_to_dict(event))
feedback = FeedbackDataModel(
email=body.get('email', ''),
@ -678,7 +684,7 @@ async def submit_feedback(request: Request):
trajectory=trajectory,
)
try:
feedback_data = store_feedback(feedback)
feedback_data = await call_sync_from_async(store_feedback, feedback)
return JSONResponse(status_code=200, content=feedback_data)
except Exception as e:
logger.error(f'Error submitting feedback: {e}')

View File

@ -101,7 +101,6 @@ class AgentSession:
agent_configs: dict[str, AgentConfig] | None = None,
status_message_callback: Optional[Callable] = None,
):
self.loop = asyncio.get_running_loop()
self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(
runtime_name=runtime_name,
@ -124,9 +123,14 @@ class AgentSession:
self.controller.agent_task = self.controller.start_step_loop()
await self.controller.agent_task # type: ignore
async def close(self):
def close(self):
"""Closes the Agent session"""
self._closed = True
def inner_close():
asyncio.run(self._close())
asyncio.get_event_loop().run_in_executor(None, inner_close)
async def _close(self):
if self._closed:
return
if self.controller is not None:
@ -138,16 +142,6 @@ class AgentSession:
if self.security_analyzer is not None:
await self.security_analyzer.close()
if self.loop:
if self.loop.is_closed():
logger.debug(
'Trying to close already closed loop. (It probably never started correctly)'
)
else:
self.loop.stop()
self.loop = None
self._closed = True
def _create_security_analyzer(self, security_analyzer: str | None):
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions

View File

@ -35,7 +35,7 @@ class SessionManager:
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
if sid in self._sessions:
asyncio.create_task(self._sessions[sid].close())
self._sessions[sid].close()
self._sessions[sid] = Session(
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
)
@ -47,7 +47,7 @@ class SessionManager:
return self._sessions.get(sid)
async def attach_to_conversation(self, sid: str) -> Conversation | None:
if not session_exists(sid, self.file_store):
if not await session_exists(sid, self.file_store):
return None
c = Conversation(sid, file_store=self.file_store, config=self.config)
await c.connect()
@ -87,7 +87,7 @@ class SessionManager:
for sid in session_ids_to_remove:
to_del_session: Session | None = self._sessions.pop(sid, None)
if to_del_session is not None:
await to_del_session.close()
to_del_session.close()
logger.debug(
f'Session {sid} and related resource have been removed due to inactivity.'
)

View File

@ -47,9 +47,9 @@ class Session:
self.config = config
self.loop = asyncio.get_event_loop()
async def close(self):
def close(self):
self.is_alive = False
await self.agent_session.close()
self.agent_session.close()
async def loop_recv(self):
try:
@ -63,11 +63,11 @@ class Session:
continue
await self.dispatch(data)
except WebSocketDisconnect:
await self.close()
logger.debug('WebSocket disconnected, sid: %s', self.sid)
logger.info('WebSocket disconnected, sid: %s', self.sid)
self.close()
except RuntimeError as e:
await self.close()
logger.exception('Error in loop_recv: %s', e)
self.close()
async def _initialize_agent(self, data: dict):
self.agent_session.event_stream.add_event(
@ -171,10 +171,12 @@ class Session:
'Model does not support image upload, change to a different model or try without an image.'
)
return
if self.agent_session.loop:
if self.loop:
asyncio.run_coroutine_threadsafe(
self._add_event(event, EventSource.USER), self.agent_session.loop
self._add_event(event, EventSource.USER), self.loop
) # type: ignore
else:
raise RuntimeError('No event loop found')
async def _add_event(self, event, event_source):
self.agent_session.event_stream.add_event(event, EventSource.USER)

View File

@ -50,6 +50,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
'max_output_tokens': 2000,
}
llm = LLM(default_config)
llm.init_model_info()
assert llm.config.max_input_tokens == 8000
assert llm.config.max_output_tokens == 2000
@ -58,6 +59,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
def test_llm_init_without_model_info(mock_get_model_info, default_config):
mock_get_model_info.side_effect = Exception('Model info not available')
llm = LLM(default_config)
llm.init_model_info()
assert llm.config.max_input_tokens == 4096
assert llm.config.max_output_tokens == 4096
@ -108,6 +110,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
'max_output_tokens': 1500,
}
llm = LLM(default_config)
llm.init_model_info()
assert llm.config.max_input_tokens == 7000
assert llm.config.max_output_tokens == 1500
mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')