mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
SAAS: Introducing orgs (phase 1) (#11265)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com> Co-authored-by: Hiep Le <69354317+hieptl@users.noreply.github.com> Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
This commit is contained in:
@@ -69,8 +69,6 @@ class StoredConversationMetadata(Base): # type: ignore
|
||||
conversation_id = Column(
|
||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
github_user_id = Column(String, nullable=True) # The GitHub user ID
|
||||
user_id = Column(String, nullable=False) # The Keycloak User ID
|
||||
selected_repository = Column(String, nullable=True)
|
||||
selected_branch = Column(String, nullable=True)
|
||||
git_provider = Column(
|
||||
@@ -199,10 +197,9 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count sandboxed conversations matching the given filters."""
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id))
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = query.where(StoredConversationMetadata.user_id == user_id)
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id)).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
@@ -319,22 +316,11 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
existing = result.scalar_one_or_none()
|
||||
assert existing is None or existing.user_id == user_id
|
||||
|
||||
metrics = info.metrics or MetricsSnapshot()
|
||||
usage = metrics.accumulated_token_usage or TokenUsage()
|
||||
|
||||
stored = StoredConversationMetadata(
|
||||
conversation_id=str(info.id),
|
||||
github_user_id=None, # TODO: Should we add this to the conversation info?
|
||||
user_id=info.created_by_user_id or '',
|
||||
selected_repository=info.selected_repository,
|
||||
selected_branch=info.selected_branch,
|
||||
git_provider=info.git_provider.value if info.git_provider else None,
|
||||
@@ -342,7 +328,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
last_updated_at=info.updated_at,
|
||||
created_at=info.created_at,
|
||||
trigger=info.trigger.value if info.trigger else None,
|
||||
pr_number=info.pr_number,
|
||||
pr_number=info.pr_number or [],
|
||||
# Cost and token metrics
|
||||
accumulated_cost=metrics.accumulated_cost,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
@@ -496,11 +482,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
)
|
||||
return query
|
||||
|
||||
def _to_info(
|
||||
@@ -535,7 +516,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
|
||||
return AppConversationInfo(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
created_by_user_id=None, # User ID is now stored in ConversationMetadataSaas
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
@@ -580,13 +561,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
# Apply user security filtering - only allow deletion of conversations owned by the current user
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
delete_query = delete_query.where(
|
||||
StoredConversationMetadata.user_id == user_id
|
||||
)
|
||||
|
||||
# Execute the secure delete query
|
||||
result = await self.db_session.execute(delete_query)
|
||||
|
||||
|
||||
46
openhands/app_server/app_lifespan/alembic/versions/005.py
Normal file
46
openhands/app_server/app_lifespan/alembic/versions/005.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Update conversation_metadata table to match StoredConversationMetadata dataclass
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 004
|
||||
Create Date: 2025-11-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '005'
|
||||
down_revision: Union[str, Sequence[str], None] = '004'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
with op.batch_alter_table('conversation_metadata') as batch_op:
|
||||
# Drop columns not in StoredConversationMetadata dataclass
|
||||
batch_op.drop_column('github_user_id')
|
||||
|
||||
# Alter user_id to become nullable
|
||||
batch_op.alter_column(
|
||||
'user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
with op.batch_alter_table('conversation_metadata') as batch_op:
|
||||
# Add back removed column
|
||||
batch_op.add_column(sa.Column('github_user_id', sa.String(), nullable=True))
|
||||
|
||||
# Restore NOT NULL constraint
|
||||
batch_op.alter_column(
|
||||
'user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -233,7 +233,20 @@ def config_from_env() -> AppServerConfig:
|
||||
config.sandbox_spec = DockerSandboxSpecServiceInjector()
|
||||
|
||||
if config.app_conversation_info is None:
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
# Use enterprise injector if running in SAAS mode
|
||||
if 'saas' in (os.getenv('OPENHANDS_CONFIG_CLS') or '').lower():
|
||||
try:
|
||||
# Import enterprise injector dynamically
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasAppConversationInfoServiceInjector,
|
||||
)
|
||||
|
||||
config.app_conversation_info = SaasAppConversationInfoServiceInjector()
|
||||
except ImportError:
|
||||
# Fallback to OSS injector if enterprise module is not available
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
else:
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
|
||||
if config.app_conversation_start_task is None:
|
||||
config.app_conversation_start_task = (
|
||||
|
||||
@@ -153,15 +153,15 @@ class JwtService:
|
||||
|
||||
# Add standard JWT claims
|
||||
now = utc_now()
|
||||
if expires_in is None:
|
||||
expires_in = timedelta(hours=1)
|
||||
|
||||
jwt_payload = {
|
||||
**payload,
|
||||
'iat': int(now.timestamp()),
|
||||
'exp': int((now + expires_in).timestamp()),
|
||||
}
|
||||
|
||||
# Only add exp if expires_in is provided
|
||||
if expires_in is not None:
|
||||
jwt_payload['exp'] = int((now + expires_in).timestamp())
|
||||
|
||||
# Get the raw key for JWE encryption and derive a 256-bit key
|
||||
secret_key = self._keys[key_id].key.get_secret_value()
|
||||
key_bytes = secret_key.encode() if isinstance(secret_key, str) else secret_key
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -22,16 +22,6 @@ class FileReadSource(str, Enum):
|
||||
DEFAULT = 'default'
|
||||
|
||||
|
||||
class RecallType(str, Enum):
|
||||
"""The type of information that can be retrieved from microagents."""
|
||||
|
||||
WORKSPACE_CONTEXT = 'workspace_context'
|
||||
"""Workspace context (repo instructions, runtime, etc.)"""
|
||||
|
||||
KNOWLEDGE = 'knowledge'
|
||||
"""A knowledge microagent."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
INVALID_ID = -1
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
@@ -28,6 +27,7 @@ from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
from openhands.events.observation.task_tracking import TaskTrackingObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
__all__ = [
|
||||
'Observation',
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
11
openhands/events/recall_type.py
Normal file
11
openhands/events/recall_type.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RecallType(str, Enum):
|
||||
"""The type of information that can be retrieved from microagents."""
|
||||
|
||||
WORKSPACE_CONTEXT = 'workspace_context'
|
||||
"""Workspace context (repo instructions, runtime, etc.)"""
|
||||
|
||||
KNOWLEDGE = 'knowledge'
|
||||
"""A knowledge microagent."""
|
||||
@@ -1,7 +1,6 @@
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
@@ -32,6 +31,7 @@ from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
from openhands.events.observation.task_tracking import TaskTrackingObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
observations = (
|
||||
NullObservation,
|
||||
|
||||
@@ -28,7 +28,7 @@ from openhands.integrations.service_types import (
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class AzureDevOpsServiceImpl(
|
||||
class AzureDevOpsService(
|
||||
AzureDevOpsResolverMixin,
|
||||
AzureDevOpsReposMixin,
|
||||
AzureDevOpsBranchesMixin,
|
||||
@@ -242,8 +242,34 @@ class AzureDevOpsServiceImpl(
|
||||
# Dynamic class loading to support custom implementations (e.g., SaaS)
|
||||
azure_devops_service_cls = os.environ.get(
|
||||
'OPENHANDS_AZURE_DEVOPS_SERVICE_CLS',
|
||||
'openhands.integrations.azure_devops.azure_devops_service.AzureDevOpsServiceImpl',
|
||||
)
|
||||
AzureDevOpsServiceImpl = get_impl( # type: ignore[misc]
|
||||
AzureDevOpsServiceImpl, azure_devops_service_cls
|
||||
'openhands.integrations.azure_devops.azure_devops_service.AzureDevOpsService',
|
||||
)
|
||||
|
||||
# Lazy loading to avoid circular imports
|
||||
_azure_devops_service_impl = None
|
||||
|
||||
|
||||
def get_azure_devops_service_impl():
|
||||
"""Get the Azure DevOps service implementation with lazy loading."""
|
||||
global _azure_devops_service_impl
|
||||
if _azure_devops_service_impl is None:
|
||||
_azure_devops_service_impl = get_impl( # type: ignore[misc]
|
||||
AzureDevOpsService, azure_devops_service_cls
|
||||
)
|
||||
return _azure_devops_service_impl
|
||||
|
||||
|
||||
# For backward compatibility, provide the implementation as a property
|
||||
class _AzureDevOpsServiceImplProxy:
|
||||
"""Proxy class to provide lazy loading for AzureDevOpsServiceImpl."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
impl = get_azure_devops_service_impl()
|
||||
return getattr(impl, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
impl = get_azure_devops_service_impl()
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
|
||||
AzureDevOpsServiceImpl: type[AzureDevOpsService] = _AzureDevOpsServiceImplProxy() # type: ignore[assignment]
|
||||
|
||||
@@ -64,4 +64,30 @@ bitbucket_service_cls = os.environ.get(
|
||||
'OPENHANDS_BITBUCKET_SERVICE_CLS',
|
||||
'openhands.integrations.bitbucket.bitbucket_service.BitBucketService',
|
||||
)
|
||||
BitBucketServiceImpl = get_impl(BitBucketService, bitbucket_service_cls)
|
||||
|
||||
# Lazy loading to avoid circular imports
|
||||
_bitbucket_service_impl = None
|
||||
|
||||
|
||||
def get_bitbucket_service_impl():
|
||||
"""Get the BitBucket service implementation with lazy loading."""
|
||||
global _bitbucket_service_impl
|
||||
if _bitbucket_service_impl is None:
|
||||
_bitbucket_service_impl = get_impl(BitBucketService, bitbucket_service_cls)
|
||||
return _bitbucket_service_impl
|
||||
|
||||
|
||||
# For backward compatibility, provide the implementation as a property
|
||||
class _BitBucketServiceImplProxy:
|
||||
"""Proxy class to provide lazy loading for BitBucketServiceImpl."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
impl = get_bitbucket_service_impl()
|
||||
return getattr(impl, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
impl = get_bitbucket_service_impl()
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
|
||||
BitBucketServiceImpl: type[BitBucketService] = _BitBucketServiceImplProxy() # type: ignore[assignment]
|
||||
|
||||
@@ -75,4 +75,30 @@ github_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITHUB_SERVICE_CLS',
|
||||
'openhands.integrations.github.github_service.GitHubService',
|
||||
)
|
||||
GithubServiceImpl = get_impl(GitHubService, github_service_cls)
|
||||
|
||||
# Lazy loading to avoid circular imports
|
||||
_github_service_impl = None
|
||||
|
||||
|
||||
def get_github_service_impl():
|
||||
"""Get the GitHub service implementation with lazy loading."""
|
||||
global _github_service_impl
|
||||
if _github_service_impl is None:
|
||||
_github_service_impl = get_impl(GitHubService, github_service_cls)
|
||||
return _github_service_impl
|
||||
|
||||
|
||||
# For backward compatibility, provide the implementation as a property
|
||||
class _GitHubServiceImplProxy:
|
||||
"""Proxy class to provide lazy loading for GithubServiceImpl."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
impl = get_github_service_impl()
|
||||
return getattr(impl, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
impl = get_github_service_impl()
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
|
||||
GithubServiceImpl: type[GitHubService] = _GitHubServiceImplProxy() # type: ignore[assignment]
|
||||
|
||||
@@ -79,4 +79,30 @@ gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||
)
|
||||
GitLabServiceImpl = get_impl(GitLabService, gitlab_service_cls)
|
||||
|
||||
# Lazy loading to avoid circular imports
|
||||
_gitlab_service_impl = None
|
||||
|
||||
|
||||
def get_gitlab_service_impl():
|
||||
"""Get the GitLab service implementation with lazy loading."""
|
||||
global _gitlab_service_impl
|
||||
if _gitlab_service_impl is None:
|
||||
_gitlab_service_impl = get_impl(GitLabService, gitlab_service_cls)
|
||||
return _gitlab_service_impl
|
||||
|
||||
|
||||
# For backward compatibility, provide the implementation as a property
|
||||
class _GitLabServiceImplProxy:
|
||||
"""Proxy class to provide lazy loading for GitLabServiceImpl."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
impl = get_gitlab_service_impl()
|
||||
return getattr(impl, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
impl = get_gitlab_service_impl()
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
|
||||
GitLabServiceImpl: type[GitLabService] = _GitLabServiceImplProxy() # type: ignore[assignment]
|
||||
|
||||
@@ -22,7 +22,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
AgentDelegateObservation,
|
||||
@@ -44,6 +44,7 @@ from openhands.events.observation.agent import (
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.utils.prompt import (
|
||||
ConversationInstructions,
|
||||
|
||||
@@ -9,12 +9,13 @@ import openhands
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
from openhands.microagent import (
|
||||
BaseMicroagent,
|
||||
|
||||
@@ -159,7 +159,7 @@ class AgentSession:
|
||||
await provider_handler.set_event_stream_secrets(self.event_stream)
|
||||
|
||||
if custom_secrets:
|
||||
custom_secrets_handler.set_event_stream_secrets(self.event_stream)
|
||||
self.event_stream.set_secrets(custom_secrets_handler.get_env_vars())
|
||||
|
||||
self.memory = await self._create_memory(
|
||||
selected_repository=selected_repository,
|
||||
|
||||
@@ -13,7 +13,6 @@ from pydantic import (
|
||||
)
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
CUSTOM_SECRETS_TYPE,
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
@@ -144,14 +143,6 @@ class Secrets(BaseModel):
|
||||
|
||||
return new_data
|
||||
|
||||
def set_event_stream_secrets(self, event_stream: EventStream) -> None:
|
||||
"""This ensures that provider tokens and custom secrets masked from the event stream
|
||||
Args:
|
||||
event_stream: Agent session's event stream
|
||||
"""
|
||||
secrets = self.get_env_vars()
|
||||
event_stream.set_secrets(secrets)
|
||||
|
||||
def get_env_vars(self) -> dict[str, str]:
|
||||
secret_store = self.model_dump(context={'expose_secrets': True})
|
||||
custom_secrets = secret_store.get('custom_secrets', {})
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -51,7 +49,7 @@ class Settings(BaseModel):
|
||||
email_verified: bool | None = None
|
||||
git_user_name: str | None = None
|
||||
git_user_email: str | None = None
|
||||
v1_enabled: bool | None = Field(default=bool(os.getenv('V1_ENABLED') == '1'))
|
||||
v1_enabled: bool = True
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
|
||||
Reference in New Issue
Block a user