mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
[Fix]: Ensure refresh logic works for restarted conversations in cloud openhands (#7670)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
d3043ec898
commit
f5fa076fdd
@ -8,6 +8,7 @@ import {
|
||||
WsClientProvider,
|
||||
useWsClient,
|
||||
} from "#/context/ws-client-provider";
|
||||
import { AuthProvider } from "#/context/auth-context";
|
||||
|
||||
describe("Propagate error message", () => {
|
||||
it("should do nothing when no message was passed from server", () => {
|
||||
@ -90,9 +91,11 @@ describe("WsClientProvider", () => {
|
||||
const { getByText } = render(<TestComponent />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<WsClientProvider conversationId="test-conversation-id">
|
||||
{children}
|
||||
</WsClientProvider>
|
||||
<AuthProvider initialProviderTokens={[]}>
|
||||
<WsClientProvider conversationId="test-conversation-id">
|
||||
{children}
|
||||
</WsClientProvider>
|
||||
</AuthProvider>
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
@ -9,6 +9,7 @@ import {
|
||||
AssistantMessageAction,
|
||||
UserMessageAction,
|
||||
} from "#/types/core/actions";
|
||||
import { useAuth } from "./auth-context";
|
||||
|
||||
const isOpenHandsEvent = (event: unknown): event is OpenHandsParsedEvent =>
|
||||
typeof event === "object" &&
|
||||
@ -110,6 +111,7 @@ export function WsClientProvider({
|
||||
);
|
||||
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
|
||||
const lastEventRef = React.useRef<Record<string, unknown> | null>(null);
|
||||
const { providerTokensSet } = useAuth();
|
||||
|
||||
const messageRateHandler = useRate({ threshold: 250 });
|
||||
|
||||
@ -168,6 +170,7 @@ export function WsClientProvider({
|
||||
const query = {
|
||||
latest_event_id: lastEvent?.id ?? -1,
|
||||
conversation_id: conversationId,
|
||||
providers_set: providerTokensSet,
|
||||
};
|
||||
|
||||
const baseUrl =
|
||||
|
||||
@ -14,7 +14,6 @@ from pydantic import (
|
||||
)
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.commands import CmdRunAction
|
||||
from openhands.events.stream import EventStream
|
||||
@ -293,9 +292,7 @@ class ProviderHandler:
|
||||
get_latest: Get the latest working token for the providers if True, otherwise get the existing ones
|
||||
"""
|
||||
|
||||
# TODO: We should remove `not get_latest` in the future. More
|
||||
# details about the error this fixes is in the next comment below
|
||||
if not self.provider_tokens and not get_latest:
|
||||
if not self.provider_tokens:
|
||||
return {}
|
||||
|
||||
env_vars: dict[ProviderType, SecretStr] = {}
|
||||
@ -316,20 +313,6 @@ class ProviderHandler:
|
||||
if token:
|
||||
env_vars[provider] = token
|
||||
|
||||
# TODO: we have an error where reinitializing the runtime doesn't happen with
|
||||
# the provider tokens; thus the code above believes that github isn't a provider
|
||||
# when it really is. We need to share information about current providers set
|
||||
# for the user when the socket event for connect is sent
|
||||
if ProviderType.GITHUB not in env_vars and get_latest:
|
||||
logger.info(
|
||||
f'Force refresh runtime token for user: {self.external_auth_id}'
|
||||
)
|
||||
service = GithubServiceImpl(
|
||||
external_auth_id=self.external_auth_id,
|
||||
external_token_manager=self.external_token_manager,
|
||||
)
|
||||
env_vars[ProviderType.GITHUB] = await service.get_latest_token()
|
||||
|
||||
if not expose_secrets:
|
||||
return env_vars
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from types import MappingProxyType
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from socketio.exceptions import ConnectionRefusedError
|
||||
@ -16,6 +17,9 @@ from openhands.events.observation.agent import (
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderToken
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import (
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
@ -27,12 +31,26 @@ from openhands.storage.conversation.conversation_validator import (
|
||||
)
|
||||
|
||||
|
||||
def create_provider_tokens_object(
|
||||
providers_set: list[ProviderType],
|
||||
) -> PROVIDER_TOKEN_TYPE:
|
||||
provider_information = {}
|
||||
|
||||
for provider in providers_set:
|
||||
provider_information[provider] = ProviderToken(token=None, user_id=None)
|
||||
|
||||
return MappingProxyType(provider_information)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def connect(connection_id: str, environ):
|
||||
logger.info(f'sio:connect: {connection_id}')
|
||||
query_params = parse_qs(environ.get('QUERY_STRING', ''))
|
||||
latest_event_id = int(query_params.get('latest_event_id', [-1])[0])
|
||||
conversation_id = query_params.get('conversation_id', [None])[0]
|
||||
providers_raw: list[str] = query_params.get('providers_set', [])
|
||||
providers_set: list[ProviderType] = [ProviderType(p) for p in providers_raw]
|
||||
|
||||
if not conversation_id:
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
@ -50,9 +68,17 @@ async def connect(connection_id: str, environ):
|
||||
raise ConnectionRefusedError(
|
||||
'Settings not found', {'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND'}
|
||||
)
|
||||
session_init_args: dict = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
session_init_args['git_provider_tokens'] = create_provider_tokens_object(
|
||||
providers_set
|
||||
)
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id, connection_id, settings, user_id, github_user_id
|
||||
conversation_id, connection_id, conversation_init_data, user_id, github_user_id
|
||||
)
|
||||
logger.info(
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user