From 250fcbe62c340525c556eacb4b1a3086efecbec3 Mon Sep 17 00:00:00 2001 From: Robert Brennan Date: Mon, 4 Nov 2024 07:08:09 -0800 Subject: [PATCH] Various async fixes (#4722) --- openhands/events/stream.py | 20 +- openhands/llm/llm.py | 177 ++++++++++-------- .../impl/eventstream/eventstream_runtime.py | 2 +- openhands/server/listen.py | 18 +- openhands/server/session/agent_session.py | 18 +- openhands/server/session/manager.py | 6 +- openhands/server/session/session.py | 16 +- tests/unit/test_llm.py | 3 + 8 files changed, 154 insertions(+), 106 deletions(-) diff --git a/openhands/events/stream.py b/openhands/events/stream.py index aafbcc2fc8..c2a335c3b5 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -11,6 +11,7 @@ from openhands.events.event import Event, EventSource from openhands.events.serialization.event import event_from_dict, event_to_dict from openhands.runtime.utils.shutdown_listener import should_continue from openhands.storage import FileStore +from openhands.utils.async_utils import call_sync_from_async class EventStreamSubscriber(str, Enum): @@ -22,14 +23,29 @@ class EventStreamSubscriber(str, Enum): TEST = 'test' -def session_exists(sid: str, file_store: FileStore) -> bool: +async def session_exists(sid: str, file_store: FileStore) -> bool: try: - file_store.list(f'sessions/{sid}') + await call_sync_from_async(file_store.list, f'sessions/{sid}') return True except FileNotFoundError: return False +class AsyncEventStreamWrapper: + def __init__(self, event_stream, *args, **kwargs): + self.event_stream = event_stream + self.args = args + self.kwargs = kwargs + + async def __aiter__(self): + loop = asyncio.get_running_loop() + + # Create an async generator that yields events + for event in self.event_stream.get_events(*self.args, **self.kwargs): + # Run the blocking get_events() in a thread pool + yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore + + @dataclass class EventStream: sid: str diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index ab56418e8a..f431fff555 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -82,6 +82,7 @@ class LLM(RetryMixin, DebugMixin): config: The LLM configuration. metrics: The metrics to use. """ + self._tried_model_info = False self.metrics: Metrics = ( metrics if metrics is not None else Metrics(model_name=config.model) ) @@ -91,56 +92,6 @@ class LLM(RetryMixin, DebugMixin): # litellm actually uses base Exception here for unknown model self.model_info: ModelInfo | None = None - try: - if self.config.model.startswith('openrouter'): - self.model_info = litellm.get_model_info(self.config.model) - except Exception as e: - logger.debug(f'Error getting model info: {e}') - - if self.config.model.startswith('litellm_proxy/'): - # IF we are using LiteLLM proxy, get model info from LiteLLM proxy - # GET {base_url}/v1/model/info with litellm_model_id as path param - response = requests.get( - f'{self.config.base_url}/v1/model/info', - headers={'Authorization': f'Bearer {self.config.api_key}'}, - ) - resp_json = response.json() - if 'data' not in resp_json: - logger.error( - f'Error getting model info from LiteLLM proxy: {resp_json}' - ) - all_model_info = resp_json.get('data', []) - current_model_info = next( - ( - info - for info in all_model_info - if info['model_name'] - == self.config.model.removeprefix('litellm_proxy/') - ), - None, - ) - if current_model_info: - self.model_info = current_model_info['model_info'] - - # Last two attempts to get model info from NAME - if not self.model_info: - try: - self.model_info = litellm.get_model_info( - self.config.model.split(':')[0] - ) - # noinspection PyBroadException - except Exception: - pass - if not self.model_info: - try: - self.model_info = litellm.get_model_info( - self.config.model.split('/')[-1] - ) - # noinspection PyBroadException - except Exception: - pass - logger.debug(f'Model info: {self.model_info}') - if self.config.log_completions: if self.config.log_completions_folder is None: raise RuntimeError( @@ -148,32 +99,26 @@ class LLM(RetryMixin, DebugMixin): ) os.makedirs(self.config.log_completions_folder, exist_ok=True) - # Set the max tokens in an LM-specific way if not set - if self.config.max_input_tokens is None: - if ( - self.model_info is not None - and 'max_input_tokens' in self.model_info - and isinstance(self.model_info['max_input_tokens'], int) - ): - self.config.max_input_tokens = self.model_info['max_input_tokens'] - else: - # Safe fallback for any potentially viable model - self.config.max_input_tokens = 4096 + self._completion = partial( + litellm_completion, + model=self.config.model, + api_key=self.config.api_key, + base_url=self.config.base_url, + api_version=self.config.api_version, + custom_llm_provider=self.config.custom_llm_provider, + max_tokens=self.config.max_output_tokens, + timeout=self.config.timeout, + temperature=self.config.temperature, + top_p=self.config.top_p, + drop_params=self.config.drop_params, + ) - if self.config.max_output_tokens is None: - # Safe default for any potentially viable model - self.config.max_output_tokens = 4096 - if self.model_info is not None: - # max_output_tokens has precedence over max_tokens, if either exists. - # litellm has models with both, one or none of these 2 parameters! - if 'max_output_tokens' in self.model_info and isinstance( - self.model_info['max_output_tokens'], int - ): - self.config.max_output_tokens = self.model_info['max_output_tokens'] - elif 'max_tokens' in self.model_info and isinstance( - self.model_info['max_tokens'], int - ): - self.config.max_output_tokens = self.model_info['max_tokens'] + if self.vision_is_active(): + logger.debug('LLM: model has vision enabled') + if self.is_caching_prompt_active(): + logger.debug('LLM: caching prompt enabled') + if self.is_function_calling_active(): + logger.debug('LLM: model supports function calling') self._completion = partial( litellm_completion, @@ -207,6 +152,7 @@ class LLM(RetryMixin, DebugMixin): ) def wrapper(*args, **kwargs): """Wrapper for the litellm completion function. Logs the input and output of the completion function.""" + self.init_model_info() messages: list[dict[str, Any]] | dict[str, Any] = [] # some callers might send the model and messages directly @@ -300,6 +246,87 @@ class LLM(RetryMixin, DebugMixin): """ return self._completion + def init_model_info(self): + if self._tried_model_info: + return + self._tried_model_info = True + try: + if self.config.model.startswith('openrouter'): + self.model_info = litellm.get_model_info(self.config.model) + except Exception as e: + logger.debug(f'Error getting model info: {e}') + + if self.config.model.startswith('litellm_proxy/'): + # IF we are using LiteLLM proxy, get model info from LiteLLM proxy + # GET {base_url}/v1/model/info with litellm_model_id as path param + response = requests.get( + f'{self.config.base_url}/v1/model/info', + headers={'Authorization': f'Bearer {self.config.api_key}'}, + ) + resp_json = response.json() + if 'data' not in resp_json: + logger.error( + f'Error getting model info from LiteLLM proxy: {resp_json}' + ) + all_model_info = resp_json.get('data', []) + current_model_info = next( + ( + info + for info in all_model_info + if info['model_name'] + == self.config.model.removeprefix('litellm_proxy/') + ), + None, + ) + if current_model_info: + self.model_info = current_model_info['model_info'] + + # Last two attempts to get model info from NAME + if not self.model_info: + try: + self.model_info = litellm.get_model_info( + self.config.model.split(':')[0] + ) + # noinspection PyBroadException + except Exception: + pass + if not self.model_info: + try: + self.model_info = litellm.get_model_info( + self.config.model.split('/')[-1] + ) + # noinspection PyBroadException + except Exception: + pass + logger.debug(f'Model info: {self.model_info}') + + # Set the max tokens in an LM-specific way if not set + if self.config.max_input_tokens is None: + if ( + self.model_info is not None + and 'max_input_tokens' in self.model_info + and isinstance(self.model_info['max_input_tokens'], int) + ): + self.config.max_input_tokens = self.model_info['max_input_tokens'] + else: + # Safe fallback for any potentially viable model + self.config.max_input_tokens = 4096 + + if self.config.max_output_tokens is None: + # Safe default for any potentially viable model + self.config.max_output_tokens = 4096 + if self.model_info is not None: + # max_output_tokens has precedence over max_tokens, if either exists. + # litellm has models with both, one or none of these 2 parameters! + if 'max_output_tokens' in self.model_info and isinstance( + self.model_info['max_output_tokens'], int + ): + self.config.max_output_tokens = self.model_info['max_output_tokens'] + elif 'max_tokens' in self.model_info and isinstance( + self.model_info['max_tokens'], int + ): + self.config.max_output_tokens = self.model_info['max_tokens'] + def vision_is_active(self): return not self.config.disable_vision and self._supports_vision() diff --git a/openhands/runtime/impl/eventstream/eventstream_runtime.py b/openhands/runtime/impl/eventstream/eventstream_runtime.py index e76f258bac..e4d5f63067 100644 --- a/openhands/runtime/impl/eventstream/eventstream_runtime.py +++ b/openhands/runtime/impl/eventstream/eventstream_runtime.py @@ -240,7 +240,7 @@ class EventStreamRuntime(Runtime): @tenacity.retry( stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(), - wait=tenacity.wait_exponential(multiplier=1, min=4, max=60), + wait=tenacity.wait_fixed(5), ) def _init_container(self): try: diff --git a/openhands/server/listen.py b/openhands/server/listen.py index 5385f0fd26..7ccc704659 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -57,6 +57,7 @@ from openhands.events.observation import ( NullObservation, ) from openhands.events.serialization import event_to_dict +from openhands.events.stream import AsyncEventStreamWrapper from openhands.llm import bedrock from openhands.runtime.base import Runtime from openhands.server.auth import get_sid_from_token, sign_token @@ -339,9 +340,12 @@ async def websocket_endpoint(websocket: WebSocket): latest_event_id = -1 if websocket.query_params.get('latest_event_id'): latest_event_id = int(websocket.query_params.get('latest_event_id')) - for event in session.agent_session.event_stream.get_events( - start_id=latest_event_id + 1 - ): + + async_stream = AsyncEventStreamWrapper( + session.agent_session.event_stream, latest_event_id + 1 + ) + + async for event in async_stream: if isinstance( event, ( @@ -665,9 +669,11 @@ async def submit_feedback(request: Request): # Assuming the storage service is already configured in the backend # and there is a function to handle the storage. body = await request.json() - events = request.state.conversation.event_stream.get_events(filter_hidden=True) + async_stream = AsyncEventStreamWrapper( + request.state.conversation.event_stream, filter_hidden=True + ) trajectory = [] - for event in events: + async for event in async_stream: trajectory.append(event_to_dict(event)) feedback = FeedbackDataModel( email=body.get('email', ''), @@ -678,7 +684,7 @@ async def submit_feedback(request: Request): trajectory=trajectory, ) try: - feedback_data = store_feedback(feedback) + feedback_data = await call_sync_from_async(store_feedback, feedback) return JSONResponse(status_code=200, content=feedback_data) except Exception as e: logger.error(f'Error submitting feedback: {e}') diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 70c88f13ba..cb90dd3984 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -101,7 +101,6 @@ class AgentSession: agent_configs: dict[str, AgentConfig] | None = None, status_message_callback: Optional[Callable] = None, ): - self.loop = asyncio.get_running_loop() self._create_security_analyzer(config.security.security_analyzer) await self._create_runtime( runtime_name=runtime_name, @@ -124,9 +123,14 @@ class AgentSession: self.controller.agent_task = self.controller.start_step_loop() await self.controller.agent_task # type: ignore - async def close(self): + def close(self): """Closes the Agent session""" + self._closed = True + def inner_close(): + asyncio.run(self._close()) + asyncio.get_event_loop().run_in_executor(None, inner_close) + async def _close(self): if self._closed: return if self.controller is not None: @@ -138,16 +142,6 @@ class AgentSession: if self.security_analyzer is not None: await self.security_analyzer.close() - if self.loop: - if self.loop.is_closed(): - logger.debug( - 'Trying to close already closed loop. (It probably never started correctly)' - ) - else: - self.loop.stop() - self.loop = None - - self._closed = True def _create_security_analyzer(self, security_analyzer: str | None): """Creates a SecurityAnalyzer instance that will be used to analyze the agent actions diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index a2e8a688eb..15f7fbde44 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -35,7 +35,7 @@ class SessionManager: def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session: if sid in self._sessions: - asyncio.create_task(self._sessions[sid].close()) + self._sessions[sid].close() self._sessions[sid] = Session( sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config ) @@ -47,7 +47,7 @@ class SessionManager: return self._sessions.get(sid) async def attach_to_conversation(self, sid: str) -> Conversation | None: - if not session_exists(sid, self.file_store): + if not await session_exists(sid, self.file_store): return None c = Conversation(sid, file_store=self.file_store, config=self.config) await c.connect() @@ -87,7 +87,7 @@ class SessionManager: for sid in session_ids_to_remove: to_del_session: Session | None = self._sessions.pop(sid, None) if to_del_session is not None: - await to_del_session.close() + to_del_session.close() logger.debug( f'Session {sid} and related resource have been removed due to inactivity.' ) diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index c16e511d3c..ef58ae052a 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -47,9 +47,9 @@ class Session: self.config = config self.loop = asyncio.get_event_loop() - async def close(self): + def close(self): self.is_alive = False - await self.agent_session.close() + self.agent_session.close() async def loop_recv(self): try: @@ -63,11 +63,11 @@ class Session: continue await self.dispatch(data) except WebSocketDisconnect: - await self.close() - logger.debug('WebSocket disconnected, sid: %s', self.sid) + logger.info('WebSocket disconnected, sid: %s', self.sid) + self.close() except RuntimeError as e: - await self.close() logger.exception('Error in loop_recv: %s', e) + self.close() async def _initialize_agent(self, data: dict): self.agent_session.event_stream.add_event( @@ -171,10 +171,12 @@ class Session: 'Model does not support image upload, change to a different model or try without an image.' ) return - if self.agent_session.loop: + if self.loop: asyncio.run_coroutine_threadsafe( - self._add_event(event, EventSource.USER), self.agent_session.loop + self._add_event(event, EventSource.USER), self.loop ) # type: ignore + else: + raise RuntimeError('No event loop found') async def _add_event(self, event, event_source): self.agent_session.event_stream.add_event(event, EventSource.USER) diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 347d383076..073743ea81 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -50,6 +50,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config): 'max_output_tokens': 2000, } llm = LLM(default_config) + llm.init_model_info() assert llm.config.max_input_tokens == 8000 assert llm.config.max_output_tokens == 2000 @@ -58,6 +59,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config): def test_llm_init_without_model_info(mock_get_model_info, default_config): mock_get_model_info.side_effect = Exception('Model info not available') llm = LLM(default_config) + llm.init_model_info() assert llm.config.max_input_tokens == 4096 assert llm.config.max_output_tokens == 4096 @@ -108,6 +110,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config): 'max_output_tokens': 1500, } llm = LLM(default_config) + llm.init_model_info() assert llm.config.max_input_tokens == 7000 assert llm.config.max_output_tokens == 1500 mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini')