mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
All-1465 Move user conversations (#7340)
This commit is contained in:
@@ -102,22 +102,42 @@ class State:
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
last_error: str = ''
|
||||
|
||||
def save_to_session(self, sid: str, file_store: FileStore):
|
||||
def save_to_session(self, sid: str, file_store: FileStore, user_id: str | None):
|
||||
pickled = pickle.dumps(self)
|
||||
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
|
||||
encoded = base64.b64encode(pickled).decode('utf-8')
|
||||
try:
|
||||
file_store.write(get_conversation_agent_state_filename(sid), encoded)
|
||||
file_store.write(
|
||||
get_conversation_agent_state_filename(sid, user_id), encoded
|
||||
)
|
||||
|
||||
# see if state is in old directory. If yes, delete it.
|
||||
filename = get_conversation_agent_state_filename(sid)
|
||||
try:
|
||||
file_store.delete(filename)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to save state to session: {e}')
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def restore_from_session(sid: str, file_store: FileStore) -> 'State':
|
||||
def restore_from_session(
|
||||
sid: str, file_store: FileStore, user_id: str | None = None
|
||||
) -> 'State':
|
||||
try:
|
||||
encoded = file_store.read(get_conversation_agent_state_filename(sid))
|
||||
encoded = file_store.read(
|
||||
get_conversation_agent_state_filename(sid, user_id)
|
||||
)
|
||||
pickled = base64.b64decode(encoded)
|
||||
state = pickle.loads(pickled)
|
||||
except FileNotFoundError:
|
||||
if user_id:
|
||||
# see if state is in old directory. If yes, load it.
|
||||
filename = get_conversation_agent_state_filename(sid)
|
||||
encoded = file_store.read(filename)
|
||||
pickled = base64.b64decode(encoded)
|
||||
state = pickle.loads(pickled)
|
||||
except Exception as e:
|
||||
logger.debug(f'Could not restore state from session: {e}')
|
||||
raise e
|
||||
|
||||
@@ -194,7 +194,9 @@ async def run_controller(
|
||||
if config.file_store is not None and config.file_store != 'memory':
|
||||
end_state = controller.get_state()
|
||||
# NOTE: the saved state does not include delegates events
|
||||
end_state.save_to_session(event_stream.sid, event_stream.file_store)
|
||||
end_state.save_to_session(
|
||||
event_stream.sid, event_stream.file_store, event_stream.user_id
|
||||
)
|
||||
|
||||
await controller.close(set_stop_state=False)
|
||||
|
||||
|
||||
@@ -32,9 +32,11 @@ class EventStreamSubscriber(str, Enum):
|
||||
TEST = 'test'
|
||||
|
||||
|
||||
async def session_exists(sid: str, file_store: FileStore) -> bool:
|
||||
async def session_exists(
|
||||
sid: str, file_store: FileStore, user_id: str | None = None
|
||||
) -> bool:
|
||||
try:
|
||||
await call_sync_from_async(file_store.list, get_conversation_dir(sid))
|
||||
await call_sync_from_async(file_store.list, get_conversation_dir(sid, user_id))
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
@@ -57,6 +59,7 @@ class AsyncEventStreamWrapper:
|
||||
|
||||
class EventStream:
|
||||
sid: str
|
||||
user_id: str | None
|
||||
file_store: FileStore
|
||||
secrets: dict[str, str]
|
||||
# For each subscriber ID, there is a map of callback functions - useful
|
||||
@@ -70,9 +73,10 @@ class EventStream:
|
||||
_thread_pools: dict[str, dict[str, ThreadPoolExecutor]]
|
||||
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
|
||||
|
||||
def __init__(self, sid: str, file_store: FileStore):
|
||||
def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None):
|
||||
self.sid = sid
|
||||
self.file_store = file_store
|
||||
self.user_id = user_id
|
||||
self._stop_flag = threading.Event()
|
||||
self._queue: queue.Queue[Event] = queue.Queue()
|
||||
self._thread_pools = {}
|
||||
@@ -90,10 +94,24 @@ class EventStream:
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
events = []
|
||||
|
||||
try:
|
||||
events = self.file_store.list(get_conversation_events_dir(self.sid))
|
||||
events_dir = get_conversation_events_dir(self.sid, self.user_id)
|
||||
events += self.file_store.list(events_dir)
|
||||
except FileNotFoundError:
|
||||
logger.debug(f'No events found for session {self.sid}')
|
||||
logger.debug(f'No events found for session {self.sid} at {events_dir}')
|
||||
|
||||
if self.user_id:
|
||||
# During transition to new location, try old location if user_id is set
|
||||
# TODO: remove this code after 5/1/2025
|
||||
try:
|
||||
events_dir = get_conversation_events_dir(self.sid)
|
||||
events += self.file_store.list(events_dir)
|
||||
except FileNotFoundError:
|
||||
logger.debug(f'No events found for session {self.sid} at {events_dir}')
|
||||
|
||||
if not events:
|
||||
self._cur_id = 0
|
||||
return
|
||||
|
||||
@@ -156,8 +174,8 @@ class EventStream:
|
||||
|
||||
del self._subscribers[subscriber_id][callback_id]
|
||||
|
||||
def _get_filename_for_id(self, id: int) -> str:
|
||||
return get_conversation_event_filename(self.sid, id)
|
||||
def _get_filename_for_id(self, id: int, user_id: str | None) -> str:
|
||||
return get_conversation_event_filename(self.sid, id, user_id)
|
||||
|
||||
@staticmethod
|
||||
def _get_id_from_filename(filename: str) -> int:
|
||||
@@ -223,10 +241,20 @@ class EventStream:
|
||||
event_id += 1
|
||||
|
||||
def get_event(self, id: int) -> Event:
|
||||
filename = self._get_filename_for_id(id)
|
||||
content = self.file_store.read(filename)
|
||||
data = json.loads(content)
|
||||
return event_from_dict(data)
|
||||
filename = self._get_filename_for_id(id, self.user_id)
|
||||
try:
|
||||
content = self.file_store.read(filename)
|
||||
data = json.loads(content)
|
||||
return event_from_dict(data)
|
||||
except FileNotFoundError:
|
||||
logger.debug(f'File {filename} not found')
|
||||
# TODO remove this block after 5/1/2025
|
||||
if self.user_id:
|
||||
filename = self._get_filename_for_id(id, None)
|
||||
content = self.file_store.read(filename)
|
||||
data = json.loads(content)
|
||||
return event_from_dict(data)
|
||||
raise
|
||||
|
||||
def get_latest_event(self) -> Event:
|
||||
return self.get_event(self._cur_id - 1)
|
||||
@@ -277,7 +305,9 @@ class EventStream:
|
||||
data = self._replace_secrets(data)
|
||||
event = event_from_dict(data)
|
||||
if event.id is not None:
|
||||
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
|
||||
self.file_store.write(
|
||||
self._get_filename_for_id(event.id, self.user_id), json.dumps(data)
|
||||
)
|
||||
self._queue.put(event)
|
||||
|
||||
def set_secrets(self, secrets: dict[str, str]):
|
||||
|
||||
@@ -37,7 +37,9 @@ class ConversationManager(ABC):
|
||||
"""Clean up the conversation manager."""
|
||||
|
||||
@abstractmethod
|
||||
async def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
async def attach_to_conversation(
|
||||
self, sid: str, user_id: str | None = None
|
||||
) -> Conversation | None:
|
||||
"""Attach to an existing conversation or create a new one."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -63,9 +63,11 @@ class StandaloneConversationManager(ConversationManager):
|
||||
self._cleanup_task.cancel()
|
||||
self._cleanup_task = None
|
||||
|
||||
async def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
async def attach_to_conversation(
|
||||
self, sid: str, user_id: str | None = None
|
||||
) -> Conversation | None:
|
||||
start_time = time.time()
|
||||
if not await session_exists(sid, self.file_store):
|
||||
if not await session_exists(sid, self.file_store, user_id=user_id):
|
||||
return None
|
||||
|
||||
async with self._conversations_lock:
|
||||
@@ -88,7 +90,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
return conversation
|
||||
|
||||
# Create new conversation if none exists
|
||||
c = Conversation(sid, file_store=self.file_store, config=self.config)
|
||||
c = Conversation(
|
||||
sid, file_store=self.file_store, config=self.config, user_id=user_id
|
||||
)
|
||||
try:
|
||||
await c.connect()
|
||||
except AgentRuntimeUnavailableError as e:
|
||||
@@ -119,7 +123,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
|
||||
self._local_connection_id_to_session_id[connection_id] = sid
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
event_stream = await self._get_event_stream(sid, user_id)
|
||||
if not event_stream:
|
||||
return await self.maybe_start_agent_loop(
|
||||
sid, settings, user_id, github_user_id=github_user_id
|
||||
@@ -299,7 +303,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
except ValueError:
|
||||
pass # Already subscribed - take no action
|
||||
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
event_stream = await self._get_event_stream(sid, user_id)
|
||||
if not event_stream:
|
||||
logger.error(
|
||||
f'No event stream after starting agent loop: {sid}',
|
||||
@@ -308,7 +312,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
raise RuntimeError(f'no_event_stream:{sid}')
|
||||
return event_stream
|
||||
|
||||
async def _get_event_stream(self, sid: str) -> EventStream | None:
|
||||
async def _get_event_stream(
|
||||
self, sid: str, user_id: str | None
|
||||
) -> EventStream | None:
|
||||
logger.info(f'_get_event_stream:{sid}', extra={'session_id': sid})
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
|
||||
@@ -148,7 +148,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
Attach the user's session based on the provided authentication token.
|
||||
"""
|
||||
request.state.conversation = (
|
||||
await shared.conversation_manager.attach_to_conversation(request.state.sid)
|
||||
await shared.conversation_manager.attach_to_conversation(
|
||||
request.state.sid, get_user_id(request)
|
||||
)
|
||||
)
|
||||
if not request.state.conversation:
|
||||
return JSONResponse(
|
||||
|
||||
@@ -37,6 +37,7 @@ class AgentSession:
|
||||
"""
|
||||
|
||||
sid: str
|
||||
user_id: str | None
|
||||
event_stream: EventStream
|
||||
file_store: FileStore
|
||||
controller: AgentController | None = None
|
||||
@@ -63,7 +64,7 @@ class AgentSession:
|
||||
"""
|
||||
|
||||
self.sid = sid
|
||||
self.event_stream = EventStream(sid, file_store)
|
||||
self.event_stream = EventStream(sid, file_store, user_id)
|
||||
self.file_store = file_store
|
||||
self._status_callback = status_callback
|
||||
self.user_id = user_id
|
||||
@@ -186,7 +187,7 @@ class AgentSession:
|
||||
self.event_stream.close()
|
||||
if self.controller is not None:
|
||||
end_state = self.controller.get_state()
|
||||
end_state.save_to_session(self.sid, self.file_store)
|
||||
end_state.save_to_session(self.sid, self.file_store, self.user_id)
|
||||
await self.controller.close()
|
||||
if self.runtime is not None:
|
||||
self.runtime.close()
|
||||
@@ -371,7 +372,9 @@ class AgentSession:
|
||||
# Use a heuristic to figure out if we should have a state:
|
||||
# if we have events in the stream.
|
||||
try:
|
||||
restored_state = State.restore_from_session(self.sid, self.file_store)
|
||||
restored_state = State.restore_from_session(
|
||||
self.sid, self.file_store, self.user_id
|
||||
)
|
||||
self.logger.debug(f'Restored state from session, sid: {self.sid}')
|
||||
except Exception as e:
|
||||
if self.event_stream.get_latest_event_id() > 0:
|
||||
|
||||
@@ -14,17 +14,16 @@ class Conversation:
|
||||
file_store: FileStore
|
||||
event_stream: EventStream
|
||||
runtime: Runtime
|
||||
user_id: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sid: str,
|
||||
file_store: FileStore,
|
||||
config: AppConfig,
|
||||
self, sid: str, file_store: FileStore, config: AppConfig, user_id: str | None
|
||||
):
|
||||
self.sid = sid
|
||||
self.config = config
|
||||
self.file_store = file_store
|
||||
self.event_stream = EventStream(sid, file_store)
|
||||
self.user_id = user_id
|
||||
self.event_stream = EventStream(sid, file_store, user_id)
|
||||
if config.security.security_analyzer:
|
||||
self.security_analyzer = options.SecurityAnalyzers.get(
|
||||
config.security.security_analyzer, SecurityAnalyzer
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
CONVERSATION_BASE_DIR = 'sessions'
|
||||
|
||||
|
||||
def get_conversation_dir(sid: str) -> str:
|
||||
return f'{CONVERSATION_BASE_DIR}/{sid}/'
|
||||
def get_conversation_dir(sid: str, user_id: str | None = None) -> str:
|
||||
if user_id:
|
||||
return f'users/{user_id}/conversations/{sid}/'
|
||||
else:
|
||||
return f'{CONVERSATION_BASE_DIR}/{sid}/'
|
||||
|
||||
|
||||
def get_conversation_events_dir(sid: str) -> str:
|
||||
return f'{get_conversation_dir(sid)}events/'
|
||||
def get_conversation_events_dir(sid: str, user_id: str | None = None) -> str:
|
||||
return f'{get_conversation_dir(sid, user_id)}events/'
|
||||
|
||||
|
||||
def get_conversation_event_filename(sid: str, id: int) -> str:
|
||||
return f'{get_conversation_events_dir(sid)}{id}.json'
|
||||
def get_conversation_event_filename(
|
||||
sid: str, id: int, user_id: str | None = None
|
||||
) -> str:
|
||||
return f'{get_conversation_events_dir(sid, user_id)}{id}.json'
|
||||
|
||||
|
||||
def get_conversation_metadata_filename(sid: str) -> str:
|
||||
return f'{get_conversation_dir(sid)}metadata.json'
|
||||
def get_conversation_metadata_filename(sid: str, user_id: str | None = None) -> str:
|
||||
return f'{get_conversation_dir(sid, user_id)}metadata.json'
|
||||
|
||||
|
||||
def get_conversation_init_data_filename(sid: str) -> str:
|
||||
return f'{get_conversation_dir(sid)}init.json'
|
||||
def get_conversation_init_data_filename(sid: str, user_id: str | None = None) -> str:
|
||||
return f'{get_conversation_dir(sid, user_id)}init.json'
|
||||
|
||||
|
||||
def get_conversation_agent_state_filename(sid: str) -> str:
|
||||
return f'{get_conversation_dir(sid)}agent_state.pkl'
|
||||
def get_conversation_agent_state_filename(sid: str, user_id: str | None = None) -> str:
|
||||
return f'{get_conversation_dir(sid, user_id)}agent_state.pkl'
|
||||
|
||||
Reference in New Issue
Block a user