Initial commit

This commit is contained in:
2026-01-12 18:30:12 +08:00
commit 214e15c04c
102 changed files with 27857 additions and 0 deletions

View File

@@ -0,0 +1,265 @@
"""
任务数据采集入口
整合环境控制、轨迹录制、文件收集的完整流程
"""
import os
import sys
import argparse
import json
import logging
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.jade_env import JadeEnv
from core.recorder import record_interactive
from utils.config_loader import load_config, get_vm_config, get_network_config
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
def load_default_config():
"""加载默认配置"""
try:
config = load_config()
vm_config = get_vm_config(config)
network_config = get_network_config(config)
return {
"vmx_path": vm_config.get('vmx_path'),
"snapshot_name": vm_config.get('snapshot_name', 'Jade_Ready'),
"vm_ip": network_config.get('vm_ip'),
"vm_password": vm_config.get('vm_password'),
"guest_username": vm_config.get('guest_username'),
"guest_password": vm_config.get('guest_password')
}
except Exception as e:
logger.warning(f"⚠️ 无法加载config.json: {e}")
logger.info(" 使用硬编码配置")
return {
"vmx_path": "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx",
"snapshot_name": "Jade_Ready",
"vm_ip": "192.168.116.129",
"vm_password": "lizhanyuan",
"guest_username": "lzy",
"guest_password": "LIZHANYUAN"
}
def load_task_config(task_id, project_root="."):
"""加载任务配置文件"""
task_json_path = os.path.join(project_root, "tasks", task_id, "task.json")
if not os.path.exists(task_json_path):
logger.error(f"❌ 任务配置文件不存在: {task_json_path}")
logger.info(" 请先创建任务目录和task.json")
return None
with open(task_json_path, 'r', encoding='utf-8') as f:
return json.load(f)
def mode_reset(env, task_config, project_root="."):
"""
模式1: 重置环境并注入输入文件
"""
print("\n" + "=" * 60)
print("🔄 模式: 重置环境")
print("=" * 60)
# 1. 重置虚拟机
env.reset()
# 2. 注入输入文件
if 'input' in task_config:
input_config = task_config['input']
source_file = input_config.get('source_file')
if source_file:
# 处理相对路径(相对于任务目录)
if not os.path.isabs(source_file):
task_dir = os.path.join(project_root, "tasks", task_config['id'])
source_file = os.path.normpath(os.path.join(task_dir, source_file))
# 确保使用绝对路径
source_file = os.path.abspath(source_file)
if os.path.exists(source_file):
# 从Windows路径中提取文件名处理反斜杠
inject_to = input_config.get('inject_to', '')
if inject_to:
# 使用Windows路径分隔符分割
guest_filename = inject_to.split('\\')[-1]
else:
guest_filename = os.path.basename(source_file)
env.inject_file(source_file, guest_filename)
else:
logger.warning(f"⚠️ 输入文件不存在: {source_file}")
print("\n✅ 环境准备完成!")
print("=" * 60)
print("💡 下一步:开始录制操作")
print(f" python scripts/tools/collect_task.py {task_config['id']} --mode record")
print("=" * 60 + "\n")
def mode_record(env, task_config, project_root="."):
"""
模式2: 录制人类操作轨迹
"""
task_id = task_config['id']
output_dir = os.path.join(project_root, "tasks", task_id, "human_demo")
print("\n" + "=" * 60)
print("🎥 模式: 录制轨迹")
print("=" * 60)
print(f"任务: {task_config.get('instruction', 'N/A')}")
print("=" * 60)
# 创建录制器并开始录制
record_interactive(env, task_id, output_dir)
print("\n💡 下一步:处理坐标转换")
print(f" python scripts/tools/process_trajectory.py {task_id}")
print("=" * 60 + "\n")
def mode_collect(env, task_config, project_root="."):
"""
模式3: 收集输出文件到ground_truth
"""
print("\n" + "=" * 60)
print("📦 模式: 收集结果文件")
print("=" * 60)
task_id = task_config['id']
if 'output' in task_config:
output_config = task_config['output']
expected_file = output_config.get('expected_file')
if expected_file:
# 目标路径
gt_dir = os.path.join(project_root, "tasks", task_id, "ground_truth")
os.makedirs(gt_dir, exist_ok=True)
host_path = os.path.join(gt_dir, expected_file)
# 收集文件
env.collect_file(expected_file, host_path)
print(f"\n✅ 文件已保存到: {host_path}")
else:
logger.warning("⚠️ 任务配置中未指定expected_file")
else:
logger.warning("⚠️ 任务配置中未指定output")
print("\n" + "=" * 60)
print("💡 下一步:验证评测")
print(f" python scripts/tools/run_eval.py {task_id}")
print("=" * 60 + "\n")
def mode_full(env, task_config, project_root="."):
"""
模式4: 完整流程reset + record + collect
"""
print("\n" + "=" * 60)
print("🔄 模式: 完整采集流程")
print("=" * 60)
# Step 1: Reset
mode_reset(env, task_config, project_root)
# Step 2: Record
input("\n按Enter键开始录制...")
mode_record(env, task_config, project_root)
# Step 3: Collect
input("\n按Enter键收集结果...")
mode_collect(env, task_config, project_root)
print("\n✅ 完整采集流程完成!")
def main():
parser = argparse.ArgumentParser(
description="JADE Benchmark 任务数据采集工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 完整流程(推荐)
python scripts/collect_task.py smoothing_001 --mode full
# 分步执行
python scripts/collect_task.py smoothing_001 --mode reset # 1. 重置并注入文件
python scripts/collect_task.py smoothing_001 --mode record # 2. 录制操作
python scripts/collect_task.py smoothing_001 --mode collect # 3. 收集结果
"""
)
parser.add_argument("task_id", help="任务ID对应tasks/目录下的子目录名)")
parser.add_argument(
"--mode",
choices=["reset", "record", "collect", "full"],
default="full",
help="采集模式默认full"
)
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--vmx", help="虚拟机.vmx文件路径覆盖默认配置")
parser.add_argument("--snapshot", help="快照名称(覆盖默认配置)")
parser.add_argument("--vm-ip", help="虚拟机IP地址覆盖默认配置")
args = parser.parse_args()
# 加载任务配置
task_config = load_task_config(args.task_id, args.project_root)
if not task_config:
sys.exit(1)
# 加载并合并配置
config = load_default_config()
if args.vmx:
config['vmx_path'] = args.vmx
if args.snapshot:
config['snapshot_name'] = args.snapshot
if args.vm_ip:
config['vm_ip'] = args.vm_ip
# 初始化环境
try:
logger.info("初始化JADE环境...")
env = JadeEnv(
vmx_path=config['vmx_path'],
snapshot_name=config['snapshot_name'],
vm_ip=config['vm_ip'],
vm_password=config.get('vm_password'),
guest_username=config.get('guest_username'),
guest_password=config.get('guest_password')
)
# 执行对应模式
if args.mode == "reset":
mode_reset(env, task_config, args.project_root)
elif args.mode == "record":
mode_record(env, task_config, args.project_root)
elif args.mode == "collect":
mode_collect(env, task_config, args.project_root)
elif args.mode == "full":
mode_full(env, task_config, args.project_root)
except KeyboardInterrupt:
print("\n\n⏹ 操作已取消")
sys.exit(1)
except Exception as e:
logger.error(f"❌ 错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()