""" 任务数据采集入口 整合环境控制、轨迹录制、文件收集的完整流程 """ import os import sys import argparse import json import logging # 添加父目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from core.jade_env import JadeEnv from core.recorder import record_interactive from utils.config_loader import load_config, get_vm_config, get_network_config logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') logger = logging.getLogger(__name__) def load_default_config(): """加载默认配置""" try: config = load_config() vm_config = get_vm_config(config) network_config = get_network_config(config) return { "vmx_path": vm_config.get('vmx_path'), "snapshot_name": vm_config.get('snapshot_name', 'Jade_Ready'), "vm_ip": network_config.get('vm_ip'), "vm_password": vm_config.get('vm_password'), "guest_username": vm_config.get('guest_username'), "guest_password": vm_config.get('guest_password') } except Exception as e: logger.warning(f"⚠️ 无法加载config.json: {e}") logger.info(" 使用硬编码配置") return { "vmx_path": "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx", "snapshot_name": "Jade_Ready", "vm_ip": "192.168.116.129", "vm_password": "lizhanyuan", "guest_username": "lzy", "guest_password": "LIZHANYUAN" } def load_task_config(task_id, project_root="."): """加载任务配置文件""" task_json_path = os.path.join(project_root, "tasks", task_id, "task.json") if not os.path.exists(task_json_path): logger.error(f"❌ 任务配置文件不存在: {task_json_path}") logger.info(" 请先创建任务目录和task.json") return None with open(task_json_path, 'r', encoding='utf-8') as f: return json.load(f) def mode_reset(env, task_config, project_root="."): """ 模式1: 重置环境并注入输入文件 """ print("\n" + "=" * 60) print("🔄 模式: 重置环境") print("=" * 60) # 1. 重置虚拟机 env.reset() # 2. 注入输入文件 if 'input' in task_config: input_config = task_config['input'] source_file = input_config.get('source_file') if source_file: # 处理相对路径(相对于任务目录) if not os.path.isabs(source_file): task_dir = os.path.join(project_root, "tasks", task_config['id']) source_file = os.path.normpath(os.path.join(task_dir, source_file)) # 确保使用绝对路径 source_file = os.path.abspath(source_file) if os.path.exists(source_file): # 从Windows路径中提取文件名(处理反斜杠) inject_to = input_config.get('inject_to', '') if inject_to: # 使用Windows路径分隔符分割 guest_filename = inject_to.split('\\')[-1] else: guest_filename = os.path.basename(source_file) env.inject_file(source_file, guest_filename) else: logger.warning(f"⚠️ 输入文件不存在: {source_file}") print("\n✅ 环境准备完成!") print("=" * 60) print("💡 下一步:开始录制操作") print(f" python scripts/tools/collect_task.py {task_config['id']} --mode record") print("=" * 60 + "\n") def mode_record(env, task_config, project_root="."): """ 模式2: 录制人类操作轨迹 """ task_id = task_config['id'] output_dir = os.path.join(project_root, "tasks", task_id, "human_demo") print("\n" + "=" * 60) print("🎥 模式: 录制轨迹") print("=" * 60) print(f"任务: {task_config.get('instruction', 'N/A')}") print("=" * 60) # 创建录制器并开始录制 record_interactive(env, task_id, output_dir) print("\n💡 下一步:处理坐标转换") print(f" python scripts/tools/process_trajectory.py {task_id}") print("=" * 60 + "\n") def mode_collect(env, task_config, project_root="."): """ 模式3: 收集输出文件到ground_truth """ print("\n" + "=" * 60) print("📦 模式: 收集结果文件") print("=" * 60) task_id = task_config['id'] if 'output' in task_config: output_config = task_config['output'] expected_file = output_config.get('expected_file') if expected_file: # 目标路径 gt_dir = os.path.join(project_root, "tasks", task_id, "ground_truth") os.makedirs(gt_dir, exist_ok=True) host_path = os.path.join(gt_dir, expected_file) # 收集文件 env.collect_file(expected_file, host_path) print(f"\n✅ 文件已保存到: {host_path}") else: logger.warning("⚠️ 任务配置中未指定expected_file") else: logger.warning("⚠️ 任务配置中未指定output") print("\n" + "=" * 60) print("💡 下一步:验证评测") print(f" python scripts/tools/run_eval.py {task_id}") print("=" * 60 + "\n") def mode_full(env, task_config, project_root="."): """ 模式4: 完整流程(reset + record + collect) """ print("\n" + "=" * 60) print("🔄 模式: 完整采集流程") print("=" * 60) # Step 1: Reset mode_reset(env, task_config, project_root) # Step 2: Record input("\n按Enter键开始录制...") mode_record(env, task_config, project_root) # Step 3: Collect input("\n按Enter键收集结果...") mode_collect(env, task_config, project_root) print("\n✅ 完整采集流程完成!") def main(): parser = argparse.ArgumentParser( description="JADE Benchmark 任务数据采集工具", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 使用示例: # 完整流程(推荐) python scripts/collect_task.py smoothing_001 --mode full # 分步执行 python scripts/collect_task.py smoothing_001 --mode reset # 1. 重置并注入文件 python scripts/collect_task.py smoothing_001 --mode record # 2. 录制操作 python scripts/collect_task.py smoothing_001 --mode collect # 3. 收集结果 """ ) parser.add_argument("task_id", help="任务ID(对应tasks/目录下的子目录名)") parser.add_argument( "--mode", choices=["reset", "record", "collect", "full"], default="full", help="采集模式(默认:full)" ) parser.add_argument("--project-root", default=".", help="项目根目录") parser.add_argument("--vmx", help="虚拟机.vmx文件路径(覆盖默认配置)") parser.add_argument("--snapshot", help="快照名称(覆盖默认配置)") parser.add_argument("--vm-ip", help="虚拟机IP地址(覆盖默认配置)") args = parser.parse_args() # 加载任务配置 task_config = load_task_config(args.task_id, args.project_root) if not task_config: sys.exit(1) # 加载并合并配置 config = load_default_config() if args.vmx: config['vmx_path'] = args.vmx if args.snapshot: config['snapshot_name'] = args.snapshot if args.vm_ip: config['vm_ip'] = args.vm_ip # 初始化环境 try: logger.info("初始化JADE环境...") env = JadeEnv( vmx_path=config['vmx_path'], snapshot_name=config['snapshot_name'], vm_ip=config['vm_ip'], vm_password=config.get('vm_password'), guest_username=config.get('guest_username'), guest_password=config.get('guest_password') ) # 执行对应模式 if args.mode == "reset": mode_reset(env, task_config, args.project_root) elif args.mode == "record": mode_record(env, task_config, args.project_root) elif args.mode == "collect": mode_collect(env, task_config, args.project_root) elif args.mode == "full": mode_full(env, task_config, args.project_root) except KeyboardInterrupt: print("\n\n⏹ 操作已取消") sys.exit(1) except Exception as e: logger.error(f"❌ 错误: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()