Fix restore cli sessions (#3409)

* fix restore cli sessions

* pytest

* fix log message

* make sure sid is set

---------

Co-authored-by: mamoodi <mamoodiha@gmail.com>
This commit is contained in:
Engel Nyst 2024-08-17 23:31:42 +02:00 committed by GitHub
parent 8d7bf83224
commit 9cb0bf97c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 8 deletions

View File

@ -665,7 +665,11 @@ def get_parser() -> argparse.ArgumentParser:
help='The working directory for the agent',
)
parser.add_argument(
'-t', '--task', type=str, default='', help='The task for the agent to perform'
'-t',
'--task',
type=str,
default='',
help='The task for the agent to perform',
)
parser.add_argument(
'-f',
@ -725,6 +729,13 @@ def get_parser() -> argparse.ArgumentParser:
type=str,
help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml',
)
parser.add_argument(
'-n',
'--name',
default='default',
type=str,
help='Name for the session',
)
return parser

View File

@ -1,4 +1,5 @@
import asyncio
import hashlib
import sys
import uuid
from typing import Callable, Type
@ -47,9 +48,13 @@ async def create_runtime(
sid: The session id.
runtime_tools_config: (will be deprecated) The runtime tools config.
"""
# if sid is provided on the command line, use it as the name of the event stream
# otherwise generate it on the basis of the configured jwt_secret
# we can do this better, this is just so that the sid is retrieved when we want to restore the session
session_id = sid or generate_sid(config)
# set up the event stream
file_store = get_file_store(config.file_store, config.file_store_path)
session_id = 'main' + ('_' + sid if sid else str(uuid.uuid4()))
event_stream = EventStream(session_id, file_store)
# agent class
@ -72,6 +77,7 @@ async def create_runtime(
async def run_controller(
config: AppConfig,
task_str: str,
sid: str | None = None,
runtime: Runtime | None = None,
agent: Agent | None = None,
exit_on_message: bool = False,
@ -100,15 +106,18 @@ async def run_controller(
config=agent_config,
)
# make sure the session id is set
sid = sid or generate_sid(config)
if runtime is None:
runtime = await create_runtime(config)
runtime = await create_runtime(config, sid=sid)
event_stream = runtime.event_stream
# restore cli session if enabled
initial_state = None
if config.enable_cli_session:
try:
logger.info('Restoring agent state from cli session')
logger.info(f'Restoring agent state from cli session {event_stream.sid}')
initial_state = State.restore_from_session(
event_stream.sid, event_stream.file_store
)
@ -179,6 +188,15 @@ async def run_controller(
return state
def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
"""Generate a session id based on the session name and the jwt secret."""
session_name = session_name or str(uuid.uuid4())
jwt_secret = config.jwt_secret
hash_str = hashlib.sha256(f'{session_name}{jwt_secret}'.encode('utf-8')).hexdigest()
return f'{session_name}_{hash_str[:16]}'
if __name__ == '__main__':
args = parse_arguments()
@ -207,6 +225,10 @@ if __name__ == '__main__':
# Set default agent
config.default_agent = args.agent_cls
# Set session name
session_name = args.name
sid = generate_sid(config, session_name)
# if max budget per task is not sent on the command line, use the config value
if args.max_budget_per_task is not None:
config.max_budget_per_task = args.max_budget_per_task
@ -217,5 +239,6 @@ if __name__ == '__main__':
run_controller(
config=config,
task_str=task_str,
sid=sid,
)
)

View File

@ -57,7 +57,7 @@ class EventStreamRuntime(Runtime):
self.session: aiohttp.ClientSession | None = None
self.instance_id = (
sid + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
)
# TODO: We can switch to aiodocker when `get_od_sandbox_image` is updated to use aiodocker
self.docker_client: docker.DockerClient = self._init_docker_client()
@ -193,7 +193,7 @@ class EventStreamRuntime(Runtime):
wait=tenacity.wait_exponential(multiplier=2, min=10, max=60),
)
async def _wait_until_alive(self):
logger.info('Reconnecting session')
logger.debug('Getting container logs...')
container = self.docker_client.containers.get(self.container_name)
# get logs
_logs = container.logs(tail=10).decode('utf-8').split('\n')

View File

@ -135,4 +135,4 @@ class AgentSession:
)
logger.info(f'Restored agent state from session, sid: {self.sid}')
except Exception as e:
print('Error restoring state', e)
logger.info(f'Error restoring state: {e}')

View File

@ -14,7 +14,7 @@ usage: pytest [-h] [-d DIRECTORY] [-t TASK] [-f FILE] [-c AGENT_CLS]
[--eval-output-dir EVAL_OUTPUT_DIR]
[--eval-n-limit EVAL_N_LIMIT]
[--eval-num-workers EVAL_NUM_WORKERS] [--eval-note EVAL_NOTE]
[-l LLM_CONFIG]
[-l LLM_CONFIG] [-n NAME]
Run an agent with a specific task
@ -44,6 +44,7 @@ options:
Replace default LLM ([llm] section in config.toml)
config with the specified LLM config, e.g. "llama3"
for [llm.llama3] section in config.toml
-n NAME, --name NAME Name for the session
"""
actual_lines = captured.out.strip().split('\n')