Initial commit
This commit is contained in:
265
scripts/tools/collect_task.py
Normal file
265
scripts/tools/collect_task.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
任务数据采集入口
|
||||
整合环境控制、轨迹录制、文件收集的完整流程
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.jade_env import JadeEnv
|
||||
from core.recorder import record_interactive
|
||||
from utils.config_loader import load_config, get_vm_config, get_network_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_default_config():
|
||||
"""加载默认配置"""
|
||||
try:
|
||||
config = load_config()
|
||||
vm_config = get_vm_config(config)
|
||||
network_config = get_network_config(config)
|
||||
|
||||
return {
|
||||
"vmx_path": vm_config.get('vmx_path'),
|
||||
"snapshot_name": vm_config.get('snapshot_name', 'Jade_Ready'),
|
||||
"vm_ip": network_config.get('vm_ip'),
|
||||
"vm_password": vm_config.get('vm_password'),
|
||||
"guest_username": vm_config.get('guest_username'),
|
||||
"guest_password": vm_config.get('guest_password')
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 无法加载config.json: {e}")
|
||||
logger.info(" 使用硬编码配置")
|
||||
return {
|
||||
"vmx_path": "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx",
|
||||
"snapshot_name": "Jade_Ready",
|
||||
"vm_ip": "192.168.116.129",
|
||||
"vm_password": "lizhanyuan",
|
||||
"guest_username": "lzy",
|
||||
"guest_password": "LIZHANYUAN"
|
||||
}
|
||||
|
||||
|
||||
def load_task_config(task_id, project_root="."):
|
||||
"""加载任务配置文件"""
|
||||
task_json_path = os.path.join(project_root, "tasks", task_id, "task.json")
|
||||
|
||||
if not os.path.exists(task_json_path):
|
||||
logger.error(f"❌ 任务配置文件不存在: {task_json_path}")
|
||||
logger.info(" 请先创建任务目录和task.json")
|
||||
return None
|
||||
|
||||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def mode_reset(env, task_config, project_root="."):
|
||||
"""
|
||||
模式1: 重置环境并注入输入文件
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🔄 模式: 重置环境")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 重置虚拟机
|
||||
env.reset()
|
||||
|
||||
# 2. 注入输入文件
|
||||
if 'input' in task_config:
|
||||
input_config = task_config['input']
|
||||
source_file = input_config.get('source_file')
|
||||
|
||||
if source_file:
|
||||
# 处理相对路径(相对于任务目录)
|
||||
if not os.path.isabs(source_file):
|
||||
task_dir = os.path.join(project_root, "tasks", task_config['id'])
|
||||
source_file = os.path.normpath(os.path.join(task_dir, source_file))
|
||||
|
||||
# 确保使用绝对路径
|
||||
source_file = os.path.abspath(source_file)
|
||||
|
||||
if os.path.exists(source_file):
|
||||
# 从Windows路径中提取文件名(处理反斜杠)
|
||||
inject_to = input_config.get('inject_to', '')
|
||||
if inject_to:
|
||||
# 使用Windows路径分隔符分割
|
||||
guest_filename = inject_to.split('\\')[-1]
|
||||
else:
|
||||
guest_filename = os.path.basename(source_file)
|
||||
|
||||
env.inject_file(source_file, guest_filename)
|
||||
else:
|
||||
logger.warning(f"⚠️ 输入文件不存在: {source_file}")
|
||||
|
||||
print("\n✅ 环境准备完成!")
|
||||
print("=" * 60)
|
||||
print("💡 下一步:开始录制操作")
|
||||
print(f" python scripts/tools/collect_task.py {task_config['id']} --mode record")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
def mode_record(env, task_config, project_root="."):
|
||||
"""
|
||||
模式2: 录制人类操作轨迹
|
||||
"""
|
||||
task_id = task_config['id']
|
||||
output_dir = os.path.join(project_root, "tasks", task_id, "human_demo")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎥 模式: 录制轨迹")
|
||||
print("=" * 60)
|
||||
print(f"任务: {task_config.get('instruction', 'N/A')}")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建录制器并开始录制
|
||||
record_interactive(env, task_id, output_dir)
|
||||
|
||||
print("\n💡 下一步:处理坐标转换")
|
||||
print(f" python scripts/tools/process_trajectory.py {task_id}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
def mode_collect(env, task_config, project_root="."):
|
||||
"""
|
||||
模式3: 收集输出文件到ground_truth
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📦 模式: 收集结果文件")
|
||||
print("=" * 60)
|
||||
|
||||
task_id = task_config['id']
|
||||
|
||||
if 'output' in task_config:
|
||||
output_config = task_config['output']
|
||||
expected_file = output_config.get('expected_file')
|
||||
|
||||
if expected_file:
|
||||
# 目标路径
|
||||
gt_dir = os.path.join(project_root, "tasks", task_id, "ground_truth")
|
||||
os.makedirs(gt_dir, exist_ok=True)
|
||||
|
||||
host_path = os.path.join(gt_dir, expected_file)
|
||||
|
||||
# 收集文件
|
||||
env.collect_file(expected_file, host_path)
|
||||
|
||||
print(f"\n✅ 文件已保存到: {host_path}")
|
||||
else:
|
||||
logger.warning("⚠️ 任务配置中未指定expected_file")
|
||||
else:
|
||||
logger.warning("⚠️ 任务配置中未指定output")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("💡 下一步:验证评测")
|
||||
print(f" python scripts/tools/run_eval.py {task_id}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
def mode_full(env, task_config, project_root="."):
|
||||
"""
|
||||
模式4: 完整流程(reset + record + collect)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🔄 模式: 完整采集流程")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Reset
|
||||
mode_reset(env, task_config, project_root)
|
||||
|
||||
# Step 2: Record
|
||||
input("\n按Enter键开始录制...")
|
||||
mode_record(env, task_config, project_root)
|
||||
|
||||
# Step 3: Collect
|
||||
input("\n按Enter键收集结果...")
|
||||
mode_collect(env, task_config, project_root)
|
||||
|
||||
print("\n✅ 完整采集流程完成!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JADE Benchmark 任务数据采集工具",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
# 完整流程(推荐)
|
||||
python scripts/collect_task.py smoothing_001 --mode full
|
||||
|
||||
# 分步执行
|
||||
python scripts/collect_task.py smoothing_001 --mode reset # 1. 重置并注入文件
|
||||
python scripts/collect_task.py smoothing_001 --mode record # 2. 录制操作
|
||||
python scripts/collect_task.py smoothing_001 --mode collect # 3. 收集结果
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("task_id", help="任务ID(对应tasks/目录下的子目录名)")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["reset", "record", "collect", "full"],
|
||||
default="full",
|
||||
help="采集模式(默认:full)"
|
||||
)
|
||||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||||
parser.add_argument("--vmx", help="虚拟机.vmx文件路径(覆盖默认配置)")
|
||||
parser.add_argument("--snapshot", help="快照名称(覆盖默认配置)")
|
||||
parser.add_argument("--vm-ip", help="虚拟机IP地址(覆盖默认配置)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载任务配置
|
||||
task_config = load_task_config(args.task_id, args.project_root)
|
||||
if not task_config:
|
||||
sys.exit(1)
|
||||
|
||||
# 加载并合并配置
|
||||
config = load_default_config()
|
||||
if args.vmx:
|
||||
config['vmx_path'] = args.vmx
|
||||
if args.snapshot:
|
||||
config['snapshot_name'] = args.snapshot
|
||||
if args.vm_ip:
|
||||
config['vm_ip'] = args.vm_ip
|
||||
|
||||
# 初始化环境
|
||||
try:
|
||||
logger.info("初始化JADE环境...")
|
||||
env = JadeEnv(
|
||||
vmx_path=config['vmx_path'],
|
||||
snapshot_name=config['snapshot_name'],
|
||||
vm_ip=config['vm_ip'],
|
||||
vm_password=config.get('vm_password'),
|
||||
guest_username=config.get('guest_username'),
|
||||
guest_password=config.get('guest_password')
|
||||
)
|
||||
|
||||
# 执行对应模式
|
||||
if args.mode == "reset":
|
||||
mode_reset(env, task_config, args.project_root)
|
||||
elif args.mode == "record":
|
||||
mode_record(env, task_config, args.project_root)
|
||||
elif args.mode == "collect":
|
||||
mode_collect(env, task_config, args.project_root)
|
||||
elif args.mode == "full":
|
||||
mode_full(env, task_config, args.project_root)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⏹ 操作已取消")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user