From 9cb0bf97c1212c801037a39a9222da098d3b53d7 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Sat, 17 Aug 2024 23:31:42 +0200 Subject: [PATCH] Fix restore cli sessions (#3409) * fix restore cli sessions * pytest * fix log message * make sure sid is set --------- Co-authored-by: mamoodi --- opendevin/core/config.py | 13 ++++++++++++- opendevin/core/main.py | 29 ++++++++++++++++++++++++++--- opendevin/runtime/client/runtime.py | 4 ++-- opendevin/server/session/agent.py | 2 +- tests/unit/test_arg_parser.py | 3 ++- 5 files changed, 43 insertions(+), 8 deletions(-) diff --git a/opendevin/core/config.py b/opendevin/core/config.py index 69b482fdd9..d40aa09718 100644 --- a/opendevin/core/config.py +++ b/opendevin/core/config.py @@ -665,7 +665,11 @@ def get_parser() -> argparse.ArgumentParser: help='The working directory for the agent', ) parser.add_argument( - '-t', '--task', type=str, default='', help='The task for the agent to perform' + '-t', + '--task', + type=str, + default='', + help='The task for the agent to perform', ) parser.add_argument( '-f', @@ -725,6 +729,13 @@ def get_parser() -> argparse.ArgumentParser: type=str, help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml', ) + parser.add_argument( + '-n', + '--name', + default='default', + type=str, + help='Name for the session', + ) return parser diff --git a/opendevin/core/main.py b/opendevin/core/main.py index 56bbef657b..75d45021ff 100644 --- a/opendevin/core/main.py +++ b/opendevin/core/main.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import sys import uuid from typing import Callable, Type @@ -47,9 +48,13 @@ async def create_runtime( sid: The session id. runtime_tools_config: (will be deprecated) The runtime tools config. """ + # if sid is provided on the command line, use it as the name of the event stream + # otherwise generate it on the basis of the configured jwt_secret + # we can do this better, this is just so that the sid is retrieved when we want to restore the session + session_id = sid or generate_sid(config) + # set up the event stream file_store = get_file_store(config.file_store, config.file_store_path) - session_id = 'main' + ('_' + sid if sid else str(uuid.uuid4())) event_stream = EventStream(session_id, file_store) # agent class @@ -72,6 +77,7 @@ async def create_runtime( async def run_controller( config: AppConfig, task_str: str, + sid: str | None = None, runtime: Runtime | None = None, agent: Agent | None = None, exit_on_message: bool = False, @@ -100,15 +106,18 @@ async def run_controller( config=agent_config, ) + # make sure the session id is set + sid = sid or generate_sid(config) + if runtime is None: - runtime = await create_runtime(config) + runtime = await create_runtime(config, sid=sid) event_stream = runtime.event_stream # restore cli session if enabled initial_state = None if config.enable_cli_session: try: - logger.info('Restoring agent state from cli session') + logger.info(f'Restoring agent state from cli session {event_stream.sid}') initial_state = State.restore_from_session( event_stream.sid, event_stream.file_store ) @@ -179,6 +188,15 @@ async def run_controller( return state +def generate_sid(config: AppConfig, session_name: str | None = None) -> str: + """Generate a session id based on the session name and the jwt secret.""" + session_name = session_name or str(uuid.uuid4()) + jwt_secret = config.jwt_secret + + hash_str = hashlib.sha256(f'{session_name}{jwt_secret}'.encode('utf-8')).hexdigest() + return f'{session_name}_{hash_str[:16]}' + + if __name__ == '__main__': args = parse_arguments() @@ -207,6 +225,10 @@ if __name__ == '__main__': # Set default agent config.default_agent = args.agent_cls + # Set session name + session_name = args.name + sid = generate_sid(config, session_name) + # if max budget per task is not sent on the command line, use the config value if args.max_budget_per_task is not None: config.max_budget_per_task = args.max_budget_per_task @@ -217,5 +239,6 @@ if __name__ == '__main__': run_controller( config=config, task_str=task_str, + sid=sid, ) ) diff --git a/opendevin/runtime/client/runtime.py b/opendevin/runtime/client/runtime.py index 220e011528..04543c0a6b 100644 --- a/opendevin/runtime/client/runtime.py +++ b/opendevin/runtime/client/runtime.py @@ -57,7 +57,7 @@ class EventStreamRuntime(Runtime): self.session: aiohttp.ClientSession | None = None self.instance_id = ( - sid + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4()) + sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4()) ) # TODO: We can switch to aiodocker when `get_od_sandbox_image` is updated to use aiodocker self.docker_client: docker.DockerClient = self._init_docker_client() @@ -193,7 +193,7 @@ class EventStreamRuntime(Runtime): wait=tenacity.wait_exponential(multiplier=2, min=10, max=60), ) async def _wait_until_alive(self): - logger.info('Reconnecting session') + logger.debug('Getting container logs...') container = self.docker_client.containers.get(self.container_name) # get logs _logs = container.logs(tail=10).decode('utf-8').split('\n') diff --git a/opendevin/server/session/agent.py b/opendevin/server/session/agent.py index 82f14415e2..7c0d8b51ce 100644 --- a/opendevin/server/session/agent.py +++ b/opendevin/server/session/agent.py @@ -135,4 +135,4 @@ class AgentSession: ) logger.info(f'Restored agent state from session, sid: {self.sid}') except Exception as e: - print('Error restoring state', e) + logger.info(f'Error restoring state: {e}') diff --git a/tests/unit/test_arg_parser.py b/tests/unit/test_arg_parser.py index 2e2aed38c8..2c8714cf36 100644 --- a/tests/unit/test_arg_parser.py +++ b/tests/unit/test_arg_parser.py @@ -14,7 +14,7 @@ usage: pytest [-h] [-d DIRECTORY] [-t TASK] [-f FILE] [-c AGENT_CLS] [--eval-output-dir EVAL_OUTPUT_DIR] [--eval-n-limit EVAL_N_LIMIT] [--eval-num-workers EVAL_NUM_WORKERS] [--eval-note EVAL_NOTE] - [-l LLM_CONFIG] + [-l LLM_CONFIG] [-n NAME] Run an agent with a specific task @@ -44,6 +44,7 @@ options: Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml + -n NAME, --name NAME Name for the session """ actual_lines = captured.out.strip().split('\n')