diff --git a/openhands/app_server/sandbox/remote_sandbox_service.py b/openhands/app_server/sandbox/remote_sandbox_service.py index 035870cd45..dc1993440c 100644 --- a/openhands/app_server/sandbox/remote_sandbox_service.py +++ b/openhands/app_server/sandbox/remote_sandbox_service.py @@ -44,6 +44,7 @@ from openhands.app_server.services.injector import InjectorState from openhands.app_server.user.specifiy_user_context import ADMIN, USER_CONTEXT_ATTR from openhands.app_server.user.user_context import UserContext from openhands.app_server.utils.sql_utils import Base, UtcDateTime +from openhands.sdk.utils.paging import page_iterator _logger = logging.getLogger(__name__) WEBHOOK_CALLBACK_VARIABLE = 'OH_WEBHOOKS_0_BASE_URL' @@ -529,32 +530,26 @@ async def poll_agent_servers(api_url: str, api_key: str, sleep_interval: int): get_event_callback_service(state) as event_callback_service, get_httpx_client(state) as httpx_client, ): - page_id = None matches = 0 - while True: - page = await app_conversation_info_service.search_app_conversation_info( - page_id=page_id + async for app_conversation_info in page_iterator( + app_conversation_info_service.search_app_conversation_info + ): + runtime = runtimes_by_sandbox_id.get( + app_conversation_info.sandbox_id ) - for app_conversation_info in page.items: - runtime = runtimes_by_sandbox_id.get( - app_conversation_info.sandbox_id + if runtime: + matches += 1 + await refresh_conversation( + app_conversation_info_service=app_conversation_info_service, + event_service=event_service, + event_callback_service=event_callback_service, + app_conversation_info=app_conversation_info, + runtime=runtime, + httpx_client=httpx_client, ) - if runtime: - matches += 1 - await refresh_conversation( - app_conversation_info_service=app_conversation_info_service, - event_service=event_service, - event_callback_service=event_callback_service, - app_conversation_info=app_conversation_info, - runtime=runtime, - httpx_client=httpx_client, - ) - page_id = page.next_page_id - if page_id is None: - _logger.debug( - f'Matched {len(runtimes_by_sandbox_id)} Runtimes with {matches} Conversations.' - ) - break + _logger.debug( + f'Matched {len(runtimes_by_sandbox_id)} Runtimes with {matches} Conversations.' + ) except Exception as exc: _logger.exception( @@ -608,37 +603,29 @@ async def refresh_conversation( event_url = ( f'{url}/api/conversations/{app_conversation_info.id.hex}/events/search' ) - page_id = None - while True: + + async def fetch_events_page(page_id: str | None = None) -> EventPage: + """Helper function to fetch a page of events from the agent server.""" params: dict[str, str] = {} if page_id: - params['page_id'] = page_id # type: ignore[unreachable] + params['page_id'] = page_id response = await httpx_client.get( event_url, params=params, headers={'X-Session-API-Key': runtime['session_api_key']}, ) response.raise_for_status() - page = EventPage.model_validate(response.json()) + return EventPage.model_validate(response.json()) - to_process = [] - for event in page.items: - existing = await event_service.get_event(event.id) - if existing is None: - await event_service.save_event(app_conversation_info.id, event) - to_process.append(event) - - for event in to_process: + async for event in page_iterator(fetch_events_page): + existing = await event_service.get_event(event.id) + if existing is None: + await event_service.save_event(app_conversation_info.id, event) await event_callback_service.execute_callbacks( app_conversation_info.id, event ) - page_id = page.next_page_id - if page_id is None: - _logger.debug( - f'Finished Refreshing Conversation {app_conversation_info.id}' - ) - break + _logger.debug(f'Finished Refreshing Conversation {app_conversation_info.id}') except Exception as exc: _logger.exception(f'Error Refreshing Conversation: {exc}', stack_info=True) diff --git a/openhands/app_server/sandbox/sandbox_service.py b/openhands/app_server/sandbox/sandbox_service.py index b1144a47cc..e7f8afe9c2 100644 --- a/openhands/app_server/sandbox/sandbox_service.py +++ b/openhands/app_server/sandbox/sandbox_service.py @@ -8,6 +8,7 @@ from openhands.app_server.sandbox.sandbox_models import ( ) from openhands.app_server.services.injector import Injector from openhands.sdk.utils.models import DiscriminatedUnionMixin +from openhands.sdk.utils.paging import page_iterator class SandboxService(ABC): @@ -83,24 +84,11 @@ class SandboxService(ABC): if max_num_sandboxes <= 0: raise ValueError('max_num_sandboxes must be greater than 0') - # Get all sandboxes (we'll search through all pages) - all_sandboxes = [] - page_id = None - - while True: - page = await self.search_sandboxes(page_id=page_id, limit=100) - all_sandboxes.extend(page.items) - - if page.next_page_id is None: - break - page_id = page.next_page_id - - # Filter to only running sandboxes - running_sandboxes = [ - sandbox - for sandbox in all_sandboxes - if sandbox.status == SandboxStatus.RUNNING - ] + # Get all running sandboxes (iterate through all pages) + running_sandboxes = [] + async for sandbox in page_iterator(self.search_sandboxes, limit=100): + if sandbox.status == SandboxStatus.RUNNING: + running_sandboxes.append(sandbox) # If we're within the limit, no cleanup needed if len(running_sandboxes) <= max_num_sandboxes: