Files
sci-gui-agent-benchmark/scripts/tools/collect_task.py
2026-01-12 18:30:12 +08:00

266 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
任务数据采集入口
整合环境控制、轨迹录制、文件收集的完整流程
"""
import os
import sys
import argparse
import json
import logging
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.jade_env import JadeEnv
from core.recorder import record_interactive
from utils.config_loader import load_config, get_vm_config, get_network_config
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
def load_default_config():
"""加载默认配置"""
try:
config = load_config()
vm_config = get_vm_config(config)
network_config = get_network_config(config)
return {
"vmx_path": vm_config.get('vmx_path'),
"snapshot_name": vm_config.get('snapshot_name', 'Jade_Ready'),
"vm_ip": network_config.get('vm_ip'),
"vm_password": vm_config.get('vm_password'),
"guest_username": vm_config.get('guest_username'),
"guest_password": vm_config.get('guest_password')
}
except Exception as e:
logger.warning(f"⚠️ 无法加载config.json: {e}")
logger.info(" 使用硬编码配置")
return {
"vmx_path": "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx",
"snapshot_name": "Jade_Ready",
"vm_ip": "192.168.116.129",
"vm_password": "lizhanyuan",
"guest_username": "lzy",
"guest_password": "LIZHANYUAN"
}
def load_task_config(task_id, project_root="."):
"""加载任务配置文件"""
task_json_path = os.path.join(project_root, "tasks", task_id, "task.json")
if not os.path.exists(task_json_path):
logger.error(f"❌ 任务配置文件不存在: {task_json_path}")
logger.info(" 请先创建任务目录和task.json")
return None
with open(task_json_path, 'r', encoding='utf-8') as f:
return json.load(f)
def mode_reset(env, task_config, project_root="."):
"""
模式1: 重置环境并注入输入文件
"""
print("\n" + "=" * 60)
print("🔄 模式: 重置环境")
print("=" * 60)
# 1. 重置虚拟机
env.reset()
# 2. 注入输入文件
if 'input' in task_config:
input_config = task_config['input']
source_file = input_config.get('source_file')
if source_file:
# 处理相对路径(相对于任务目录)
if not os.path.isabs(source_file):
task_dir = os.path.join(project_root, "tasks", task_config['id'])
source_file = os.path.normpath(os.path.join(task_dir, source_file))
# 确保使用绝对路径
source_file = os.path.abspath(source_file)
if os.path.exists(source_file):
# 从Windows路径中提取文件名处理反斜杠
inject_to = input_config.get('inject_to', '')
if inject_to:
# 使用Windows路径分隔符分割
guest_filename = inject_to.split('\\')[-1]
else:
guest_filename = os.path.basename(source_file)
env.inject_file(source_file, guest_filename)
else:
logger.warning(f"⚠️ 输入文件不存在: {source_file}")
print("\n✅ 环境准备完成!")
print("=" * 60)
print("💡 下一步:开始录制操作")
print(f" python scripts/tools/collect_task.py {task_config['id']} --mode record")
print("=" * 60 + "\n")
def mode_record(env, task_config, project_root="."):
"""
模式2: 录制人类操作轨迹
"""
task_id = task_config['id']
output_dir = os.path.join(project_root, "tasks", task_id, "human_demo")
print("\n" + "=" * 60)
print("🎥 模式: 录制轨迹")
print("=" * 60)
print(f"任务: {task_config.get('instruction', 'N/A')}")
print("=" * 60)
# 创建录制器并开始录制
record_interactive(env, task_id, output_dir)
print("\n💡 下一步:处理坐标转换")
print(f" python scripts/tools/process_trajectory.py {task_id}")
print("=" * 60 + "\n")
def mode_collect(env, task_config, project_root="."):
"""
模式3: 收集输出文件到ground_truth
"""
print("\n" + "=" * 60)
print("📦 模式: 收集结果文件")
print("=" * 60)
task_id = task_config['id']
if 'output' in task_config:
output_config = task_config['output']
expected_file = output_config.get('expected_file')
if expected_file:
# 目标路径
gt_dir = os.path.join(project_root, "tasks", task_id, "ground_truth")
os.makedirs(gt_dir, exist_ok=True)
host_path = os.path.join(gt_dir, expected_file)
# 收集文件
env.collect_file(expected_file, host_path)
print(f"\n✅ 文件已保存到: {host_path}")
else:
logger.warning("⚠️ 任务配置中未指定expected_file")
else:
logger.warning("⚠️ 任务配置中未指定output")
print("\n" + "=" * 60)
print("💡 下一步:验证评测")
print(f" python scripts/tools/run_eval.py {task_id}")
print("=" * 60 + "\n")
def mode_full(env, task_config, project_root="."):
"""
模式4: 完整流程reset + record + collect
"""
print("\n" + "=" * 60)
print("🔄 模式: 完整采集流程")
print("=" * 60)
# Step 1: Reset
mode_reset(env, task_config, project_root)
# Step 2: Record
input("\n按Enter键开始录制...")
mode_record(env, task_config, project_root)
# Step 3: Collect
input("\n按Enter键收集结果...")
mode_collect(env, task_config, project_root)
print("\n✅ 完整采集流程完成!")
def main():
parser = argparse.ArgumentParser(
description="JADE Benchmark 任务数据采集工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 完整流程(推荐)
python scripts/collect_task.py smoothing_001 --mode full
# 分步执行
python scripts/collect_task.py smoothing_001 --mode reset # 1. 重置并注入文件
python scripts/collect_task.py smoothing_001 --mode record # 2. 录制操作
python scripts/collect_task.py smoothing_001 --mode collect # 3. 收集结果
"""
)
parser.add_argument("task_id", help="任务ID对应tasks/目录下的子目录名)")
parser.add_argument(
"--mode",
choices=["reset", "record", "collect", "full"],
default="full",
help="采集模式默认full"
)
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--vmx", help="虚拟机.vmx文件路径覆盖默认配置")
parser.add_argument("--snapshot", help="快照名称(覆盖默认配置)")
parser.add_argument("--vm-ip", help="虚拟机IP地址覆盖默认配置")
args = parser.parse_args()
# 加载任务配置
task_config = load_task_config(args.task_id, args.project_root)
if not task_config:
sys.exit(1)
# 加载并合并配置
config = load_default_config()
if args.vmx:
config['vmx_path'] = args.vmx
if args.snapshot:
config['snapshot_name'] = args.snapshot
if args.vm_ip:
config['vm_ip'] = args.vm_ip
# 初始化环境
try:
logger.info("初始化JADE环境...")
env = JadeEnv(
vmx_path=config['vmx_path'],
snapshot_name=config['snapshot_name'],
vm_ip=config['vm_ip'],
vm_password=config.get('vm_password'),
guest_username=config.get('guest_username'),
guest_password=config.get('guest_password')
)
# 执行对应模式
if args.mode == "reset":
mode_reset(env, task_config, args.project_root)
elif args.mode == "record":
mode_record(env, task_config, args.project_root)
elif args.mode == "collect":
mode_collect(env, task_config, args.project_root)
elif args.mode == "full":
mode_full(env, task_config, args.project_root)
except KeyboardInterrupt:
print("\n\n⏹ 操作已取消")
sys.exit(1)
except Exception as e:
logger.error(f"❌ 错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()