mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 13:52:43 +08:00
288 lines
8.0 KiB
Plaintext
Vendored
288 lines
8.0 KiB
Plaintext
Vendored
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"tqdm.pandas()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 1. Load raw data and convert to training data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gzip\n",
|
|
"import json\n",
|
|
"\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"FILE_PATHS = [\n",
|
|
" 'YOURPATH-no-hint-train-t05-run_1/output.with_completions.jsonl.gz',\n",
|
|
" 'YOURPATH-no-hint-train-t05-run_2/output.with_completions.jsonl.gz',\n",
|
|
"]\n",
|
|
"\n",
|
|
"# More memory efficient for large files\n",
|
|
"# Initialize lists to store the data\n",
|
|
"data = []\n",
|
|
"\n",
|
|
"\n",
|
|
"# Read file line by line\n",
|
|
"for FILE_PATH in FILE_PATHS:\n",
|
|
" with gzip.open(FILE_PATH, 'rb') as f: # Use 'rb' for gzipped files\n",
|
|
" for i, line in tqdm(\n",
|
|
" enumerate(f), desc=f'Processing {FILE_PATH.split(\"/\")[-1]}'\n",
|
|
" ):\n",
|
|
" # Parse only the fields we need\n",
|
|
" raw_data = json.loads(line)\n",
|
|
" data.append(\n",
|
|
" {\n",
|
|
" 'resolved': raw_data['report']['resolved'],\n",
|
|
" 'messages': raw_data['raw_completions']['messages']\n",
|
|
" if raw_data['raw_completions'] is not None\n",
|
|
" else None,\n",
|
|
" 'git_patch': raw_data['test_result'].get('git_patch', ''),\n",
|
|
" 'tools': raw_data['raw_completions']['tools']\n",
|
|
" if raw_data['raw_completions'] is not None\n",
|
|
" and 'tools' in raw_data['raw_completions']\n",
|
|
" else None,\n",
|
|
" }\n",
|
|
" )\n",
|
|
"\n",
|
|
"# Convert to DataFrame after collecting all data\n",
|
|
"df = pd.DataFrame(data)\n",
|
|
"print(f'#total amount of data={len(df)}')\n",
|
|
"df = df[~df['messages'].isna()]\n",
|
|
"print(f'#total amount of data after removing nan={len(df)}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Filter"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def _contains_multiple_tool_calls(messages: list[dict]) -> bool:\n",
|
|
" return any(\n",
|
|
" message.get('tool_calls') and len(message['tool_calls']) > 1\n",
|
|
" for message in messages\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"df['contains_multiple_tool_calls'] = df['messages'].apply(_contains_multiple_tool_calls)\n",
|
|
"display(df.groupby(['contains_multiple_tool_calls'])['resolved'].sum())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import copy\n",
|
|
"\n",
|
|
"# Convert function calling messages to non-function calling messages\n",
|
|
"from openhands.llm.fn_call_converter import (\n",
|
|
" FunctionCallConversionError,\n",
|
|
" convert_fncall_messages_to_non_fncall_messages,\n",
|
|
" convert_from_multiple_tool_calls_to_single_tool_call_messages,\n",
|
|
")\n",
|
|
"\n",
|
|
"total_failed = 0\n",
|
|
"\n",
|
|
"\n",
|
|
"def _convert_messages(messages: list[dict], tools: list[dict]) -> list[dict]:\n",
|
|
" global total_failed\n",
|
|
" message_copy = copy.deepcopy(messages)\n",
|
|
" for message in message_copy:\n",
|
|
" if message['content'] is None:\n",
|
|
" message['content'] = ''\n",
|
|
" try:\n",
|
|
" return convert_fncall_messages_to_non_fncall_messages(\n",
|
|
" message_copy, tools, add_in_context_learning_example=False\n",
|
|
" )\n",
|
|
" except FunctionCallConversionError:\n",
|
|
" total_failed += 1\n",
|
|
" # print(f'Failed to convert messages: {messages}\\nTools: {tools}')\n",
|
|
" # traceback.print_exc()\n",
|
|
" return None\n",
|
|
"\n",
|
|
"\n",
|
|
"df['converted_messages'] = df.apply(\n",
|
|
" lambda row: convert_from_multiple_tool_calls_to_single_tool_call_messages(\n",
|
|
" row['messages'], ignore_final_tool_result=True\n",
|
|
" ),\n",
|
|
" axis=1,\n",
|
|
")\n",
|
|
"df['nonfncall_messages'] = df.apply(\n",
|
|
" lambda row: _convert_messages(row['converted_messages'], row['tools']), axis=1\n",
|
|
")\n",
|
|
"print('total nan', df['nonfncall_messages'].isna().sum())\n",
|
|
"df = df[~df['nonfncall_messages'].isna()]\n",
|
|
"print(f'Total failed: {total_failed}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Tokenization"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from pandarallel import pandarallel\n",
|
|
"from transformers import AutoTokenizer\n",
|
|
"\n",
|
|
"os.environ['TOKENIZERS_PARALLELISM'] = 'false'\n",
|
|
"pandarallel.initialize(progress_bar=True, verbose=1, nb_workers=16)\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-7B-Instruct')\n",
|
|
"df['n_tokens'] = df['rm_conv'].parallel_apply(\n",
|
|
" lambda x: len(tokenizer.apply_chat_template(x))\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(f'BEFORE: #total={len(df)}')\n",
|
|
"df_selected = df[df['n_tokens'] < 131072]\n",
|
|
"print(f'AFTER(truncated to 128k): #total={len(df_selected)}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df_selected['n_tokens'].describe()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ecdf of n_tokens\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import seaborn as sns\n",
|
|
"\n",
|
|
"display(df.groupby(['resolved'])['n_tokens'].describe())\n",
|
|
"sns.ecdfplot(x='n_tokens', data=df, hue='resolved')\n",
|
|
"plt.show()\n",
|
|
"\n",
|
|
"print(f'#total={len(df)}')\n",
|
|
"df_selected = df[df['n_tokens'] < 131072]\n",
|
|
"print(f'#selected={len(df_selected)}')\n",
|
|
"display(df_selected.groupby(['resolved'])['n_tokens'].describe())\n",
|
|
"sns.ecdfplot(x='n_tokens', data=df_selected, hue='resolved')\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df_selected[~df_selected['resolved']]['n_tokens'].describe()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df_selected['resolved'].value_counts()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df_selected.groupby(['resolved'])['n_tokens'].describe()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Save Resolved Messages for SFT"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df_selected[df_selected['resolved']][['nonfncall_messages']].rename(\n",
|
|
" columns={'nonfncall_messages': 'messages'}\n",
|
|
").to_json(\n",
|
|
" os.path.join(\n",
|
|
" 'YOUR_OUTPUT_FOLDER',\n",
|
|
" f'policy_traj_128k_swegym_{df_selected[\"resolved\"].value_counts()[True]}i.jsonl',\n",
|
|
" ),\n",
|
|
" lines=True,\n",
|
|
" orient='records',\n",
|
|
")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "openhands-ai-CPy6G0pU-py3.12",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|