All-1465 Move user conversations (#7340)

This commit is contained in:
chuckbutkus
2025-03-19 16:03:09 -04:00
committed by GitHub
parent 35b70ca915
commit c3d60b31d1
9 changed files with 114 additions and 45 deletions

View File

@@ -102,22 +102,42 @@ class State:
extra_data: dict[str, Any] = field(default_factory=dict)
last_error: str = ''
def save_to_session(self, sid: str, file_store: FileStore):
def save_to_session(self, sid: str, file_store: FileStore, user_id: str | None):
pickled = pickle.dumps(self)
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
encoded = base64.b64encode(pickled).decode('utf-8')
try:
file_store.write(get_conversation_agent_state_filename(sid), encoded)
file_store.write(
get_conversation_agent_state_filename(sid, user_id), encoded
)
# see if state is in old directory. If yes, delete it.
filename = get_conversation_agent_state_filename(sid)
try:
file_store.delete(filename)
except Exception:
pass
except Exception as e:
logger.error(f'Failed to save state to session: {e}')
raise e
@staticmethod
def restore_from_session(sid: str, file_store: FileStore) -> 'State':
def restore_from_session(
sid: str, file_store: FileStore, user_id: str | None = None
) -> 'State':
try:
encoded = file_store.read(get_conversation_agent_state_filename(sid))
encoded = file_store.read(
get_conversation_agent_state_filename(sid, user_id)
)
pickled = base64.b64decode(encoded)
state = pickle.loads(pickled)
except FileNotFoundError:
if user_id:
# see if state is in old directory. If yes, load it.
filename = get_conversation_agent_state_filename(sid)
encoded = file_store.read(filename)
pickled = base64.b64decode(encoded)
state = pickle.loads(pickled)
except Exception as e:
logger.debug(f'Could not restore state from session: {e}')
raise e

View File

@@ -194,7 +194,9 @@ async def run_controller(
if config.file_store is not None and config.file_store != 'memory':
end_state = controller.get_state()
# NOTE: the saved state does not include delegates events
end_state.save_to_session(event_stream.sid, event_stream.file_store)
end_state.save_to_session(
event_stream.sid, event_stream.file_store, event_stream.user_id
)
await controller.close(set_stop_state=False)

View File

@@ -32,9 +32,11 @@ class EventStreamSubscriber(str, Enum):
TEST = 'test'
async def session_exists(sid: str, file_store: FileStore) -> bool:
async def session_exists(
sid: str, file_store: FileStore, user_id: str | None = None
) -> bool:
try:
await call_sync_from_async(file_store.list, get_conversation_dir(sid))
await call_sync_from_async(file_store.list, get_conversation_dir(sid, user_id))
return True
except FileNotFoundError:
return False
@@ -57,6 +59,7 @@ class AsyncEventStreamWrapper:
class EventStream:
sid: str
user_id: str | None
file_store: FileStore
secrets: dict[str, str]
# For each subscriber ID, there is a map of callback functions - useful
@@ -70,9 +73,10 @@ class EventStream:
_thread_pools: dict[str, dict[str, ThreadPoolExecutor]]
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
def __init__(self, sid: str, file_store: FileStore):
def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None):
self.sid = sid
self.file_store = file_store
self.user_id = user_id
self._stop_flag = threading.Event()
self._queue: queue.Queue[Event] = queue.Queue()
self._thread_pools = {}
@@ -90,10 +94,24 @@ class EventStream:
self.__post_init__()
def __post_init__(self) -> None:
events = []
try:
events = self.file_store.list(get_conversation_events_dir(self.sid))
events_dir = get_conversation_events_dir(self.sid, self.user_id)
events += self.file_store.list(events_dir)
except FileNotFoundError:
logger.debug(f'No events found for session {self.sid}')
logger.debug(f'No events found for session {self.sid} at {events_dir}')
if self.user_id:
# During transition to new location, try old location if user_id is set
# TODO: remove this code after 5/1/2025
try:
events_dir = get_conversation_events_dir(self.sid)
events += self.file_store.list(events_dir)
except FileNotFoundError:
logger.debug(f'No events found for session {self.sid} at {events_dir}')
if not events:
self._cur_id = 0
return
@@ -156,8 +174,8 @@ class EventStream:
del self._subscribers[subscriber_id][callback_id]
def _get_filename_for_id(self, id: int) -> str:
return get_conversation_event_filename(self.sid, id)
def _get_filename_for_id(self, id: int, user_id: str | None) -> str:
return get_conversation_event_filename(self.sid, id, user_id)
@staticmethod
def _get_id_from_filename(filename: str) -> int:
@@ -223,10 +241,20 @@ class EventStream:
event_id += 1
def get_event(self, id: int) -> Event:
filename = self._get_filename_for_id(id)
content = self.file_store.read(filename)
data = json.loads(content)
return event_from_dict(data)
filename = self._get_filename_for_id(id, self.user_id)
try:
content = self.file_store.read(filename)
data = json.loads(content)
return event_from_dict(data)
except FileNotFoundError:
logger.debug(f'File {filename} not found')
# TODO remove this block after 5/1/2025
if self.user_id:
filename = self._get_filename_for_id(id, None)
content = self.file_store.read(filename)
data = json.loads(content)
return event_from_dict(data)
raise
def get_latest_event(self) -> Event:
return self.get_event(self._cur_id - 1)
@@ -277,7 +305,9 @@ class EventStream:
data = self._replace_secrets(data)
event = event_from_dict(data)
if event.id is not None:
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
self.file_store.write(
self._get_filename_for_id(event.id, self.user_id), json.dumps(data)
)
self._queue.put(event)
def set_secrets(self, secrets: dict[str, str]):

View File

@@ -37,7 +37,9 @@ class ConversationManager(ABC):
"""Clean up the conversation manager."""
@abstractmethod
async def attach_to_conversation(self, sid: str) -> Conversation | None:
async def attach_to_conversation(
self, sid: str, user_id: str | None = None
) -> Conversation | None:
"""Attach to an existing conversation or create a new one."""
@abstractmethod

View File

@@ -63,9 +63,11 @@ class StandaloneConversationManager(ConversationManager):
self._cleanup_task.cancel()
self._cleanup_task = None
async def attach_to_conversation(self, sid: str) -> Conversation | None:
async def attach_to_conversation(
self, sid: str, user_id: str | None = None
) -> Conversation | None:
start_time = time.time()
if not await session_exists(sid, self.file_store):
if not await session_exists(sid, self.file_store, user_id=user_id):
return None
async with self._conversations_lock:
@@ -88,7 +90,9 @@ class StandaloneConversationManager(ConversationManager):
return conversation
# Create new conversation if none exists
c = Conversation(sid, file_store=self.file_store, config=self.config)
c = Conversation(
sid, file_store=self.file_store, config=self.config, user_id=user_id
)
try:
await c.connect()
except AgentRuntimeUnavailableError as e:
@@ -119,7 +123,7 @@ class StandaloneConversationManager(ConversationManager):
)
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self._local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid)
event_stream = await self._get_event_stream(sid, user_id)
if not event_stream:
return await self.maybe_start_agent_loop(
sid, settings, user_id, github_user_id=github_user_id
@@ -299,7 +303,7 @@ class StandaloneConversationManager(ConversationManager):
except ValueError:
pass # Already subscribed - take no action
event_stream = await self._get_event_stream(sid)
event_stream = await self._get_event_stream(sid, user_id)
if not event_stream:
logger.error(
f'No event stream after starting agent loop: {sid}',
@@ -308,7 +312,9 @@ class StandaloneConversationManager(ConversationManager):
raise RuntimeError(f'no_event_stream:{sid}')
return event_stream
async def _get_event_stream(self, sid: str) -> EventStream | None:
async def _get_event_stream(
self, sid: str, user_id: str | None
) -> EventStream | None:
logger.info(f'_get_event_stream:{sid}', extra={'session_id': sid})
session = self._local_agent_loops_by_sid.get(sid)
if session:

View File

@@ -148,7 +148,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
Attach the user's session based on the provided authentication token.
"""
request.state.conversation = (
await shared.conversation_manager.attach_to_conversation(request.state.sid)
await shared.conversation_manager.attach_to_conversation(
request.state.sid, get_user_id(request)
)
)
if not request.state.conversation:
return JSONResponse(

View File

@@ -37,6 +37,7 @@ class AgentSession:
"""
sid: str
user_id: str | None
event_stream: EventStream
file_store: FileStore
controller: AgentController | None = None
@@ -63,7 +64,7 @@ class AgentSession:
"""
self.sid = sid
self.event_stream = EventStream(sid, file_store)
self.event_stream = EventStream(sid, file_store, user_id)
self.file_store = file_store
self._status_callback = status_callback
self.user_id = user_id
@@ -186,7 +187,7 @@ class AgentSession:
self.event_stream.close()
if self.controller is not None:
end_state = self.controller.get_state()
end_state.save_to_session(self.sid, self.file_store)
end_state.save_to_session(self.sid, self.file_store, self.user_id)
await self.controller.close()
if self.runtime is not None:
self.runtime.close()
@@ -371,7 +372,9 @@ class AgentSession:
# Use a heuristic to figure out if we should have a state:
# if we have events in the stream.
try:
restored_state = State.restore_from_session(self.sid, self.file_store)
restored_state = State.restore_from_session(
self.sid, self.file_store, self.user_id
)
self.logger.debug(f'Restored state from session, sid: {self.sid}')
except Exception as e:
if self.event_stream.get_latest_event_id() > 0:

View File

@@ -14,17 +14,16 @@ class Conversation:
file_store: FileStore
event_stream: EventStream
runtime: Runtime
user_id: str | None
def __init__(
self,
sid: str,
file_store: FileStore,
config: AppConfig,
self, sid: str, file_store: FileStore, config: AppConfig, user_id: str | None
):
self.sid = sid
self.config = config
self.file_store = file_store
self.event_stream = EventStream(sid, file_store)
self.user_id = user_id
self.event_stream = EventStream(sid, file_store, user_id)
if config.security.security_analyzer:
self.security_analyzer = options.SecurityAnalyzers.get(
config.security.security_analyzer, SecurityAnalyzer

View File

@@ -1,25 +1,30 @@
CONVERSATION_BASE_DIR = 'sessions'
def get_conversation_dir(sid: str) -> str:
return f'{CONVERSATION_BASE_DIR}/{sid}/'
def get_conversation_dir(sid: str, user_id: str | None = None) -> str:
if user_id:
return f'users/{user_id}/conversations/{sid}/'
else:
return f'{CONVERSATION_BASE_DIR}/{sid}/'
def get_conversation_events_dir(sid: str) -> str:
return f'{get_conversation_dir(sid)}events/'
def get_conversation_events_dir(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}events/'
def get_conversation_event_filename(sid: str, id: int) -> str:
return f'{get_conversation_events_dir(sid)}{id}.json'
def get_conversation_event_filename(
sid: str, id: int, user_id: str | None = None
) -> str:
return f'{get_conversation_events_dir(sid, user_id)}{id}.json'
def get_conversation_metadata_filename(sid: str) -> str:
return f'{get_conversation_dir(sid)}metadata.json'
def get_conversation_metadata_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}metadata.json'
def get_conversation_init_data_filename(sid: str) -> str:
return f'{get_conversation_dir(sid)}init.json'
def get_conversation_init_data_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}init.json'
def get_conversation_agent_state_filename(sid: str) -> str:
return f'{get_conversation_dir(sid)}agent_state.pkl'
def get_conversation_agent_state_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}agent_state.pkl'