SAAS: Introducing orgs (phase 1) (#11265)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com>
Co-authored-by: Hiep Le <69354317+hieptl@users.noreply.github.com>
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
This commit is contained in:
chuckbutkus
2026-01-15 22:03:31 -05:00
committed by GitHub
parent f5315887ec
commit d5e66b4f3a
106 changed files with 19515 additions and 15865 deletions

View File

@@ -69,8 +69,6 @@ class StoredConversationMetadata(Base): # type: ignore
conversation_id = Column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
github_user_id = Column(String, nullable=True) # The GitHub user ID
user_id = Column(String, nullable=False) # The Keycloak User ID
selected_repository = Column(String, nullable=True)
selected_branch = Column(String, nullable=True)
git_provider = Column(
@@ -199,10 +197,9 @@ class SQLAppConversationInfoService(AppConversationInfoService):
updated_at__lt: datetime | None = None,
) -> int:
"""Count sandboxed conversations matching the given filters."""
query = select(func.count(StoredConversationMetadata.conversation_id))
user_id = await self.user_context.get_user_id()
if user_id:
query = query.where(StoredConversationMetadata.user_id == user_id)
query = select(func.count(StoredConversationMetadata.conversation_id)).where(
StoredConversationMetadata.conversation_version == 'V1'
)
query = self._apply_filters(
query=query,
@@ -319,22 +316,11 @@ class SQLAppConversationInfoService(AppConversationInfoService):
async def save_app_conversation_info(
self, info: AppConversationInfo
) -> AppConversationInfo:
user_id = await self.user_context.get_user_id()
if user_id:
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_id == str(info.id)
)
result = await self.db_session.execute(query)
existing = result.scalar_one_or_none()
assert existing is None or existing.user_id == user_id
metrics = info.metrics or MetricsSnapshot()
usage = metrics.accumulated_token_usage or TokenUsage()
stored = StoredConversationMetadata(
conversation_id=str(info.id),
github_user_id=None, # TODO: Should we add this to the conversation info?
user_id=info.created_by_user_id or '',
selected_repository=info.selected_repository,
selected_branch=info.selected_branch,
git_provider=info.git_provider.value if info.git_provider else None,
@@ -342,7 +328,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
last_updated_at=info.updated_at,
created_at=info.created_at,
trigger=info.trigger.value if info.trigger else None,
pr_number=info.pr_number,
pr_number=info.pr_number or [],
# Cost and token metrics
accumulated_cost=metrics.accumulated_cost,
prompt_tokens=usage.prompt_tokens,
@@ -496,11 +482,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
user_id = await self.user_context.get_user_id()
if user_id:
query = query.where(
StoredConversationMetadata.user_id == user_id,
)
return query
def _to_info(
@@ -535,7 +516,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
return AppConversationInfo(
id=UUID(stored.conversation_id),
created_by_user_id=stored.user_id if stored.user_id else None,
created_by_user_id=None, # User ID is now stored in ConversationMetadataSaas
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
@@ -580,13 +561,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
StoredConversationMetadata.conversation_id == str(conversation_id)
)
# Apply user security filtering - only allow deletion of conversations owned by the current user
user_id = await self.user_context.get_user_id()
if user_id:
delete_query = delete_query.where(
StoredConversationMetadata.user_id == user_id
)
# Execute the secure delete query
result = await self.db_session.execute(delete_query)

View File

@@ -0,0 +1,46 @@
"""Update conversation_metadata table to match StoredConversationMetadata dataclass
Revision ID: 005
Revises: 004
Create Date: 2025-11-11 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '005'
down_revision: Union[str, Sequence[str], None] = '004'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
with op.batch_alter_table('conversation_metadata') as batch_op:
# Drop columns not in StoredConversationMetadata dataclass
batch_op.drop_column('github_user_id')
# Alter user_id to become nullable
batch_op.alter_column(
'user_id',
existing_type=sa.String(),
nullable=True,
)
def downgrade() -> None:
"""Downgrade schema."""
with op.batch_alter_table('conversation_metadata') as batch_op:
# Add back removed column
batch_op.add_column(sa.Column('github_user_id', sa.String(), nullable=True))
# Restore NOT NULL constraint
batch_op.alter_column(
'user_id',
existing_type=sa.String(),
nullable=False,
)

View File

@@ -233,7 +233,20 @@ def config_from_env() -> AppServerConfig:
config.sandbox_spec = DockerSandboxSpecServiceInjector()
if config.app_conversation_info is None:
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
# Use enterprise injector if running in SAAS mode
if 'saas' in (os.getenv('OPENHANDS_CONFIG_CLS') or '').lower():
try:
# Import enterprise injector dynamically
from enterprise.server.utils.saas_app_conversation_info_injector import (
SaasAppConversationInfoServiceInjector,
)
config.app_conversation_info = SaasAppConversationInfoServiceInjector()
except ImportError:
# Fallback to OSS injector if enterprise module is not available
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
else:
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
if config.app_conversation_start_task is None:
config.app_conversation_start_task = (

View File

@@ -153,15 +153,15 @@ class JwtService:
# Add standard JWT claims
now = utc_now()
if expires_in is None:
expires_in = timedelta(hours=1)
jwt_payload = {
**payload,
'iat': int(now.timestamp()),
'exp': int((now + expires_in).timestamp()),
}
# Only add exp if expires_in is provided
if expires_in is not None:
jwt_payload['exp'] = int((now + expires_in).timestamp())
# Get the raw key for JWE encryption and derive a 256-bit key
secret_key = self._keys[key_id].key.get_secret_value()
key_bytes = secret_key.encode() if isinstance(secret_key, str) else secret_key

View File

@@ -1,4 +1,5 @@
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.event import Event, EventSource
from openhands.events.recall_type import RecallType
from openhands.events.stream import EventStream, EventStreamSubscriber
__all__ = [

View File

@@ -3,7 +3,7 @@ from typing import Any
from openhands.core.schema import ActionType
from openhands.events.action.action import Action
from openhands.events.event import RecallType
from openhands.events.recall_type import RecallType
@dataclass

View File

@@ -22,16 +22,6 @@ class FileReadSource(str, Enum):
DEFAULT = 'default'
class RecallType(str, Enum):
"""The type of information that can be retrieved from microagents."""
WORKSPACE_CONTEXT = 'workspace_context'
"""Workspace context (repo instructions, runtime, etc.)"""
KNOWLEDGE = 'knowledge'
"""A knowledge microagent."""
@dataclass
class Event:
INVALID_ID = -1

View File

@@ -1,4 +1,3 @@
from openhands.events.event import RecallType
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
@@ -28,6 +27,7 @@ from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.observation.success import SuccessObservation
from openhands.events.observation.task_tracking import TaskTrackingObservation
from openhands.events.recall_type import RecallType
__all__ = [
'Observation',

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from openhands.core.schema import ObservationType
from openhands.events.event import RecallType
from openhands.events.observation.observation import Observation
from openhands.events.recall_type import RecallType
@dataclass

View File

@@ -0,0 +1,11 @@
from enum import Enum
class RecallType(str, Enum):
"""The type of information that can be retrieved from microagents."""
WORKSPACE_CONTEXT = 'workspace_context'
"""Workspace context (repo instructions, runtime, etc.)"""
KNOWLEDGE = 'knowledge'
"""A knowledge microagent."""

View File

@@ -1,7 +1,6 @@
import copy
from typing import Any
from openhands.events.event import RecallType
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
@@ -32,6 +31,7 @@ from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.observation.success import SuccessObservation
from openhands.events.observation.task_tracking import TaskTrackingObservation
from openhands.events.recall_type import RecallType
observations = (
NullObservation,

View File

@@ -28,7 +28,7 @@ from openhands.integrations.service_types import (
from openhands.utils.import_utils import get_impl
class AzureDevOpsServiceImpl(
class AzureDevOpsService(
AzureDevOpsResolverMixin,
AzureDevOpsReposMixin,
AzureDevOpsBranchesMixin,
@@ -242,8 +242,34 @@ class AzureDevOpsServiceImpl(
# Dynamic class loading to support custom implementations (e.g., SaaS)
azure_devops_service_cls = os.environ.get(
'OPENHANDS_AZURE_DEVOPS_SERVICE_CLS',
'openhands.integrations.azure_devops.azure_devops_service.AzureDevOpsServiceImpl',
)
AzureDevOpsServiceImpl = get_impl( # type: ignore[misc]
AzureDevOpsServiceImpl, azure_devops_service_cls
'openhands.integrations.azure_devops.azure_devops_service.AzureDevOpsService',
)
# Lazy loading to avoid circular imports
_azure_devops_service_impl = None
def get_azure_devops_service_impl():
"""Get the Azure DevOps service implementation with lazy loading."""
global _azure_devops_service_impl
if _azure_devops_service_impl is None:
_azure_devops_service_impl = get_impl( # type: ignore[misc]
AzureDevOpsService, azure_devops_service_cls
)
return _azure_devops_service_impl
# For backward compatibility, provide the implementation as a property
class _AzureDevOpsServiceImplProxy:
"""Proxy class to provide lazy loading for AzureDevOpsServiceImpl."""
def __getattr__(self, name):
impl = get_azure_devops_service_impl()
return getattr(impl, name)
def __call__(self, *args, **kwargs):
impl = get_azure_devops_service_impl()
return impl(*args, **kwargs)
AzureDevOpsServiceImpl: type[AzureDevOpsService] = _AzureDevOpsServiceImplProxy() # type: ignore[assignment]

View File

@@ -64,4 +64,30 @@ bitbucket_service_cls = os.environ.get(
'OPENHANDS_BITBUCKET_SERVICE_CLS',
'openhands.integrations.bitbucket.bitbucket_service.BitBucketService',
)
BitBucketServiceImpl = get_impl(BitBucketService, bitbucket_service_cls)
# Lazy loading to avoid circular imports
_bitbucket_service_impl = None
def get_bitbucket_service_impl():
"""Get the BitBucket service implementation with lazy loading."""
global _bitbucket_service_impl
if _bitbucket_service_impl is None:
_bitbucket_service_impl = get_impl(BitBucketService, bitbucket_service_cls)
return _bitbucket_service_impl
# For backward compatibility, provide the implementation as a property
class _BitBucketServiceImplProxy:
"""Proxy class to provide lazy loading for BitBucketServiceImpl."""
def __getattr__(self, name):
impl = get_bitbucket_service_impl()
return getattr(impl, name)
def __call__(self, *args, **kwargs):
impl = get_bitbucket_service_impl()
return impl(*args, **kwargs)
BitBucketServiceImpl: type[BitBucketService] = _BitBucketServiceImplProxy() # type: ignore[assignment]

View File

@@ -75,4 +75,30 @@ github_service_cls = os.environ.get(
'OPENHANDS_GITHUB_SERVICE_CLS',
'openhands.integrations.github.github_service.GitHubService',
)
GithubServiceImpl = get_impl(GitHubService, github_service_cls)
# Lazy loading to avoid circular imports
_github_service_impl = None
def get_github_service_impl():
"""Get the GitHub service implementation with lazy loading."""
global _github_service_impl
if _github_service_impl is None:
_github_service_impl = get_impl(GitHubService, github_service_cls)
return _github_service_impl
# For backward compatibility, provide the implementation as a property
class _GitHubServiceImplProxy:
"""Proxy class to provide lazy loading for GithubServiceImpl."""
def __getattr__(self, name):
impl = get_github_service_impl()
return getattr(impl, name)
def __call__(self, *args, **kwargs):
impl = get_github_service_impl()
return impl(*args, **kwargs)
GithubServiceImpl: type[GitHubService] = _GitHubServiceImplProxy() # type: ignore[assignment]

View File

@@ -79,4 +79,30 @@ gitlab_service_cls = os.environ.get(
'OPENHANDS_GITLAB_SERVICE_CLS',
'openhands.integrations.gitlab.gitlab_service.GitLabService',
)
GitLabServiceImpl = get_impl(GitLabService, gitlab_service_cls)
# Lazy loading to avoid circular imports
_gitlab_service_impl = None
def get_gitlab_service_impl():
"""Get the GitLab service implementation with lazy loading."""
global _gitlab_service_impl
if _gitlab_service_impl is None:
_gitlab_service_impl = get_impl(GitLabService, gitlab_service_cls)
return _gitlab_service_impl
# For backward compatibility, provide the implementation as a property
class _GitLabServiceImplProxy:
"""Proxy class to provide lazy loading for GitLabServiceImpl."""
def __getattr__(self, name):
impl = get_gitlab_service_impl()
return getattr(impl, name)
def __call__(self, *args, **kwargs):
impl = get_gitlab_service_impl()
return impl(*args, **kwargs)
GitLabServiceImpl: type[GitLabService] = _GitLabServiceImplProxy() # type: ignore[assignment]

View File

@@ -22,7 +22,7 @@ from openhands.events.action import (
)
from openhands.events.action.mcp import MCPAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import Event, RecallType
from openhands.events.event import Event
from openhands.events.observation import (
AgentCondensationObservation,
AgentDelegateObservation,
@@ -44,6 +44,7 @@ from openhands.events.observation.agent import (
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.events.recall_type import RecallType
from openhands.events.serialization.event import truncate_content
from openhands.utils.prompt import (
ConversationInstructions,

View File

@@ -9,12 +9,13 @@ import openhands
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.event import Event, EventSource
from openhands.events.observation.agent import (
MicroagentKnowledge,
RecallObservation,
)
from openhands.events.observation.empty import NullObservation
from openhands.events.recall_type import RecallType
from openhands.events.stream import EventStream, EventStreamSubscriber
from openhands.microagent import (
BaseMicroagent,

View File

@@ -159,7 +159,7 @@ class AgentSession:
await provider_handler.set_event_stream_secrets(self.event_stream)
if custom_secrets:
custom_secrets_handler.set_event_stream_secrets(self.event_stream)
self.event_stream.set_secrets(custom_secrets_handler.get_env_vars())
self.memory = await self._create_memory(
selected_repository=selected_repository,

View File

@@ -13,7 +13,6 @@ from pydantic import (
)
from pydantic.json import pydantic_encoder
from openhands.events.stream import EventStream
from openhands.integrations.provider import (
CUSTOM_SECRETS_TYPE,
PROVIDER_TOKEN_TYPE,
@@ -144,14 +143,6 @@ class Secrets(BaseModel):
return new_data
def set_event_stream_secrets(self, event_stream: EventStream) -> None:
"""This ensures that provider tokens and custom secrets masked from the event stream
Args:
event_stream: Agent session's event stream
"""
secrets = self.get_env_vars()
event_stream.set_secrets(secrets)
def get_env_vars(self) -> dict[str, str]:
secret_store = self.model_dump(context={'expose_secrets': True})
custom_secrets = secret_store.get('custom_secrets', {})

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import os
from pydantic import (
BaseModel,
ConfigDict,
@@ -51,7 +49,7 @@ class Settings(BaseModel):
email_verified: bool | None = None
git_user_name: str | None = None
git_user_email: str | None = None
v1_enabled: bool | None = Field(default=bool(os.getenv('V1_ENABLED') == '1'))
v1_enabled: bool = True
model_config = ConfigDict(
validate_assignment=True,