Move current user_id to github_user_id and create a new user_id field (#7231)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
This commit is contained in:
chuckbutkus 2025-03-16 16:32:27 -04:00 committed by GitHub
parent 999a59f938
commit 8074b261d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 128 additions and 57 deletions

View File

@ -5,6 +5,7 @@ from typing import Any
import httpx
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import (
AuthenticationError,
GitService,
@ -15,7 +16,7 @@ from openhands.integrations.service_types import (
User,
)
from openhands.utils.import_utils import get_impl
from openhands.core.logger import openhands_logger as logger
class GitHubService(GitService):
BASE_URL = 'https://api.github.com'
@ -25,6 +26,7 @@ class GitHubService(GitService):
def __init__(
self,
user_id: str | None = None,
external_auth_id: str | None = None,
external_auth_token: SecretStr | None = None,
token: SecretStr | None = None,
external_token_manager: bool = False,

View File

@ -97,7 +97,7 @@ class Runtime(FileEditRuntimeMixin):
status_callback: Callable | None = None,
attach_to_existing: bool = False,
headless_mode: bool = False,
github_user_id: str | None = None,
user_id: str | None = None,
):
self.sid = sid
self.event_stream = event_stream
@ -130,7 +130,7 @@ class Runtime(FileEditRuntimeMixin):
self, enable_llm_editor=config.get_agent_config().codeact_enable_llm_editor
)
self.github_user_id = github_user_id
self.user_id = user_id
def setup_initial_env(self) -> None:
if self.attach_to_existing:
@ -220,9 +220,9 @@ class Runtime(FileEditRuntimeMixin):
assert event.timeout is not None
try:
if isinstance(event, CmdRunAction):
if self.github_user_id and '$GITHUB_TOKEN' in event.command:
if self.user_id and '$GITHUB_TOKEN' in event.command:
gh_client = GithubServiceImpl(
user_id=self.github_user_id, external_token_manager=True
external_auth_id=self.user_id, external_token_manager=True
)
token = await gh_client.get_latest_token()
if token:

View File

@ -59,7 +59,7 @@ class ActionExecutionClient(Runtime):
status_callback: Any | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
github_user_id: str | None = None,
user_id: str | None = None,
):
self.session = HttpSession()
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
@ -75,7 +75,7 @@ class ActionExecutionClient(Runtime):
status_callback,
attach_to_existing,
headless_mode,
github_user_id,
user_id,
)
@abstractmethod

View File

@ -45,7 +45,7 @@ class RemoteRuntime(ActionExecutionClient):
status_callback: Callable | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
github_user_id: str | None = None,
user_id: str | None = None,
):
super().__init__(
config,
@ -56,7 +56,7 @@ class RemoteRuntime(ActionExecutionClient):
status_callback,
attach_to_existing,
headless_mode,
github_user_id,
user_id,
)
if self.config.sandbox.api_key is None:
raise ValueError(

View File

@ -46,7 +46,12 @@ class ConversationManager(ABC):
@abstractmethod
async def join_conversation(
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
self,
sid: str,
connection_id: str,
settings: Settings,
user_id: str | None,
github_user_id: str | None,
) -> EventStream | None:
"""Join a conversation and return its event stream."""
@ -74,6 +79,7 @@ class ConversationManager(ABC):
settings: Settings,
user_id: str | None,
initial_user_msg: MessageAction | None = None,
github_user_id: str | None = None,
) -> EventStream:
"""Start an event loop if one is not already running"""

View File

@ -106,7 +106,12 @@ class StandaloneConversationManager(ConversationManager):
return c
async def join_conversation(
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
self,
sid: str,
connection_id: str,
settings: Settings,
user_id: str | None,
github_user_id: str | None,
):
logger.info(
f'join_conversation:{sid}:{connection_id}',
@ -116,7 +121,9 @@ class StandaloneConversationManager(ConversationManager):
self._local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid)
if not event_stream:
return await self.maybe_start_agent_loop(sid, settings, user_id)
return await self.maybe_start_agent_loop(
sid, settings, user_id, github_user_id=github_user_id
)
for event in event_stream.get_events(reverse=True):
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in (
@ -187,14 +194,18 @@ class StandaloneConversationManager(ConversationManager):
logger.error('error_cleaning_stale')
await asyncio.sleep(_CLEANUP_INTERVAL)
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
async def _get_conversation_store(
self, user_id: str | None, github_user_id: str | None
) -> ConversationStore:
conversation_store_class = self._conversation_store_class
if not conversation_store_class:
self._conversation_store_class = conversation_store_class = get_impl(
ConversationStore, # type: ignore
self.server_config.conversation_store_class,
)
store = await conversation_store_class.get_instance(self.config, user_id)
store = await conversation_store_class.get_instance(
self.config, user_id, github_user_id
)
return store
async def get_running_agent_loops(
@ -243,6 +254,7 @@ class StandaloneConversationManager(ConversationManager):
settings: Settings,
user_id: str | None,
initial_user_msg: MessageAction | None = None,
github_user_id: str | None = None,
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
session: Session | None = None
@ -256,7 +268,9 @@ class StandaloneConversationManager(ConversationManager):
extra={'session_id': sid, 'user_id': user_id},
)
# Get the conversations sorted (oldest first)
conversation_store = await self._get_conversation_store(user_id)
conversation_store = await self._get_conversation_store(
user_id, github_user_id
)
conversations = await conversation_store.get_all_metadata(response_ids)
conversations.sort(key=_last_updated_at_key, reverse=True)
@ -277,7 +291,9 @@ class StandaloneConversationManager(ConversationManager):
try:
session.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER,
self._create_conversation_update_callback(user_id, sid),
self._create_conversation_update_callback(
user_id, github_user_id, sid
),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
@ -374,22 +390,23 @@ class StandaloneConversationManager(ConversationManager):
)
def _create_conversation_update_callback(
self, user_id: str | None, conversation_id: str
self, user_id: str | None, github_user_id: str | None, conversation_id: str
) -> Callable:
def callback(*args, **kwargs):
call_async_from_sync(
self._update_timestamp_for_conversation,
GENERAL_TIMEOUT,
user_id,
github_user_id,
conversation_id,
)
return callback
async def _update_timestamp_for_conversation(
self, user_id: str, conversation_id: str
self, user_id: str, github_user_id: str, conversation_id: str
):
conversation_store = await self._get_conversation_store(user_id)
conversation_store = await self._get_conversation_store(user_id, github_user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now(timezone.utc)
await conversation_store.save_metadata(conversation)

View File

@ -35,7 +35,9 @@ async def connect(connection_id: str, environ):
cookies_str = environ.get('HTTP_COOKIE', '')
conversation_validator = ConversationValidatorImpl()
user_id = await conversation_validator.validate(conversation_id, cookies_str)
user_id, github_user_id = await conversation_validator.validate(
conversation_id, cookies_str
)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
@ -46,7 +48,7 @@ async def connect(connection_id: str, environ):
)
event_stream = await conversation_manager.join_conversation(
conversation_id, connection_id, settings, user_id
conversation_id, connection_id, settings, user_id, github_user_id
)
agent_state_changed = None

View File

@ -10,7 +10,12 @@ from openhands.events.action.message import MessageAction
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.provider import ProviderType
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id
from openhands.server.auth import (
get_access_token,
get_github_user_id,
get_provider_tokens,
get_user_id,
)
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@ -73,12 +78,12 @@ async def _create_new_conversation(
logger.warn('Settings not present, not starting conversation')
raise MissingSettingsError('Settings not found')
session_init_args['github_token'] = token or SecretStr('')
session_init_args['provider_token'] = token
session_init_args['selected_repository'] = selected_repository
session_init_args['selected_branch'] = selected_branch
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
@ -100,7 +105,8 @@ async def _create_new_conversation(
ConversationMetadata(
conversation_id=conversation_id,
title=conversation_title,
github_user_id=user_id,
user_id=user_id,
github_user_id=None,
selected_repository=selected_repository,
selected_branch=selected_branch,
)
@ -122,7 +128,10 @@ async def _create_new_conversation(
image_urls=image_urls or [],
)
await conversation_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data, user_id, initial_message_action
conversation_id,
conversation_init_data,
user_id,
initial_user_msg=initial_message_action,
)
logger.info(f'Finished initializing conversation {conversation_id}')
@ -158,7 +167,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
try:
# Create conversation with initial message
conversation_id = await _create_new_conversation(
user_id,
get_user_id(request),
github_token,
selected_repository,
selected_branch,
@ -197,7 +206,7 @@ async def search_conversations(
limit: int = 20,
) -> ConversationInfoResultSet:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
@ -216,7 +225,7 @@ async def search_conversations(
conversation.conversation_id for conversation in filtered_results
)
running_conversations = await conversation_manager.get_running_agent_loops(
get_github_user_id(request), set(conversation_ids)
get_user_id(request), set(conversation_ids)
)
result = ConversationInfoResultSet(
results=await wait_all(
@ -236,7 +245,7 @@ async def get_conversation(
conversation_id: str, request: Request
) -> ConversationInfo | None:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
try:
metadata = await conversation_store.get_metadata(conversation_id)
@ -252,7 +261,7 @@ async def update_conversation(
request: Request, conversation_id: str, title: str = Body(embed=True)
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
@ -268,7 +277,7 @@ async def delete_conversation(
request: Request,
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
try:
await conversation_store.get_metadata(conversation_id)

View File

@ -53,7 +53,7 @@ class AgentSession:
sid: str,
file_store: FileStore,
status_callback: Callable | None = None,
github_user_id: str | None = None,
user_id: str | None = None,
):
"""Initializes a new instance of the Session class
@ -66,9 +66,9 @@ class AgentSession:
self.event_stream = EventStream(sid, file_store)
self.file_store = file_store
self._status_callback = status_callback
self.github_user_id = github_user_id
self.user_id = user_id
self.logger = OpenHandsLoggerAdapter(
extra={'session_id': sid, 'user_id': github_user_id}
extra={'session_id': sid, 'user_id': user_id}
)
async def start(
@ -241,7 +241,7 @@ class AgentSession:
kwargs = {}
if runtime_cls == RemoteRuntime:
kwargs['github_user_id'] = self.github_user_id
kwargs['user_id'] = self.user_id
self.runtime = runtime_cls(
config=config,

View File

@ -8,6 +8,6 @@ class ConversationInitData(Settings):
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
github_token: SecretStr | None = Field(default=None)
provider_token: SecretStr | None = Field(default=None)
selected_repository: str | None = Field(default=None)
selected_branch: str | None = Field(default=None)

View File

@ -61,7 +61,7 @@ class Session:
sid,
file_store,
status_callback=self.queue_status_message,
github_user_id=user_id,
user_id=user_id,
)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid
@ -123,11 +123,11 @@ class Session:
agent = Agent.get_cls(agent_cls)(llm, agent_config)
github_token = None
provider_token = None
selected_repository = None
selected_branch = None
if isinstance(settings, ConversationInitData):
github_token = settings.github_token
provider_token = settings.provider_token
selected_repository = settings.selected_repository
selected_branch = settings.selected_branch
@ -140,7 +140,7 @@ class Session:
max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
github_token=github_token,
github_token=provider_token,
selected_repository=selected_repository,
selected_branch=selected_branch,
initial_message=initial_message,

View File

@ -43,7 +43,7 @@ class Settings(BaseModel):
if context and context.get('expose_secrets', False):
return llm_api_key.get_secret_value()
return pydantic_encoder(llm_api_key)
return pydantic_encoder(llm_api_key) if llm_api_key else None
@staticmethod
def _convert_token_value(

View File

@ -12,25 +12,36 @@ from openhands.utils.async_utils import wait_all
class ConversationStore(ABC):
"""
Storage for conversation metadata. May or may not support multiple users depending on the environment
"""
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
@abstractmethod
async def save_metadata(self, metadata: ConversationMetadata) -> None:
"""Store conversation metadata"""
"""Store conversation metadata."""
@abstractmethod
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
"""Load conversation metadata"""
"""Load conversation metadata."""
async def validate_metadata(
self, conversation_id: str, user_id: str, github_user_id: str
) -> bool:
"""Validate that conversation belongs to the current user."""
# TODO: remove github_user_id after transition to Keycloak is complete.
metadata = await self.get_metadata(conversation_id)
if (not metadata.user_id and not metadata.github_user_id) or (
metadata.user_id != user_id and metadata.github_user_id != github_user_id
):
return False
else:
return True
@abstractmethod
async def delete_metadata(self, conversation_id: str) -> None:
"""delete conversation metadata"""
"""Delete conversation metadata."""
@abstractmethod
async def exists(self, conversation_id: str) -> bool:
"""Check if conversation exists"""
"""Check if conversation exists."""
@abstractmethod
async def search(
@ -49,6 +60,6 @@ class ConversationStore(ABC):
@classmethod
@abstractmethod
async def get_instance(
cls, config: AppConfig, user_id: str | None
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
) -> ConversationStore:
"""Get a store for the user represented by the token given"""

View File

@ -7,7 +7,7 @@ class ConversationValidator:
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
async def validate(self, conversation_id: str, cookies_str: str):
return None
return None, None
conversation_validator_cls = os.environ.get(

View File

@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore):
@classmethod
async def get_instance(
cls, config: AppConfig, user_id: str | None
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
) -> FileConversationStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileConversationStore(file_store)

View File

@ -5,6 +5,7 @@ from datetime import datetime, timezone
@dataclass
class ConversationMetadata:
conversation_id: str
user_id: str | None
github_user_id: str | None
selected_repository: str | None
selected_branch: str | None = None

View File

@ -32,6 +32,7 @@ def _patch_store():
'selected_repository': 'foobar',
'conversation_id': 'some_conversation_id',
'github_user_id': '12345',
'user_id': '12345',
'created_at': '2025-01-01T00:00:00+00:00',
'last_updated_at': '2025-01-01T00:01:00+00:00',
}

View File

@ -13,7 +13,8 @@ async def test_load_store():
store = FileConversationStore(InMemoryFileStore({}))
expected = ConversationMetadata(
conversation_id='some-conversation-id',
github_user_id='some-user-id',
user_id='some-user-id',
github_user_id='12345',
selected_repository='some-repo',
title="Let's talk about trains",
)
@ -31,6 +32,7 @@ async def test_load_int_user_id():
{
'conversation_id': 'some-conversation-id',
'github_user_id': 12345,
'user_id': '67890',
'selected_repository': 'some-repo',
'title': "Let's talk about trains",
'created_at': '2025-01-16T19:51:04.886331Z',
@ -41,6 +43,7 @@ async def test_load_int_user_id():
)
found = await store.get_metadata('some-conversation-id')
assert found.github_user_id == '12345'
assert found.user_id == '67890'
@pytest.mark.asyncio
@ -61,6 +64,7 @@ async def test_search_basic():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'First conversation',
'created_at': '2025-01-16T19:51:04Z',
@ -70,6 +74,7 @@ async def test_search_basic():
{
'conversation_id': 'conv2',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Second conversation',
'created_at': '2025-01-17T19:51:04Z',
@ -79,6 +84,7 @@ async def test_search_basic():
{
'conversation_id': 'conv3',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Third conversation',
'created_at': '2025-01-15T19:51:04Z',
@ -107,6 +113,7 @@ async def test_search_pagination():
{
'conversation_id': f'conv{i}',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': f'Conversation {i}',
'created_at': f'2025-01-{15+i}T19:51:04Z',
@ -148,6 +155,7 @@ async def test_search_with_invalid_conversation():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Valid conversation',
'created_at': '2025-01-16T19:51:04Z',
@ -176,6 +184,7 @@ async def test_get_all_metadata():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'First conversation',
'created_at': '2025-01-16T19:51:04Z',
@ -185,6 +194,7 @@ async def test_get_all_metadata():
{
'conversation_id': 'conv2',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Second conversation',
'created_at': '2025-01-17T19:51:04Z',

View File

@ -49,6 +49,7 @@ async def test_iterate_single_page():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'First conversation',
'created_at': '2025-01-16T19:51:04Z',
@ -58,6 +59,7 @@ async def test_iterate_single_page():
{
'conversation_id': 'conv2',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Second conversation',
'created_at': '2025-01-17T19:51:04Z',
@ -86,6 +88,7 @@ async def test_iterate_multiple_pages():
{
'conversation_id': f'conv{i}',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': f'Conversation {i}',
'created_at': f'2025-01-{15+i}T19:51:04Z',
@ -120,6 +123,7 @@ async def test_iterate_with_invalid_conversation():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Valid conversation',
'created_at': '2025-01-16T19:51:04Z',

View File

@ -61,7 +61,7 @@ async def test_init_new_local_session():
'new-session-id', ConversationInitData(), 1
)
await conversation_manager.join_conversation(
'new-session-id', 'new-session-id', ConversationInitData(), 1
'new-session-id', 'new-session-id', ConversationInitData(), 1, '12345'
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 1
@ -93,10 +93,18 @@ async def test_join_local_session():
'new-session-id', ConversationInitData(), None
)
await conversation_manager.join_conversation(
'new-session-id', 'new-session-id', ConversationInitData(), None
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
)
await conversation_manager.join_conversation(
'new-session-id', 'new-session-id', ConversationInitData(), None
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 2
@ -128,7 +136,7 @@ async def test_add_to_local_event_stream():
'new-session-id', ConversationInitData(), 1
)
await conversation_manager.join_conversation(
'new-session-id', 'connection-id', ConversationInitData(), 1
'new-session-id', 'connection-id', ConversationInitData(), 1, '12345'
)
await conversation_manager.send_to_event_stream(
'connection-id', {'event_type': 'some_event'}