[Fix]: Ensure refresh logic works for restarted conversations in cloud openhands (#7670)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra 2025-04-03 11:58:01 -04:00 committed by GitHub
parent d3043ec898
commit f5fa076fdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 22 deletions

View File

@ -8,6 +8,7 @@ import {
WsClientProvider,
useWsClient,
} from "#/context/ws-client-provider";
import { AuthProvider } from "#/context/auth-context";
describe("Propagate error message", () => {
it("should do nothing when no message was passed from server", () => {
@ -90,9 +91,11 @@ describe("WsClientProvider", () => {
const { getByText } = render(<TestComponent />, {
wrapper: ({ children }) => (
<QueryClientProvider client={new QueryClient()}>
<WsClientProvider conversationId="test-conversation-id">
{children}
</WsClientProvider>
<AuthProvider initialProviderTokens={[]}>
<WsClientProvider conversationId="test-conversation-id">
{children}
</WsClientProvider>
</AuthProvider>
</QueryClientProvider>
),
});

View File

@ -9,6 +9,7 @@ import {
AssistantMessageAction,
UserMessageAction,
} from "#/types/core/actions";
import { useAuth } from "./auth-context";
const isOpenHandsEvent = (event: unknown): event is OpenHandsParsedEvent =>
typeof event === "object" &&
@ -110,6 +111,7 @@ export function WsClientProvider({
);
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
const lastEventRef = React.useRef<Record<string, unknown> | null>(null);
const { providerTokensSet } = useAuth();
const messageRateHandler = useRate({ threshold: 250 });
@ -168,6 +170,7 @@ export function WsClientProvider({
const query = {
latest_event_id: lastEvent?.id ?? -1,
conversation_id: conversationId,
providers_set: providerTokensSet,
};
const baseUrl =

View File

@ -14,7 +14,6 @@ from pydantic import (
)
from pydantic.json import pydantic_encoder
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.action.commands import CmdRunAction
from openhands.events.stream import EventStream
@ -293,9 +292,7 @@ class ProviderHandler:
get_latest: Get the latest working token for the providers if True, otherwise get the existing ones
"""
# TODO: We should remove `not get_latest` in the future. More
# details about the error this fixes is in the next comment below
if not self.provider_tokens and not get_latest:
if not self.provider_tokens:
return {}
env_vars: dict[ProviderType, SecretStr] = {}
@ -316,20 +313,6 @@ class ProviderHandler:
if token:
env_vars[provider] = token
# TODO: we have an error where reinitializing the runtime doesn't happen with
# the provider tokens; thus the code above believes that github isn't a provider
# when it really is. We need to share information about current providers set
# for the user when the socket event for connect is sent
if ProviderType.GITHUB not in env_vars and get_latest:
logger.info(
f'Force refresh runtime token for user: {self.external_auth_id}'
)
service = GithubServiceImpl(
external_auth_id=self.external_auth_id,
external_token_manager=self.external_token_manager,
)
env_vars[ProviderType.GITHUB] = await service.get_latest_token()
if not expose_secrets:
return env_vars

View File

@ -1,3 +1,4 @@
from types import MappingProxyType
from urllib.parse import parse_qs
from socketio.exceptions import ConnectionRefusedError
@ -16,6 +17,9 @@ from openhands.events.observation.agent import (
RecallObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderToken
from openhands.integrations.service_types import ProviderType
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import (
SettingsStoreImpl,
config,
@ -27,12 +31,26 @@ from openhands.storage.conversation.conversation_validator import (
)
def create_provider_tokens_object(
providers_set: list[ProviderType],
) -> PROVIDER_TOKEN_TYPE:
provider_information = {}
for provider in providers_set:
provider_information[provider] = ProviderToken(token=None, user_id=None)
return MappingProxyType(provider_information)
@sio.event
async def connect(connection_id: str, environ):
logger.info(f'sio:connect: {connection_id}')
query_params = parse_qs(environ.get('QUERY_STRING', ''))
latest_event_id = int(query_params.get('latest_event_id', [-1])[0])
conversation_id = query_params.get('conversation_id', [None])[0]
providers_raw: list[str] = query_params.get('providers_set', [])
providers_set: list[ProviderType] = [ProviderType(p) for p in providers_raw]
if not conversation_id:
logger.error('No conversation_id in query params')
raise ConnectionRefusedError('No conversation_id in query params')
@ -50,9 +68,17 @@ async def connect(connection_id: str, environ):
raise ConnectionRefusedError(
'Settings not found', {'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND'}
)
session_init_args: dict = {}
if settings:
session_init_args = {**settings.__dict__, **session_init_args}
session_init_args['git_provider_tokens'] = create_provider_tokens_object(
providers_set
)
conversation_init_data = ConversationInitData(**session_init_args)
event_stream = await conversation_manager.join_conversation(
conversation_id, connection_id, settings, user_id, github_user_id
conversation_id, connection_id, conversation_init_data, user_id, github_user_id
)
logger.info(
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'