[Feat]: Add experiment manager (#8820)

This commit is contained in:
Rohit Malhotra 2025-06-05 14:49:20 -04:00 committed by GitHub
parent 412e265745
commit 93b1276768
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 3 deletions

View File

@ -0,0 +1,19 @@
import os
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.utils.import_utils import get_impl
class ExperimentManager:
@staticmethod
def run_conversation_variant_test(
user_id: str, conversation_id: str, conversation_settings: ConversationInitData
) -> ConversationInitData:
return conversation_settings
experiment_manager_cls = os.environ.get(
'OPENHANDS_EXPERIMENT_MANAGER_CLS',
'openhands.experiments.experiment_manager.ExperimentManager',
)
ExperimentManagerImpl = get_impl(ExperimentManager, experiment_manager_cls)

View File

@ -19,6 +19,7 @@ from openhands.events.observation.agent import (
AgentStateChangedObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.experiments.experiment_manager import ExperimentManagerImpl
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderToken
from openhands.integrations.service_types import ProviderType
from openhands.server.session.conversation_init_data import ConversationInitData
@ -49,7 +50,7 @@ def create_provider_tokens_object(
async def setup_init_convo_settings(
user_id: str | None, providers_set: list[ProviderType]
user_id: str | None, conversation_id: str, providers_set: list[ProviderType]
) -> ConversationInitData:
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
@ -73,7 +74,11 @@ async def setup_init_convo_settings(
if user_secrets:
session_init_args['custom_secrets'] = user_secrets.custom_secrets
return ConversationInitData(**session_init_args)
convo_init_data = ConversationInitData(**session_init_args)
# We should recreate the same experiment conditions when restarting a conversation
return ExperimentManagerImpl.run_conversation_variant_test(
user_id, conversation_id, convo_init_data
)
@sio.event
@ -119,7 +124,9 @@ async def connect(connection_id: str, environ: dict) -> None:
f'User {user_id} is allowed to connect to conversation {conversation_id}'
)
conversation_init_data = await setup_init_convo_settings(user_id, providers_set)
conversation_init_data = await setup_init_convo_settings(
user_id, conversation_id, providers_set
)
agent_loop_info = await conversation_manager.join_conversation(
conversation_id,
connection_id,

View File

@ -4,6 +4,7 @@ from typing import Any
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.experiments.experiment_manager import ExperimentManagerImpl
from openhands.integrations.provider import (
CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA,
PROVIDER_TOKEN_TYPE,
@ -78,6 +79,8 @@ async def create_new_conversation(
session_init_args['git_provider'] = git_provider
session_init_args['conversation_instructions'] = conversation_instructions
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
logger.info('ServerConversation store loaded')
@ -93,6 +96,7 @@ async def create_new_conversation(
extra={'user_id': user_id, 'session_id': conversation_id},
)
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(user_id, conversation_id, conversation_init_data)
conversation_title = get_default_conversation_title(conversation_id)
logger.info(f'Saving metadata for conversation {conversation_id}')
@ -105,6 +109,7 @@ async def create_new_conversation(
selected_repository=selected_repository,
selected_branch=selected_branch,
git_provider=git_provider,
llm_model=settings.llm_model,
)
)

View File

@ -25,6 +25,7 @@ class ConversationMetadata:
trigger: ConversationTrigger | None = None
pr_number: list[int] = field(default_factory=list)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
llm_model: str | None = None
# Cost and token metrics
accumulated_cost: float = 0.0
prompt_tokens: int = 0