266 lines
8.6 KiB
Python
266 lines
8.6 KiB
Python
"""
|
||
任务数据采集入口
|
||
整合环境控制、轨迹录制、文件收集的完整流程
|
||
"""
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import json
|
||
import logging
|
||
|
||
# 添加父目录到路径
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from core.jade_env import JadeEnv
|
||
from core.recorder import record_interactive
|
||
from utils.config_loader import load_config, get_vm_config, get_network_config
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def load_default_config():
|
||
"""加载默认配置"""
|
||
try:
|
||
config = load_config()
|
||
vm_config = get_vm_config(config)
|
||
network_config = get_network_config(config)
|
||
|
||
return {
|
||
"vmx_path": vm_config.get('vmx_path'),
|
||
"snapshot_name": vm_config.get('snapshot_name', 'Jade_Ready'),
|
||
"vm_ip": network_config.get('vm_ip'),
|
||
"vm_password": vm_config.get('vm_password'),
|
||
"guest_username": vm_config.get('guest_username'),
|
||
"guest_password": vm_config.get('guest_password')
|
||
}
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ 无法加载config.json: {e}")
|
||
logger.info(" 使用硬编码配置")
|
||
return {
|
||
"vmx_path": "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx",
|
||
"snapshot_name": "Jade_Ready",
|
||
"vm_ip": "192.168.116.129",
|
||
"vm_password": "lizhanyuan",
|
||
"guest_username": "lzy",
|
||
"guest_password": "LIZHANYUAN"
|
||
}
|
||
|
||
|
||
def load_task_config(task_id, project_root="."):
|
||
"""加载任务配置文件"""
|
||
task_json_path = os.path.join(project_root, "tasks", task_id, "task.json")
|
||
|
||
if not os.path.exists(task_json_path):
|
||
logger.error(f"❌ 任务配置文件不存在: {task_json_path}")
|
||
logger.info(" 请先创建任务目录和task.json")
|
||
return None
|
||
|
||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
|
||
|
||
def mode_reset(env, task_config, project_root="."):
|
||
"""
|
||
模式1: 重置环境并注入输入文件
|
||
"""
|
||
print("\n" + "=" * 60)
|
||
print("🔄 模式: 重置环境")
|
||
print("=" * 60)
|
||
|
||
# 1. 重置虚拟机
|
||
env.reset()
|
||
|
||
# 2. 注入输入文件
|
||
if 'input' in task_config:
|
||
input_config = task_config['input']
|
||
source_file = input_config.get('source_file')
|
||
|
||
if source_file:
|
||
# 处理相对路径(相对于任务目录)
|
||
if not os.path.isabs(source_file):
|
||
task_dir = os.path.join(project_root, "tasks", task_config['id'])
|
||
source_file = os.path.normpath(os.path.join(task_dir, source_file))
|
||
|
||
# 确保使用绝对路径
|
||
source_file = os.path.abspath(source_file)
|
||
|
||
if os.path.exists(source_file):
|
||
# 从Windows路径中提取文件名(处理反斜杠)
|
||
inject_to = input_config.get('inject_to', '')
|
||
if inject_to:
|
||
# 使用Windows路径分隔符分割
|
||
guest_filename = inject_to.split('\\')[-1]
|
||
else:
|
||
guest_filename = os.path.basename(source_file)
|
||
|
||
env.inject_file(source_file, guest_filename)
|
||
else:
|
||
logger.warning(f"⚠️ 输入文件不存在: {source_file}")
|
||
|
||
print("\n✅ 环境准备完成!")
|
||
print("=" * 60)
|
||
print("💡 下一步:开始录制操作")
|
||
print(f" python scripts/tools/collect_task.py {task_config['id']} --mode record")
|
||
print("=" * 60 + "\n")
|
||
|
||
|
||
def mode_record(env, task_config, project_root="."):
|
||
"""
|
||
模式2: 录制人类操作轨迹
|
||
"""
|
||
task_id = task_config['id']
|
||
output_dir = os.path.join(project_root, "tasks", task_id, "human_demo")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("🎥 模式: 录制轨迹")
|
||
print("=" * 60)
|
||
print(f"任务: {task_config.get('instruction', 'N/A')}")
|
||
print("=" * 60)
|
||
|
||
# 创建录制器并开始录制
|
||
record_interactive(env, task_id, output_dir)
|
||
|
||
print("\n💡 下一步:处理坐标转换")
|
||
print(f" python scripts/tools/process_trajectory.py {task_id}")
|
||
print("=" * 60 + "\n")
|
||
|
||
|
||
def mode_collect(env, task_config, project_root="."):
|
||
"""
|
||
模式3: 收集输出文件到ground_truth
|
||
"""
|
||
print("\n" + "=" * 60)
|
||
print("📦 模式: 收集结果文件")
|
||
print("=" * 60)
|
||
|
||
task_id = task_config['id']
|
||
|
||
if 'output' in task_config:
|
||
output_config = task_config['output']
|
||
expected_file = output_config.get('expected_file')
|
||
|
||
if expected_file:
|
||
# 目标路径
|
||
gt_dir = os.path.join(project_root, "tasks", task_id, "ground_truth")
|
||
os.makedirs(gt_dir, exist_ok=True)
|
||
|
||
host_path = os.path.join(gt_dir, expected_file)
|
||
|
||
# 收集文件
|
||
env.collect_file(expected_file, host_path)
|
||
|
||
print(f"\n✅ 文件已保存到: {host_path}")
|
||
else:
|
||
logger.warning("⚠️ 任务配置中未指定expected_file")
|
||
else:
|
||
logger.warning("⚠️ 任务配置中未指定output")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("💡 下一步:验证评测")
|
||
print(f" python scripts/tools/run_eval.py {task_id}")
|
||
print("=" * 60 + "\n")
|
||
|
||
|
||
def mode_full(env, task_config, project_root="."):
|
||
"""
|
||
模式4: 完整流程(reset + record + collect)
|
||
"""
|
||
print("\n" + "=" * 60)
|
||
print("🔄 模式: 完整采集流程")
|
||
print("=" * 60)
|
||
|
||
# Step 1: Reset
|
||
mode_reset(env, task_config, project_root)
|
||
|
||
# Step 2: Record
|
||
input("\n按Enter键开始录制...")
|
||
mode_record(env, task_config, project_root)
|
||
|
||
# Step 3: Collect
|
||
input("\n按Enter键收集结果...")
|
||
mode_collect(env, task_config, project_root)
|
||
|
||
print("\n✅ 完整采集流程完成!")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="JADE Benchmark 任务数据采集工具",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
使用示例:
|
||
# 完整流程(推荐)
|
||
python scripts/collect_task.py smoothing_001 --mode full
|
||
|
||
# 分步执行
|
||
python scripts/collect_task.py smoothing_001 --mode reset # 1. 重置并注入文件
|
||
python scripts/collect_task.py smoothing_001 --mode record # 2. 录制操作
|
||
python scripts/collect_task.py smoothing_001 --mode collect # 3. 收集结果
|
||
"""
|
||
)
|
||
|
||
parser.add_argument("task_id", help="任务ID(对应tasks/目录下的子目录名)")
|
||
parser.add_argument(
|
||
"--mode",
|
||
choices=["reset", "record", "collect", "full"],
|
||
default="full",
|
||
help="采集模式(默认:full)"
|
||
)
|
||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||
parser.add_argument("--vmx", help="虚拟机.vmx文件路径(覆盖默认配置)")
|
||
parser.add_argument("--snapshot", help="快照名称(覆盖默认配置)")
|
||
parser.add_argument("--vm-ip", help="虚拟机IP地址(覆盖默认配置)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 加载任务配置
|
||
task_config = load_task_config(args.task_id, args.project_root)
|
||
if not task_config:
|
||
sys.exit(1)
|
||
|
||
# 加载并合并配置
|
||
config = load_default_config()
|
||
if args.vmx:
|
||
config['vmx_path'] = args.vmx
|
||
if args.snapshot:
|
||
config['snapshot_name'] = args.snapshot
|
||
if args.vm_ip:
|
||
config['vm_ip'] = args.vm_ip
|
||
|
||
# 初始化环境
|
||
try:
|
||
logger.info("初始化JADE环境...")
|
||
env = JadeEnv(
|
||
vmx_path=config['vmx_path'],
|
||
snapshot_name=config['snapshot_name'],
|
||
vm_ip=config['vm_ip'],
|
||
vm_password=config.get('vm_password'),
|
||
guest_username=config.get('guest_username'),
|
||
guest_password=config.get('guest_password')
|
||
)
|
||
|
||
# 执行对应模式
|
||
if args.mode == "reset":
|
||
mode_reset(env, task_config, args.project_root)
|
||
elif args.mode == "record":
|
||
mode_record(env, task_config, args.project_root)
|
||
elif args.mode == "collect":
|
||
mode_collect(env, task_config, args.project_root)
|
||
elif args.mode == "full":
|
||
mode_full(env, task_config, args.project_root)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n⏹ 操作已取消")
|
||
sys.exit(1)
|
||
except Exception as e:
|
||
logger.error(f"❌ 错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|