[chore] Run full agent pre-commit (#8235)

This commit is contained in:
Engel Nyst 2025-05-03 17:24:03 +02:00 committed by GitHub
parent 98cb2e24ee
commit 985e20d529
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 186 additions and 147 deletions

View File

@ -20,7 +20,7 @@ MCP configuration is defined in the `[mcp]` section of your `config.toml` file.
sse_servers = [
# Basic SSE server with just a URL
"http://example.com:8080/mcp",
# SSE server with API key authentication
{url="https://secure-example.com/mcp", api_key="your-api-key"}
]
@ -29,7 +29,7 @@ sse_servers = [
stdio_servers = [
# Basic stdio server
{name="fetch", command="uvx", args=["mcp-server-fetch"]},
# Stdio server with environment variables
{
name="data-processor",

View File

@ -55,4 +55,4 @@
"node": ">=18.0"
},
"packageManager": "npm@10.5.0"
}
}

View File

@ -36,13 +36,12 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction, FileReadAction
from openhands.events.action import CmdRunAction, FileReadAction, MessageAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
from openhands.utils.shutdown_listener import sleep_if_should_continue
import pdb
USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'true').lower() == 'true'
@ -51,7 +50,7 @@ RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'tru
# TODO: migrate all swe-bench docker to ghcr.io/openhands
# TODO: 适应所有的语言
DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', '')
LANGUAGE =os.environ.get('LANGUAGE', 'python')
LANGUAGE = os.environ.get('LANGUAGE', 'python')
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
@ -71,7 +70,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
# Instruction based on Anthropic's official trajectory
# https://github.com/eschluntz/swe-bench-experiments/tree/main/evaluation/verified/20241022_tools_claude-3-5-sonnet-updated/trajs
instructions = {
"python":(
'python': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@ -96,7 +95,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"java": (
'java': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@ -121,7 +120,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
" Make sure all these tests pass with your changes.\n"
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"go": (
'go': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@ -146,7 +145,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"c": (
'c': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@ -171,7 +170,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"cpp": (
'cpp': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@ -196,7 +195,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"javascript": (
'javascript': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@ -221,7 +220,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"typescript":(
'typescript': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@ -246,7 +245,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"rust":(
'rust': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@ -270,11 +269,10 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' - The functions you changed\n'
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
)
),
}
instruction = instructions.get(LANGUAGE.lower())
if instruction and RUN_WITH_BROWSING:
instruction += (
'<IMPORTANT!>\n'
@ -284,7 +282,6 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
return instruction
# TODO: 适应所有的语言
# def get_instance_docker_image(instance_id: str) -> str:
# image_name = 'sweb.eval.x86_64.' + instance_id
@ -307,16 +304,15 @@ def get_instance_docker_image(instance: pd.Series):
container_name = container_name.replace('/', '_m_')
instance_id = instance.get('instance_id', '')
tag_suffix = instance_id.split('-')[-1] if instance_id else ''
container_tag = f"pr-{tag_suffix}"
container_tag = f'pr-{tag_suffix}'
# pdb.set_trace()
return f"mswebench/{container_name}:{container_tag}"
return f'mswebench/{container_name}:{container_tag}'
# return "kong/insomnia:pr-8284"
# return "'sweb.eval.x86_64.local_insomnia"
# return "local_insomnia_why"
# return "local/kong-insomnia:pr-8117"
def get_config(
instance: pd.Series,
metadata: EvalMetadata,
@ -569,7 +565,6 @@ def complete_runtime(
f'Failed to git config --global core.pager "": {str(obs)}',
)
action = CmdRunAction(command='git add -A')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
@ -582,14 +577,14 @@ def complete_runtime(
##删除二进制文件
action = CmdRunAction(
command=f'''
command="""
for file in $(git status --porcelain | grep -E "^(M| M|\\?\\?|A| A)" | cut -c4-); do
if [ -f "$file" ] && (file "$file" | grep -q "executable" || git check-attr binary "$file" | grep -q "binary: set"); then
git rm -f "$file" 2>/dev/null || rm -f "$file"
echo "Removed: $file"
fi
done
'''
"""
)
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
@ -626,14 +621,12 @@ def complete_runtime(
else:
assert_and_raise(False, f'Unexpected observation type: {str(obs)}')
action = FileReadAction(
path='patch.diff'
)
action = FileReadAction(path='patch.diff')
action.set_hard_timeout(max(300 + 100 * n_retries, 600))
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
git_patch = obs.content
# pdb.set_trace()
# pdb.set_trace()
assert_and_raise(git_patch is not None, 'Failed to get git diff (None)')
@ -714,12 +707,12 @@ def process_instance(
is_binary_block = False
for line in lines:
if line.startswith("diff --git "):
if line.startswith('diff --git '):
if block and not is_binary_block:
cleaned_lines.extend(block)
block = [line]
is_binary_block = False
elif "Binary files" in line:
elif 'Binary files' in line:
is_binary_block = True
block.append(line)
else:
@ -727,7 +720,8 @@ def process_instance(
if block and not is_binary_block:
cleaned_lines.extend(block)
return "\n".join(cleaned_lines)
return '\n'.join(cleaned_lines)
git_patch = remove_binary_diffs(git_patch)
test_result = {
'git_patch': git_patch,
@ -797,7 +791,7 @@ if __name__ == '__main__':
# so we don't need to manage file uploading to OpenHands's repo
# dataset = load_dataset(args.dataset, split=args.split)
# dataset = load_dataset(args.dataset)
dataset = load_dataset("json", data_files = args.dataset)
dataset = load_dataset('json', data_files=args.dataset)
dataset = dataset[args.split]
swe_bench_tests = filter_dataset(dataset.to_pandas(), 'instance_id')
logger.info(

View File

@ -3,7 +3,9 @@ import json
input_file = 'XXX.jsonl'
output_file = 'YYY.jsonl'
with open(input_file, 'r', encoding='utf-8') as fin, open(output_file, 'w', encoding='utf-8') as fout:
with open(input_file, 'r', encoding='utf-8') as fin, open(
output_file, 'w', encoding='utf-8'
) as fout:
for line in fin:
line = line.strip()
if not line:
@ -13,18 +15,22 @@ with open(input_file, 'r', encoding='utf-8') as fin, open(output_file, 'w', enco
item = data
# 提取原始数据
org = item.get("org", "")
repo = item.get("repo", "")
number = str(item.get("number", ""))
org = item.get('org', '')
repo = item.get('repo', '')
number = str(item.get('number', ''))
new_item = {}
new_item["repo"] = f"{org}/{repo}"
new_item["instance_id"] = f"{org}__{repo}-{number}"
new_item["problem_statement"] = item["resolved_issues"][0].get("title", "") + "\n" + item["resolved_issues"][0].get("body", "")
new_item["FAIL_TO_PASS"] = []
new_item["PASS_TO_PASS"] = []
new_item["base_commit"] = item['base'].get("sha","")
new_item["version"] = "0.1" # depends
new_item['repo'] = f'{org}/{repo}'
new_item['instance_id'] = f'{org}__{repo}-{number}'
new_item['problem_statement'] = (
item['resolved_issues'][0].get('title', '')
+ '\n'
+ item['resolved_issues'][0].get('body', '')
)
new_item['FAIL_TO_PASS'] = []
new_item['PASS_TO_PASS'] = []
new_item['base_commit'] = item['base'].get('sha', '')
new_item['version'] = '0.1' # depends
output_data = new_item
fout.write(json.dumps(output_data, ensure_ascii=False) + "\n")
fout.write(json.dumps(output_data, ensure_ascii=False) + '\n')

View File

@ -15,7 +15,7 @@ def main():
'org': groups.group(1),
'repo': groups.group(2),
'number': groups.group(3),
'fix_patch': data['test_result']['git_patch']
'fix_patch': data['test_result']['git_patch'],
}
fout.write(json.dumps(patch) + '\n')

View File

@ -27,7 +27,7 @@ describe("AuthModal", () => {
it("should render the GitHub and GitLab buttons", () => {
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
const githubButton = screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" });
const gitlabButton = screen.getByRole("button", { name: "GITLAB$CONNECT_TO_GITLAB" });

View File

@ -43,7 +43,7 @@ const createWrapper = () => {
},
},
});
return ({ children }: { children: React.ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
@ -61,7 +61,7 @@ describe("AcceptTOS", () => {
it("should render a TOS checkbox that is unchecked by default", () => {
render(<AcceptTOS />, { wrapper: createWrapper() });
const checkbox = screen.getByRole("checkbox");
const continueButton = screen.getByRole("button", { name: "TOS$CONTINUE" });
@ -72,7 +72,7 @@ describe("AcceptTOS", () => {
it("should enable the continue button when the TOS checkbox is checked", async () => {
const user = userEvent.setup();
render(<AcceptTOS />, { wrapper: createWrapper() });
const checkbox = screen.getByRole("checkbox");
const continueButton = screen.getByRole("button", { name: "TOS$CONTINUE" });
@ -96,7 +96,7 @@ describe("AcceptTOS", () => {
const user = userEvent.setup();
render(<AcceptTOS />, { wrapper: createWrapper() });
const checkbox = screen.getByRole("checkbox");
await user.click(checkbox);
@ -121,7 +121,7 @@ describe("AcceptTOS", () => {
const user = userEvent.setup();
render(<AcceptTOS />, { wrapper: createWrapper() });
const checkbox = screen.getByRole("checkbox");
await user.click(checkbox);
@ -133,4 +133,4 @@ describe("AcceptTOS", () => {
expect(window.location.href).toBe(externalUrl);
});
});
});

View File

@ -390,7 +390,9 @@ class GitHubService(BaseGitService, GitService):
except Exception:
return []
async def get_repository_details_from_repo_name(self, repository: str) -> Repository:
async def get_repository_details_from_repo_name(
self, repository: str
) -> Repository:
url = f'{self.BASE_URL}/repos/{repository}'
repo, _ = await self._make_request(url)

View File

@ -382,9 +382,10 @@ class GitLabService(BaseGitService, GitService):
except Exception:
return []
async def get_repository_details_from_repo_name(self, repository: str) -> Repository:
encoded_name = repository.replace("/", "%2F")
async def get_repository_details_from_repo_name(
self, repository: str
) -> Repository:
encoded_name = repository.replace('/', '%2F')
url = f'{self.BASE_URL}/projects/{encoded_name}'
repo, _ = await self._make_request(url)
@ -396,8 +397,6 @@ class GitLabService(BaseGitService, GitService):
git_provider=ProviderType.GITLAB,
is_public=repo.get('visibility') == 'public',
)
gitlab_service_cls = os.environ.get(

View File

@ -3,4 +3,4 @@ Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retriev
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then use the {{ apiName }} to look at the {{ ciSystem }} that are failing on the most recent commit. Try and reproduce the failure locally.
Get things working locally, then push your changes. Sleep for 30 seconds at a time until the {{ ciProvider }} {{ ciSystem.lower() }} have run again.
If they are still failing, repeat the process.
If they are still failing, repeat the process.

View File

@ -1,4 +1,4 @@
You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }}. You need to fix the merge conflicts.
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then resolve the merge conflicts. If you aren't sure what the right solution is, look back through the commit history at the commits that introduced the conflict and resolve them accordingly.
Then resolve the merge conflicts. If you aren't sure what the right solution is, look back through the commit history at the commits that introduced the conflict and resolve them accordingly.

View File

@ -1,4 +1,4 @@
You are working on Issue #{{ issue_number }} in repository {{ repo }}. Your goal is to fix the issue.
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the issue details and any comments on the issue.
Then check out a new branch and investigate what changes will need to be made.
Finally, make the required changes and open up a {{ requestVerb }}. Be sure to reference the issue in the {{ requestTypeShort }} description.
Finally, make the required changes and open up a {{ requestVerb }}. Be sure to reference the issue in the {{ requestTypeShort }} description.

View File

@ -2,4 +2,4 @@ You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then use the {{ apiName }} to retrieve all the feedback on the {{ requestTypeShort }} so far.
If anything hasn't been addressed, address it and commit your changes back to the same branch.
If anything hasn't been addressed, address it and commit your changes back to the same branch.

View File

@ -1,4 +1,3 @@
import asyncio
import os
import tempfile
import threading
@ -46,6 +45,7 @@ from openhands.runtime.utils.request import send_request
from openhands.utils.http_session import HttpSession
from openhands.utils.tenacity_stop import stop_if_should_exit
def _is_retryable_error(exception):
return isinstance(
exception, (httpx.RemoteProtocolError, httpcore.RemoteProtocolError)
@ -358,26 +358,27 @@ class ActionExecutionClient(Runtime):
async def call_tool_mcp(self, action: MCPAction) -> Observation:
# Import here to avoid circular imports
from openhands.mcp.utils import create_mcp_clients, call_tool_mcp as call_tool_mcp_handler
from openhands.mcp.utils import call_tool_mcp as call_tool_mcp_handler
from openhands.mcp.utils import create_mcp_clients
# Get the updated MCP config
updated_mcp_config = self.get_updated_mcp_config()
self.log(
'debug',
f'Creating MCP clients with servers: {updated_mcp_config.sse_servers}',
)
# Create clients for this specific operation
mcp_clients = await create_mcp_clients(updated_mcp_config.sse_servers)
# Call the tool and return the result
# No need for try/finally since disconnect() is now just resetting state
result = await call_tool_mcp_handler(mcp_clients, action)
# Reset client state (no active connections to worry about)
for client in mcp_clients:
await client.disconnect()
return result
def close(self) -> None:

View File

@ -10,8 +10,8 @@ from openhands.events.event_store import EventStore
from openhands.server.config.server_config import ServerConfig
from openhands.server.monitoring import MonitoringListener
from openhands.server.session.conversation import Conversation
from openhands.storage.data_models.settings import Settings
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore

View File

@ -18,9 +18,9 @@ from openhands.server.monitoring import MonitoringListener
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
from openhands.storage.data_models.settings import Settings
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync, wait_all
from openhands.utils.import_utils import get_impl

View File

@ -14,7 +14,11 @@ from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
)
from openhands.integrations.service_types import AuthenticationError, ProviderType, Repository, SuggestedTask
from openhands.integrations.service_types import (
AuthenticationError,
ProviderType,
SuggestedTask,
)
from openhands.runtime import get_runtime_cls
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
@ -45,7 +49,6 @@ from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.utils.async_utils import wait_all
from openhands.utils.conversation_summary import generate_conversation_title
app = APIRouter(prefix='/api')
@ -58,7 +61,7 @@ class InitSessionRequest(BaseModel):
image_urls: list[str] | None = None
replay_json: str | None = None
suggested_task: SuggestedTask | None = None
async def _create_new_conversation(
user_id: str | None,
@ -71,10 +74,13 @@ async def _create_new_conversation(
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
attach_convo_id: bool = False,
):
logger.info(
'Creating conversation',
extra={'signal': 'create_conversation', 'user_id': user_id, 'trigger': conversation_trigger.value},
extra={
'signal': 'create_conversation',
'user_id': user_id,
'trigger': conversation_trigger.value,
},
)
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
@ -163,7 +169,7 @@ async def new_conversation(
data: InitSessionRequest,
user_id: str = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
auth_type: AuthType | None = Depends(get_auth_type)
auth_type: AuthType | None = Depends(get_auth_type),
):
"""Initialize a new session or join an existing one.
@ -202,7 +208,7 @@ async def new_conversation(
initial_user_msg=initial_user_msg,
image_urls=image_urls,
replay_json=replay_json,
conversation_trigger=conversation_trigger
conversation_trigger=conversation_trigger,
)
return JSONResponse(
@ -227,13 +233,13 @@ async def new_conversation(
},
status_code=status.HTTP_400_BAD_REQUEST,
)
except AuthenticationError as e:
return JSONResponse(
content={
'status': 'error',
'message': str(e),
'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR'
'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR',
},
status_code=status.HTTP_400_BAD_REQUEST,
)

View File

@ -2,9 +2,8 @@ from typing import Any
from fastapi import APIRouter
from openhands.security.options import SecurityAnalyzers
from openhands.controller.agent import Agent
from openhands.security.options import SecurityAnalyzers
from openhands.server.shared import config, server_config
from openhands.utils.llm import get_supported_llm_models

View File

@ -15,12 +15,12 @@ from openhands.server.settings import (
POSTSettingsModel,
)
from openhands.server.shared import config
from openhands.storage.data_models.settings import Settings
from openhands.server.user_auth import (
get_provider_tokens,
get_user_settings,
get_user_settings_store,
)
from openhands.storage.data_models.settings import Settings
from openhands.storage.settings.settings_store import SettingsStore
app = APIRouter(prefix='/api')
@ -38,7 +38,7 @@ async def load_settings(
content={'error': 'Settings not found'},
)
provider_tokens_set: dict[ProviderType, str | None] = {}
provider_tokens_set: dict[ProviderType, str | None] = {}
if provider_tokens:
for provider_type, provider_token in provider_tokens.items():
if provider_token.token or provider_token.user_id:
@ -227,8 +227,7 @@ async def store_provider_tokens(
if existing_settings:
if existing_settings.secrets_store:
existing_providers = [
provider
for provider in existing_settings.secrets_store.provider_tokens
provider for provider in existing_settings.secrets_store.provider_tokens
]
# Merge incoming settings store with the existing one
@ -245,7 +244,7 @@ async def store_provider_tokens(
else: # nothing passed in means keep current settings
provider_tokens = dict(existing_settings.secrets_store.provider_tokens)
settings.provider_tokens = provider_tokens
return settings
@ -334,7 +333,11 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
# Create new provider tokens immutably
if settings_with_token_data.provider_tokens:
settings = settings.model_copy(
update={'secrets_store': SecretStore(provider_tokens=settings_with_token_data.provider_tokens)}
update={
'secrets_store': SecretStore(
provider_tokens=settings_with_token_data.provider_tokens
)
}
)
return settings

View File

@ -17,7 +17,6 @@ from openhands.events.action import ChangeAgentStateAction, MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.stream import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroagent
@ -420,9 +419,7 @@ class AgentSession:
memory.load_user_workspace_microagents(microagents)
if selected_repository and repo_directory:
memory.set_repository_info(
selected_repository, repo_directory
)
memory.set_repository_info(selected_repository, repo_directory)
return memory
def _maybe_restore_state(self) -> State | None:

View File

@ -21,8 +21,8 @@ from openhands.events.observation import (
CmdOutputObservation,
NullObservation,
)
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
@ -214,7 +214,8 @@ class Session:
await self.send(event_to_dict(event))
# NOTE: ipython observations are not sent here currently
elif event.source == EventSource.ENVIRONMENT and isinstance(
event, (CmdOutputObservation, AgentStateChangedObservation, RecallObservation)
event,
(CmdOutputObservation, AgentStateChangedObservation, RecallObservation),
):
# feedback from the environment to agent actions is understood as agent events by the UI
event_dict = event_to_dict(event)

View File

@ -50,4 +50,4 @@ async def get_user_settings_store(request: Request) -> SettingsStore | None:
async def get_auth_type(request: Request) -> AuthType | None:
user_auth = await get_user_auth(request)
return user_auth.get_auth_type()
return user_auth.get_auth_type()

View File

@ -51,7 +51,6 @@ class DefaultUserAuth(UserAuth):
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
return provider_tokens
@classmethod
async def get_instance(cls, request: Request) -> UserAuth:
user_auth = DefaultUserAuth()

View File

@ -14,8 +14,8 @@ from openhands.utils.import_utils import get_impl
class AuthType(Enum):
COOKIE = "cookie"
BEARER = "bearer"
COOKIE = 'cookie'
BEARER = 'bearer'
class UserAuth(ABC):

View File

@ -4,8 +4,8 @@ import json
from dataclasses import dataclass
from openhands.core.config.app_config import AppConfig
from openhands.storage.data_models.settings import Settings
from openhands.storage import get_file_store
from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import call_sync_from_async

View File

@ -26,7 +26,7 @@ from openhands.resolver.resolver_output import ResolverOutput
@pytest.fixture
def default_mock_args():
"""Fixture that provides a default mock args object with common values.
Tests can override specific attributes as needed.
"""
mock_args = MagicMock()
@ -53,10 +53,13 @@ def default_mock_args():
@pytest.fixture
def mock_github_token():
"""Fixture that patches the identify_token function to return GitHub provider type.
This eliminates the need for repeated patching in each test function.
"""
with patch('openhands.resolver.resolve_issue.identify_token', return_value=ProviderType.GITHUB) as patched:
with patch(
'openhands.resolver.resolve_issue.identify_token',
return_value=ProviderType.GITHUB,
) as patched:
yield patched
@ -152,7 +155,9 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_github_toke
# Verify that the handler was correctly configured and called
resolver.issue_handler_factory.assert_called_once()
mock_handler.get_converted_issues.assert_called_once_with(issue_numbers=[5432], comment_id=None)
mock_handler.get_converted_issues.assert_called_once_with(
issue_numbers=[5432], comment_id=None
)
def test_download_issues_from_github():
@ -348,9 +353,7 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
# Create resolver with mocked token identification
resolver = IssueResolver(default_mock_args)
result = await resolver.complete_runtime(
mock_runtime, 'base_commit_hash'
)
result = await resolver.complete_runtime(mock_runtime, 'base_commit_hash')
assert result == {'git_patch': 'git diff content'}
assert mock_runtime.run_action.call_count == 5
@ -358,7 +361,7 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_case",
'test_case',
[
{
'name': 'successful_run',
@ -410,11 +413,20 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
'expected_error': None,
'expected_explanation': 'Non-JSON explanation',
'is_pr': True,
'comment_success': [True, False], # To trigger the PR success logging code path
'comment_success': [
True,
False,
], # To trigger the PR success logging code path
},
],
)
async def test_process_issue(default_mock_args, mock_github_token, mock_output_dir, mock_prompt_template, test_case):
async def test_process_issue(
default_mock_args,
mock_github_token,
mock_output_dir,
mock_prompt_template,
test_case,
):
"""Test the process_issue method with different scenarios."""
# Set up test data
@ -426,7 +438,7 @@ async def test_process_issue(default_mock_args, mock_github_token, mock_output_d
body='This is a test issue',
)
base_commit = 'abcdef1234567890'
# Customize the mock args for this test
default_mock_args.output_dir = mock_output_dir
default_mock_args.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
@ -457,7 +469,7 @@ async def test_process_issue(default_mock_args, mock_github_token, mock_output_d
# Mock the create_runtime function
mock_create_runtime = MagicMock(return_value=mock_runtime)
# Mock the run_controller function
mock_run_controller = AsyncMock()
if test_case['run_controller_raises']:
@ -466,14 +478,15 @@ async def test_process_issue(default_mock_args, mock_github_token, mock_output_d
mock_run_controller.return_value = test_case['run_controller_return']
# Patch the necessary functions and methods
with patch('openhands.resolver.resolve_issue.create_runtime', mock_create_runtime), \
patch('openhands.resolver.resolve_issue.run_controller', mock_run_controller), \
patch.object(resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}), \
patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime:
with patch(
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
), patch(
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
), patch.object(
resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}
), patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime:
# Call the process_issue method
result = await resolver.process_issue(issue, base_commit, handler_instance)
# Assert the result matches our expectations
assert isinstance(result, ResolverOutput)
@ -490,16 +503,17 @@ async def test_process_issue(default_mock_args, mock_github_token, mock_output_d
mock_initialize_runtime.assert_called_once()
mock_run_controller.assert_called_once()
resolver.complete_runtime.assert_awaited_once_with(mock_runtime, base_commit)
# Assert run_controller was called with the right parameters
if not test_case['run_controller_raises']:
# Check that the first positional argument is a config
assert 'config' in mock_run_controller.call_args[1]
# Check that initial_user_action is a MessageAction with the right content
assert isinstance(mock_run_controller.call_args[1]['initial_user_action'], MessageAction)
assert isinstance(
mock_run_controller.call_args[1]['initial_user_action'], MessageAction
)
assert mock_run_controller.call_args[1]['runtime'] == mock_runtime
# Assert that guess_success was called only for successful runs
if test_case['expected_success']:
handler_instance.guess_success.assert_called_once()

View File

@ -19,14 +19,16 @@ from openhands.resolver.interfaces.issue_definitions import (
ServiceContextIssue,
ServiceContextPR,
)
from openhands.resolver.resolve_issue import IssueResolver, SandboxConfig, AppConfig, AgentConfig
from openhands.resolver.resolve_issue import (
IssueResolver,
)
from openhands.resolver.resolver_output import ResolverOutput
@pytest.fixture
def default_mock_args():
"""Fixture that provides a default mock args object with common values.
Tests can override specific attributes as needed.
"""
mock_args = MagicMock()
@ -52,10 +54,13 @@ def default_mock_args():
@pytest.fixture
def mock_gitlab_token():
"""Fixture that patches the identify_token function to return GitLab provider type.
This eliminates the need for repeated patching in each test function.
"""
with patch('openhands.resolver.resolve_issue.identify_token', return_value=ProviderType.GITLAB) as patched:
with patch(
'openhands.resolver.resolve_issue.identify_token',
return_value=ProviderType.GITLAB,
) as patched:
yield patched
@ -124,10 +129,10 @@ def test_initialize_runtime(default_mock_args, mock_gitlab_token):
exit_code=0, content='', command='git config --global core.pager ""'
),
]
# Create resolver with mocked token identification
resolver = IssueResolver(default_mock_args)
resolver.initialize_runtime(mock_runtime)
if os.getenv('GITLAB_CI') == 'true':
@ -154,24 +159,26 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_gitlab_toke
# Customize the mock args for this test
default_mock_args.issue_number = 5432
# Create a resolver instance with mocked token identification
resolver = IssueResolver(default_mock_args)
# Mock the issue_handler_factory method
resolver.issue_handler_factory = MagicMock(return_value=mock_handler)
# Test that the correct exception is raised
with pytest.raises(ValueError) as exc_info:
await resolver.resolve_issue()
# Verify the error message
assert 'No issues found for issue number 5432' in str(exc_info.value)
assert 'test-owner/test-repo' in str(exc_info.value)
# Verify that the handler was correctly configured and called
resolver.issue_handler_factory.assert_called_once()
mock_handler.get_converted_issues.assert_called_once_with(issue_numbers=[5432], comment_id=None)
mock_handler.get_converted_issues.assert_called_once_with(
issue_numbers=[5432], comment_id=None
)
def test_download_issues_from_gitlab():
@ -377,12 +384,14 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
content='',
command='git config --global --add safe.directory /workspace',
),
create_cmd_output(exit_code=0, content='', command='git add -A'),
create_cmd_output(
exit_code=0, content='', command='git add -A'
exit_code=0,
content='git diff content',
command='git diff --no-color --cached base_commit_hash',
),
create_cmd_output(exit_code=0, content='git diff content', command='git diff --no-color --cached base_commit_hash'),
]
# Create a resolver instance with mocked token identification
resolver = IssueResolver(default_mock_args)
@ -394,7 +403,7 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_case",
'test_case',
[
{
'name': 'successful_run',
@ -448,7 +457,13 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
},
],
)
async def test_process_issue(default_mock_args, mock_gitlab_token, mock_output_dir, mock_prompt_template, test_case):
async def test_process_issue(
default_mock_args,
mock_gitlab_token,
mock_output_dir,
mock_prompt_template,
test_case,
):
"""Test the process_issue method with different scenarios."""
# Set up test data
issue = Issue(
@ -482,7 +497,7 @@ async def test_process_issue(default_mock_args, mock_gitlab_token, mock_output_d
mock_runtime = MagicMock()
mock_runtime.connect = AsyncMock()
mock_create_runtime = MagicMock(return_value=mock_runtime)
# Configure run_controller mock based on test case
mock_run_controller = AsyncMock()
if test_case.get('run_controller_raises'):
@ -491,16 +506,18 @@ async def test_process_issue(default_mock_args, mock_gitlab_token, mock_output_d
mock_run_controller.return_value = test_case['run_controller_return']
# Patch the necessary functions and methods
with patch('openhands.resolver.resolve_issue.create_runtime', mock_create_runtime), \
patch('openhands.resolver.resolve_issue.run_controller', mock_run_controller), \
patch.object(resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}), \
patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime, \
patch('openhands.resolver.resolve_issue.SandboxConfig', return_value=MagicMock()), \
patch('openhands.resolver.resolve_issue.AppConfig', return_value=MagicMock()):
with patch(
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
), patch(
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
), patch.object(
resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}
), patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime, patch(
'openhands.resolver.resolve_issue.SandboxConfig', return_value=MagicMock()
), patch('openhands.resolver.resolve_issue.AppConfig', return_value=MagicMock()):
# Call the process_issue method
result = await resolver.process_issue(issue, base_commit, handler_instance)
mock_create_runtime.assert_called_once()
mock_runtime.connect.assert_called_once()
mock_initialize_runtime.assert_called_once()
@ -521,6 +538,7 @@ async def test_process_issue(default_mock_args, mock_gitlab_token, mock_output_d
else:
handler_instance.guess_success.assert_not_called()
def test_get_instruction(mock_prompt_template, mock_followup_prompt_template):
issue = Issue(
owner='test_owner',
@ -923,4 +941,4 @@ def test_download_issue_with_specific_comment():
if __name__ == '__main__':
pytest.main()
pytest.main()