Initial commit
This commit is contained in:
213
scripts/tools/batch_create_tasks.py
Executable file
213
scripts/tools/batch_create_tasks.py
Executable file
@@ -0,0 +1,213 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
批量创建任务
|
||||
从JSON定义文件批量创建任务
|
||||
用法:
|
||||
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from scripts.tools.init_task import init_task
|
||||
|
||||
|
||||
def batch_create_tasks(definition_file, project_root=".", force=False, skip_existing=True):
|
||||
"""
|
||||
批量创建任务
|
||||
|
||||
Args:
|
||||
definition_file: 任务定义JSON文件路径
|
||||
project_root: 项目根目录
|
||||
force: 是否强制覆盖已存在的任务
|
||||
skip_existing: 是否跳过已存在的任务(与force互斥)
|
||||
"""
|
||||
# 读取任务定义
|
||||
with open(definition_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
tasks = data.get("tasks", [])
|
||||
category = data.get("category", "unknown")
|
||||
tutorial_source = data.get("tutorial_source", "")
|
||||
|
||||
print("=" * 60)
|
||||
print(f"📚 批量创建任务: {category}")
|
||||
if tutorial_source:
|
||||
print(f" 教程来源: {tutorial_source}")
|
||||
print(f" 任务数量: {len(tasks)}")
|
||||
print("=" * 60)
|
||||
|
||||
results = {
|
||||
"success": [],
|
||||
"skipped": [],
|
||||
"failed": []
|
||||
}
|
||||
|
||||
for i, task_def in enumerate(tasks, 1):
|
||||
task_id = task_def["id"]
|
||||
print(f"\n[{i}/{len(tasks)}] 处理任务: {task_id}")
|
||||
print("-" * 60)
|
||||
|
||||
# 检查是否已存在
|
||||
task_dir = os.path.join(project_root, "tasks", task_id)
|
||||
if os.path.exists(task_dir) and not force:
|
||||
if skip_existing:
|
||||
print(f"⏭️ 跳过(已存在): {task_id}")
|
||||
results["skipped"].append(task_id)
|
||||
continue
|
||||
else:
|
||||
print(f"⚠️ 任务已存在: {task_id}")
|
||||
print(" 使用 --force 强制覆盖,或设置 skip_existing=True")
|
||||
results["failed"].append((task_id, "已存在"))
|
||||
continue
|
||||
|
||||
try:
|
||||
# 调用 init_task 创建任务结构
|
||||
success = init_task(
|
||||
task_id=task_id,
|
||||
project_root=project_root,
|
||||
force=force,
|
||||
category=task_def.get("category", category),
|
||||
difficulty=task_def.get("difficulty", "easy"),
|
||||
instruction=task_def.get("instruction", "")
|
||||
)
|
||||
|
||||
if not success:
|
||||
results["failed"].append((task_id, "初始化失败"))
|
||||
continue
|
||||
|
||||
# 更新 task.json
|
||||
task_json_path = os.path.join(task_dir, "task.json")
|
||||
if os.path.exists(task_json_path):
|
||||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||||
task_config = json.load(f)
|
||||
|
||||
# 更新输入输出配置
|
||||
source_file = task_def.get("source_file", "DEMO01.MDI")
|
||||
if not os.path.isabs(source_file) and not source_file.startswith("../"):
|
||||
source_file = f"../../data/source/{source_file}"
|
||||
|
||||
filename = os.path.basename(source_file)
|
||||
inject_to = f"C:\\Users\\lzy\\Desktop\\{filename}"
|
||||
|
||||
output_filename = task_def.get("output_filename", "result.txt")
|
||||
collect_from = f"C:\\Users\\lzy\\Desktop\\{output_filename}"
|
||||
|
||||
task_config["input"] = {
|
||||
"source_file": source_file,
|
||||
"inject_to": inject_to
|
||||
}
|
||||
task_config["output"] = {
|
||||
"expected_file": output_filename,
|
||||
"collect_from": collect_from
|
||||
}
|
||||
|
||||
# 更新评测方法
|
||||
eval_method = task_def.get("evaluation_method", "xrd_data_compare")
|
||||
task_config["evaluation"]["method"] = eval_method
|
||||
|
||||
# 添加教程来源
|
||||
if task_def.get("tutorial_source"):
|
||||
task_config["tutorial_source"] = task_def["tutorial_source"]
|
||||
|
||||
# 添加备注
|
||||
if task_def.get("notes"):
|
||||
task_config["notes"] = task_def["notes"]
|
||||
|
||||
# 保存
|
||||
with open(task_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(task_config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"✅ 任务创建成功: {task_id}")
|
||||
results["success"].append(task_id)
|
||||
else:
|
||||
print(f"❌ task.json 未创建: {task_json_path}")
|
||||
results["failed"].append((task_id, "配置文件未创建"))
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 创建任务失败: {task_id}")
|
||||
print(f" 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results["failed"].append((task_id, str(e)))
|
||||
|
||||
# 打印总结
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 批量创建总结")
|
||||
print("=" * 60)
|
||||
print(f"✅ 成功: {len(results['success'])}")
|
||||
if results["success"]:
|
||||
for task_id in results["success"]:
|
||||
print(f" - {task_id}")
|
||||
|
||||
if results["skipped"]:
|
||||
print(f"\n⏭️ 跳过: {len(results['skipped'])}")
|
||||
for task_id in results["skipped"]:
|
||||
print(f" - {task_id}")
|
||||
|
||||
if results["failed"]:
|
||||
print(f"\n❌ 失败: {len(results['failed'])}")
|
||||
for task_id, reason in results["failed"]:
|
||||
print(f" - {task_id}: {reason}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="批量创建任务",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
# 批量创建任务(跳过已存在的)
|
||||
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json
|
||||
|
||||
# 强制覆盖已存在的任务
|
||||
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json --force
|
||||
|
||||
# 不跳过已存在的任务(遇到已存在就失败)
|
||||
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json --no-skip-existing
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("definition_file", help="任务定义JSON文件路径")
|
||||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||||
parser.add_argument("--force", action="store_true", help="强制覆盖已存在的任务")
|
||||
parser.add_argument("--no-skip-existing", action="store_true", help="不跳过已存在的任务(遇到就失败)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.definition_file):
|
||||
print(f"❌ 定义文件不存在: {args.definition_file}")
|
||||
sys.exit(1)
|
||||
|
||||
skip_existing = not args.no_skip_existing
|
||||
|
||||
results = batch_create_tasks(
|
||||
args.definition_file,
|
||||
args.project_root,
|
||||
force=args.force,
|
||||
skip_existing=skip_existing
|
||||
)
|
||||
|
||||
# 如果有失败的任务,返回非0退出码
|
||||
if results["failed"]:
|
||||
sys.exit(1)
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
265
scripts/tools/collect_task.py
Normal file
265
scripts/tools/collect_task.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
任务数据采集入口
|
||||
整合环境控制、轨迹录制、文件收集的完整流程
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.jade_env import JadeEnv
|
||||
from core.recorder import record_interactive
|
||||
from utils.config_loader import load_config, get_vm_config, get_network_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_default_config():
|
||||
"""加载默认配置"""
|
||||
try:
|
||||
config = load_config()
|
||||
vm_config = get_vm_config(config)
|
||||
network_config = get_network_config(config)
|
||||
|
||||
return {
|
||||
"vmx_path": vm_config.get('vmx_path'),
|
||||
"snapshot_name": vm_config.get('snapshot_name', 'Jade_Ready'),
|
||||
"vm_ip": network_config.get('vm_ip'),
|
||||
"vm_password": vm_config.get('vm_password'),
|
||||
"guest_username": vm_config.get('guest_username'),
|
||||
"guest_password": vm_config.get('guest_password')
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 无法加载config.json: {e}")
|
||||
logger.info(" 使用硬编码配置")
|
||||
return {
|
||||
"vmx_path": "/Volumes/Castor/虚拟机/Jade_Win_11.vmwarevm/Windows 11 64 位 ARM 2.vmx",
|
||||
"snapshot_name": "Jade_Ready",
|
||||
"vm_ip": "192.168.116.129",
|
||||
"vm_password": "lizhanyuan",
|
||||
"guest_username": "lzy",
|
||||
"guest_password": "LIZHANYUAN"
|
||||
}
|
||||
|
||||
|
||||
def load_task_config(task_id, project_root="."):
|
||||
"""加载任务配置文件"""
|
||||
task_json_path = os.path.join(project_root, "tasks", task_id, "task.json")
|
||||
|
||||
if not os.path.exists(task_json_path):
|
||||
logger.error(f"❌ 任务配置文件不存在: {task_json_path}")
|
||||
logger.info(" 请先创建任务目录和task.json")
|
||||
return None
|
||||
|
||||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def mode_reset(env, task_config, project_root="."):
|
||||
"""
|
||||
模式1: 重置环境并注入输入文件
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🔄 模式: 重置环境")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 重置虚拟机
|
||||
env.reset()
|
||||
|
||||
# 2. 注入输入文件
|
||||
if 'input' in task_config:
|
||||
input_config = task_config['input']
|
||||
source_file = input_config.get('source_file')
|
||||
|
||||
if source_file:
|
||||
# 处理相对路径(相对于任务目录)
|
||||
if not os.path.isabs(source_file):
|
||||
task_dir = os.path.join(project_root, "tasks", task_config['id'])
|
||||
source_file = os.path.normpath(os.path.join(task_dir, source_file))
|
||||
|
||||
# 确保使用绝对路径
|
||||
source_file = os.path.abspath(source_file)
|
||||
|
||||
if os.path.exists(source_file):
|
||||
# 从Windows路径中提取文件名(处理反斜杠)
|
||||
inject_to = input_config.get('inject_to', '')
|
||||
if inject_to:
|
||||
# 使用Windows路径分隔符分割
|
||||
guest_filename = inject_to.split('\\')[-1]
|
||||
else:
|
||||
guest_filename = os.path.basename(source_file)
|
||||
|
||||
env.inject_file(source_file, guest_filename)
|
||||
else:
|
||||
logger.warning(f"⚠️ 输入文件不存在: {source_file}")
|
||||
|
||||
print("\n✅ 环境准备完成!")
|
||||
print("=" * 60)
|
||||
print("💡 下一步:开始录制操作")
|
||||
print(f" python scripts/tools/collect_task.py {task_config['id']} --mode record")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
def mode_record(env, task_config, project_root="."):
|
||||
"""
|
||||
模式2: 录制人类操作轨迹
|
||||
"""
|
||||
task_id = task_config['id']
|
||||
output_dir = os.path.join(project_root, "tasks", task_id, "human_demo")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎥 模式: 录制轨迹")
|
||||
print("=" * 60)
|
||||
print(f"任务: {task_config.get('instruction', 'N/A')}")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建录制器并开始录制
|
||||
record_interactive(env, task_id, output_dir)
|
||||
|
||||
print("\n💡 下一步:处理坐标转换")
|
||||
print(f" python scripts/tools/process_trajectory.py {task_id}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
def mode_collect(env, task_config, project_root="."):
|
||||
"""
|
||||
模式3: 收集输出文件到ground_truth
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📦 模式: 收集结果文件")
|
||||
print("=" * 60)
|
||||
|
||||
task_id = task_config['id']
|
||||
|
||||
if 'output' in task_config:
|
||||
output_config = task_config['output']
|
||||
expected_file = output_config.get('expected_file')
|
||||
|
||||
if expected_file:
|
||||
# 目标路径
|
||||
gt_dir = os.path.join(project_root, "tasks", task_id, "ground_truth")
|
||||
os.makedirs(gt_dir, exist_ok=True)
|
||||
|
||||
host_path = os.path.join(gt_dir, expected_file)
|
||||
|
||||
# 收集文件
|
||||
env.collect_file(expected_file, host_path)
|
||||
|
||||
print(f"\n✅ 文件已保存到: {host_path}")
|
||||
else:
|
||||
logger.warning("⚠️ 任务配置中未指定expected_file")
|
||||
else:
|
||||
logger.warning("⚠️ 任务配置中未指定output")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("💡 下一步:验证评测")
|
||||
print(f" python scripts/tools/run_eval.py {task_id}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
def mode_full(env, task_config, project_root="."):
|
||||
"""
|
||||
模式4: 完整流程(reset + record + collect)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🔄 模式: 完整采集流程")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Reset
|
||||
mode_reset(env, task_config, project_root)
|
||||
|
||||
# Step 2: Record
|
||||
input("\n按Enter键开始录制...")
|
||||
mode_record(env, task_config, project_root)
|
||||
|
||||
# Step 3: Collect
|
||||
input("\n按Enter键收集结果...")
|
||||
mode_collect(env, task_config, project_root)
|
||||
|
||||
print("\n✅ 完整采集流程完成!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JADE Benchmark 任务数据采集工具",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
# 完整流程(推荐)
|
||||
python scripts/collect_task.py smoothing_001 --mode full
|
||||
|
||||
# 分步执行
|
||||
python scripts/collect_task.py smoothing_001 --mode reset # 1. 重置并注入文件
|
||||
python scripts/collect_task.py smoothing_001 --mode record # 2. 录制操作
|
||||
python scripts/collect_task.py smoothing_001 --mode collect # 3. 收集结果
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("task_id", help="任务ID(对应tasks/目录下的子目录名)")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["reset", "record", "collect", "full"],
|
||||
default="full",
|
||||
help="采集模式(默认:full)"
|
||||
)
|
||||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||||
parser.add_argument("--vmx", help="虚拟机.vmx文件路径(覆盖默认配置)")
|
||||
parser.add_argument("--snapshot", help="快照名称(覆盖默认配置)")
|
||||
parser.add_argument("--vm-ip", help="虚拟机IP地址(覆盖默认配置)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载任务配置
|
||||
task_config = load_task_config(args.task_id, args.project_root)
|
||||
if not task_config:
|
||||
sys.exit(1)
|
||||
|
||||
# 加载并合并配置
|
||||
config = load_default_config()
|
||||
if args.vmx:
|
||||
config['vmx_path'] = args.vmx
|
||||
if args.snapshot:
|
||||
config['snapshot_name'] = args.snapshot
|
||||
if args.vm_ip:
|
||||
config['vm_ip'] = args.vm_ip
|
||||
|
||||
# 初始化环境
|
||||
try:
|
||||
logger.info("初始化JADE环境...")
|
||||
env = JadeEnv(
|
||||
vmx_path=config['vmx_path'],
|
||||
snapshot_name=config['snapshot_name'],
|
||||
vm_ip=config['vm_ip'],
|
||||
vm_password=config.get('vm_password'),
|
||||
guest_username=config.get('guest_username'),
|
||||
guest_password=config.get('guest_password')
|
||||
)
|
||||
|
||||
# 执行对应模式
|
||||
if args.mode == "reset":
|
||||
mode_reset(env, task_config, args.project_root)
|
||||
elif args.mode == "record":
|
||||
mode_record(env, task_config, args.project_root)
|
||||
elif args.mode == "collect":
|
||||
mode_collect(env, task_config, args.project_root)
|
||||
elif args.mode == "full":
|
||||
mode_full(env, task_config, args.project_root)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⏹ 操作已取消")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
223
scripts/tools/detect_vm_ip.py
Executable file
223
scripts/tools/detect_vm_ip.py
Executable file
@@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自动检测VM的IP地址
|
||||
使用vmrun getGuestIPAddress命令获取VM的当前IP
|
||||
"""
|
||||
import subprocess
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from scripts.utils.config_loader import load_config
|
||||
|
||||
|
||||
def get_vm_ip(vmx_path, vm_password=None):
|
||||
"""
|
||||
使用vmrun获取VM的IP地址
|
||||
|
||||
Args:
|
||||
vmx_path: 虚拟机.vmx文件路径
|
||||
vm_password: 虚拟机文件加密密码(可选)
|
||||
|
||||
Returns:
|
||||
str: VM的IP地址,如果失败返回None
|
||||
"""
|
||||
vmrun = "/Applications/VMware Fusion.app/Contents/Library/vmrun"
|
||||
|
||||
# 构建命令
|
||||
cmd = [vmrun, "-T", "fusion"]
|
||||
if vm_password:
|
||||
cmd.extend(["-vp", vm_password])
|
||||
cmd.extend(["getGuestIPAddress", vmx_path])
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
ip = result.stdout.strip()
|
||||
if ip and ip != "":
|
||||
return ip
|
||||
else:
|
||||
print(f"⚠️ vmrun返回空IP地址")
|
||||
return None
|
||||
else:
|
||||
error_msg = result.stderr or result.stdout
|
||||
print(f"❌ 获取IP失败: {error_msg}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"❌ 获取IP超时")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"❌ 获取IP异常: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def update_config_ip(new_ip, config_path="config.json"):
|
||||
"""
|
||||
更新config.json中的IP地址
|
||||
|
||||
Args:
|
||||
new_ip: 新的IP地址
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
config_path = os.path.join(project_root, config_path)
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
print(f"❌ 配置文件不存在: {config_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
old_ip = config.get("network", {}).get("vm_ip", "未知")
|
||||
|
||||
if old_ip == new_ip:
|
||||
print(f"✅ IP地址未变化: {new_ip}")
|
||||
return True
|
||||
|
||||
config["network"]["vm_ip"] = new_ip
|
||||
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"✅ 已更新IP地址: {old_ip} → {new_ip}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 更新配置文件失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_http_connection(ip, port=5000, timeout=5):
|
||||
"""
|
||||
测试HTTP连接是否可用
|
||||
|
||||
Args:
|
||||
ip: VM的IP地址
|
||||
port: 端口号
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
bool: 连接是否成功
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
print(f"⚠️ requests模块未安装,跳过HTTP连接测试")
|
||||
return False
|
||||
|
||||
url = f"http://{ip}:{port}/screen_info"
|
||||
proxies = {'http': None, 'https': None} # 绕过代理
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout, proxies=proxies)
|
||||
if response.status_code == 200:
|
||||
print(f"✅ HTTP服务连接成功: {url}")
|
||||
return True
|
||||
else:
|
||||
print(f"⚠️ HTTP服务响应异常: 状态码 {response.status_code}")
|
||||
return False
|
||||
except requests.exceptions.Timeout:
|
||||
print(f"⚠️ HTTP服务连接超时: {url}")
|
||||
return False
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f"⚠️ HTTP服务连接失败: {url}")
|
||||
print(f" 错误: {str(e)[:100]}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ HTTP连接异常: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("🔍 检测VM IP地址")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载配置
|
||||
try:
|
||||
config = load_config()
|
||||
vmx_path = config["vmware"]["vmx_path"]
|
||||
vm_password = config["vmware"].get("vm_password")
|
||||
current_ip = config["network"].get("vm_ip", "未知")
|
||||
except Exception as e:
|
||||
print(f"❌ 加载配置失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\n📋 当前配置:")
|
||||
print(f" VM路径: {os.path.basename(vmx_path)}")
|
||||
print(f" 当前IP: {current_ip}")
|
||||
print(f" 端口: {config['network'].get('agent_server_port', 5000)}")
|
||||
|
||||
# 获取VM IP
|
||||
print(f"\n🔍 正在获取VM IP地址...")
|
||||
vm_ip = get_vm_ip(vmx_path, vm_password)
|
||||
|
||||
if not vm_ip:
|
||||
print("\n❌ 无法获取VM IP地址")
|
||||
print(" 可能原因:")
|
||||
print(" 1. VM未运行")
|
||||
print(" 2. VM网络未配置")
|
||||
print(" 3. vmrun命令执行失败")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"✅ 检测到VM IP: {vm_ip}")
|
||||
|
||||
# 测试HTTP连接
|
||||
port = config["network"].get("agent_server_port", 5000)
|
||||
print(f"\n🔗 测试HTTP连接 (端口 {port})...")
|
||||
http_ok = test_http_connection(vm_ip, port)
|
||||
|
||||
# 询问是否更新配置
|
||||
if vm_ip != current_ip:
|
||||
print(f"\n⚠️ IP地址已变化: {current_ip} → {vm_ip}")
|
||||
|
||||
if http_ok:
|
||||
print(f"\n❓ 是否更新配置文件? (y/n): ", end="")
|
||||
try:
|
||||
choice = input().strip().lower()
|
||||
if choice == 'y':
|
||||
if update_config_ip(vm_ip):
|
||||
print(f"\n✅ 配置已更新!")
|
||||
else:
|
||||
print(f"\n❌ 配置更新失败")
|
||||
else:
|
||||
print(f"\n⏭️ 跳过更新")
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n\n⚠️ 用户取消")
|
||||
else:
|
||||
print(f"\n⚠️ HTTP服务不可用,请检查:")
|
||||
print(f" 1. VM中是否运行了 agent_server.py?")
|
||||
print(f" 2. 端口 {port} 是否被占用?")
|
||||
print(f" 3. 防火墙是否阻止了连接?")
|
||||
else:
|
||||
print(f"\n✅ IP地址未变化")
|
||||
if not http_ok:
|
||||
print(f"\n⚠️ 但HTTP服务不可用,请检查VM中的agent_server.py")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
271
scripts/tools/extract_task_from_tutorial.py
Executable file
271
scripts/tools/extract_task_from_tutorial.py
Executable file
@@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
从教程信息快速生成任务定义
|
||||
用法:
|
||||
python scripts/tools/extract_task_from_tutorial.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from scripts.tools.init_task import init_task
|
||||
|
||||
|
||||
# 任务类别和难度映射
|
||||
CATEGORY_MAP = {
|
||||
"1": "basic_processing",
|
||||
"2": "peak_analysis",
|
||||
"3": "phase_identification",
|
||||
"4": "crystal_parameters",
|
||||
"5": "calibration",
|
||||
"6": "advanced_analysis",
|
||||
}
|
||||
|
||||
DIFFICULTY_MAP = {
|
||||
"1": "easy",
|
||||
"2": "medium",
|
||||
"3": "hard",
|
||||
}
|
||||
|
||||
# 常见任务模板
|
||||
TASK_TEMPLATES = {
|
||||
"basic_processing": {
|
||||
"open_file": "请打开桌面上的 {filename} 文件。",
|
||||
"smooth": "请打开桌面上的 {filename} 文件,进行平滑处理 (Smoothing),然后将处理后的曲线导出为 ASCII (.txt) 文件并命名为 {output}。",
|
||||
"background": "请打开桌面上的 {filename} 文件,进行背景扣除 (Background Removal),然后将处理后的曲线导出为 ASCII (.txt) 文件并命名为 {output}。",
|
||||
"export": "请打开桌面上的 {filename} 文件,将当前曲线导出为 ASCII (.txt) 文件并命名为 {output}。",
|
||||
},
|
||||
"peak_analysis": {
|
||||
"peak_search": "请打开桌面上的 {filename} 文件,进行寻峰操作 (Peak Search),并导出寻峰结果文件 {output}。",
|
||||
"peak_separation": "请打开桌面上的 {filename} 文件,进行多峰分离操作 (Peak Separation),并导出结果文件 {output}。",
|
||||
"peak_fitting": "请打开桌面上的 {filename} 文件,进行峰形拟合 (Peak Fitting),并导出结果文件 {output}。",
|
||||
},
|
||||
"phase_identification": {
|
||||
"phase_search": "请打开桌面上的 {filename} 文件,进行物相检索 (Phase Search),并导出检索结果文件 {output}。",
|
||||
"quantitative": "请打开桌面上的 {filename} 文件,进行物相定量分析 (Quantitative Analysis),并导出结果文件 {output}。",
|
||||
},
|
||||
"crystal_parameters": {
|
||||
"lattice_constant": "请打开桌面上的 {filename} 文件,精确测定晶格常数 (Lattice Constant),并导出结果文件 {output}。",
|
||||
"crystal_size": "请打开桌面上的 {filename} 文件,使用Scherrer公式计算晶粒大小 (Crystal Size),并导出结果文件 {output}。",
|
||||
"stress": "请打开桌面上的 {filename} 文件,进行残余应力分析 (Stress Analysis),并导出结果文件 {output}。",
|
||||
"crystallinity": "请打开桌面上的 {filename} 文件,计算结晶化度 (Crystallinity),并导出结果文件 {output}。",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def print_category_menu():
|
||||
"""打印类别菜单"""
|
||||
print("\n📚 任务类别:")
|
||||
print(" 1. basic_processing (基础处理)")
|
||||
print(" 2. peak_analysis (峰分析)")
|
||||
print(" 3. phase_identification (物相检索)")
|
||||
print(" 4. crystal_parameters (晶体参数)")
|
||||
print(" 5. calibration (校正)")
|
||||
print(" 6. advanced_analysis (高级分析)")
|
||||
|
||||
|
||||
def print_difficulty_menu():
|
||||
"""打印难度菜单"""
|
||||
print("\n📊 难度等级:")
|
||||
print(" 1. easy (简单,3-5步操作)")
|
||||
print(" 2. medium (中等,5-10步操作)")
|
||||
print(" 3. hard (困难,10+步操作)")
|
||||
|
||||
|
||||
def get_user_input():
|
||||
"""交互式获取用户输入"""
|
||||
print("=" * 60)
|
||||
print("🎯 从教程提取任务 - 快速生成工具")
|
||||
print("=" * 60)
|
||||
|
||||
# 任务ID
|
||||
task_id = input("\n📝 任务ID (例如: peak_search_001): ").strip()
|
||||
if not task_id:
|
||||
print("❌ 任务ID不能为空")
|
||||
return None
|
||||
|
||||
# 类别
|
||||
print_category_menu()
|
||||
category_choice = input("\n选择类别 (1-6): ").strip()
|
||||
category = CATEGORY_MAP.get(category_choice)
|
||||
if not category:
|
||||
print("❌ 无效的类别选择")
|
||||
return None
|
||||
|
||||
# 难度
|
||||
print_difficulty_menu()
|
||||
difficulty_choice = input("\n选择难度 (1-3): ").strip()
|
||||
difficulty = DIFFICULTY_MAP.get(difficulty_choice)
|
||||
if not difficulty:
|
||||
print("❌ 无效的难度选择")
|
||||
return None
|
||||
|
||||
# 输入文件
|
||||
print("\n📁 输入文件配置:")
|
||||
source_file = input(" 源文件路径 (相对于data/source/, 例如: DEMO01.MDI): ").strip()
|
||||
if not source_file:
|
||||
source_file = "DEMO01.MDI"
|
||||
|
||||
# 输出文件
|
||||
print("\n📤 输出文件配置:")
|
||||
output_filename = input(" 输出文件名 (例如: result.txt): ").strip()
|
||||
if not output_filename:
|
||||
output_filename = "result.txt"
|
||||
|
||||
# 任务类型(如果类别有模板)
|
||||
task_type = None
|
||||
if category in TASK_TEMPLATES:
|
||||
templates = TASK_TEMPLATES[category]
|
||||
print(f"\n📋 可用任务模板 ({category}):")
|
||||
for i, (key, template) in enumerate(templates.items(), 1):
|
||||
print(f" {i}. {key}")
|
||||
|
||||
use_template = input("\n使用模板? (y/n, 默认n): ").strip().lower()
|
||||
if use_template == 'y':
|
||||
template_choice = input(f"选择模板 (1-{len(templates)}): ").strip()
|
||||
try:
|
||||
template_key = list(templates.keys())[int(template_choice) - 1]
|
||||
task_type = template_key
|
||||
except (ValueError, IndexError):
|
||||
print("⚠️ 无效的模板选择,将使用自定义指令")
|
||||
|
||||
# 指令
|
||||
if task_type and category in TASK_TEMPLATES:
|
||||
# 使用模板
|
||||
template = TASK_TEMPLATES[category][task_type]
|
||||
instruction = template.format(
|
||||
filename=os.path.basename(source_file),
|
||||
output=output_filename
|
||||
)
|
||||
print(f"\n✅ 生成的指令 (模板): {instruction}")
|
||||
confirm = input("使用此指令? (y/n, 默认y): ").strip().lower()
|
||||
if confirm == 'n':
|
||||
instruction = input("\n📝 自定义指令: ").strip()
|
||||
else:
|
||||
# 自定义指令
|
||||
instruction = input("\n📝 任务指令 (中文描述): ").strip()
|
||||
|
||||
if not instruction:
|
||||
print("❌ 指令不能为空")
|
||||
return None
|
||||
|
||||
# 教程来源(可选)
|
||||
tutorial_source = input("\n📚 教程来源 (可选,例如: 教程(1)): ").strip()
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"category": category,
|
||||
"difficulty": difficulty,
|
||||
"instruction": instruction,
|
||||
"source_file": source_file,
|
||||
"output_filename": output_filename,
|
||||
"tutorial_source": tutorial_source,
|
||||
}
|
||||
|
||||
|
||||
def create_task_from_info(info):
|
||||
"""根据信息创建任务"""
|
||||
task_id = info["task_id"]
|
||||
category = info["category"]
|
||||
difficulty = info["difficulty"]
|
||||
instruction = info["instruction"]
|
||||
|
||||
# 构建源文件路径
|
||||
source_file = info["source_file"]
|
||||
if not os.path.isabs(source_file):
|
||||
# 相对路径,假设在 data/source/ 下
|
||||
source_file = f"../../data/source/{source_file}"
|
||||
|
||||
# 构建VM路径
|
||||
filename = os.path.basename(source_file)
|
||||
inject_to = f"C:\\Users\\lzy\\Desktop\\{filename}"
|
||||
|
||||
# 输出文件路径
|
||||
output_filename = info["output_filename"]
|
||||
collect_from = f"C:\\Users\\lzy\\Desktop\\{output_filename}"
|
||||
|
||||
print(f"\n🚀 正在创建任务: {task_id}")
|
||||
print(f" 类别: {category}")
|
||||
print(f" 难度: {difficulty}")
|
||||
print(f" 源文件: {source_file}")
|
||||
print(f" 输出文件: {output_filename}")
|
||||
|
||||
# 调用 init_task
|
||||
try:
|
||||
init_task(
|
||||
task_id=task_id,
|
||||
category=category,
|
||||
difficulty=difficulty,
|
||||
instruction=instruction,
|
||||
project_root=str(project_root)
|
||||
)
|
||||
|
||||
# 更新 task.json
|
||||
task_json_path = project_root / "tasks" / task_id / "task.json"
|
||||
if task_json_path.exists():
|
||||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||||
task_config = json.load(f)
|
||||
|
||||
# 更新输入输出配置
|
||||
task_config["input"] = {
|
||||
"source_file": source_file,
|
||||
"inject_to": inject_to
|
||||
}
|
||||
task_config["output"] = {
|
||||
"expected_file": output_filename,
|
||||
"collect_from": collect_from
|
||||
}
|
||||
|
||||
# 添加教程来源(如果有)
|
||||
if info.get("tutorial_source"):
|
||||
task_config["tutorial_source"] = info["tutorial_source"]
|
||||
|
||||
# 保存
|
||||
with open(task_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(task_config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n✅ 任务创建成功!")
|
||||
print(f" 任务目录: tasks/{task_id}/")
|
||||
print(f" 配置文件: tasks/{task_id}/task.json")
|
||||
print(f"\n📝 下一步:")
|
||||
print(f" 1. 检查并完善 task.json")
|
||||
print(f" 2. 运行: python scripts/tools/collect_task.py {task_id} --mode full")
|
||||
|
||||
return True
|
||||
else:
|
||||
print(f"❌ 任务目录创建失败: {task_json_path}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 创建任务时出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
try:
|
||||
info = get_user_input()
|
||||
if info:
|
||||
create_task_from_info(info)
|
||||
else:
|
||||
print("\n❌ 任务创建取消")
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ 用户取消操作")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
178
scripts/tools/init_task.py
Normal file
178
scripts/tools/init_task.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
任务初始化工具
|
||||
快速创建新任务的目录结构和配置文件模板
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
|
||||
TASK_JSON_TEMPLATE = {
|
||||
"id": "",
|
||||
"category": "basic_processing", # basic_processing, peak_analysis, phase_identification, compound_tasks
|
||||
"difficulty": "easy", # easy, medium, hard
|
||||
|
||||
"instruction": "请填写任务指令",
|
||||
|
||||
"input": {
|
||||
"source_file": "../../data/source/DEMO01.MDI",
|
||||
"inject_to": "C:\\Users\\lzy\\Desktop\\DEMO01.MDI"
|
||||
},
|
||||
|
||||
"output": {
|
||||
"expected_file": "result.txt",
|
||||
"collect_from": "C:\\Users\\lzy\\Desktop\\result.txt"
|
||||
},
|
||||
|
||||
"evaluation": {
|
||||
"method": "xrd_data_compare",
|
||||
"ground_truth": "ground_truth/result.txt",
|
||||
"target_output": "agent_output/result.txt",
|
||||
"tolerance": 1e-4
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def init_task(task_id, project_root=".", force=False, category=None, difficulty=None, instruction=None):
|
||||
"""
|
||||
初始化新任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
project_root: 项目根目录
|
||||
force: 是否覆盖已存在的任务
|
||||
category: 任务类别(可选)
|
||||
difficulty: 任务难度(可选)
|
||||
instruction: 任务指令(可选)
|
||||
"""
|
||||
task_dir = os.path.join(project_root, "tasks", task_id)
|
||||
|
||||
# 检查是否已存在
|
||||
if os.path.exists(task_dir) and not force:
|
||||
print(f"❌ 任务目录已存在: {task_dir}")
|
||||
print(" 使用 --force 参数强制覆盖")
|
||||
return False
|
||||
|
||||
print(f"创建任务: {task_id}")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建目录结构
|
||||
directories = [
|
||||
task_dir,
|
||||
os.path.join(task_dir, "ground_truth"),
|
||||
os.path.join(task_dir, "human_demo"),
|
||||
os.path.join(task_dir, "human_demo", "screens"),
|
||||
os.path.join(task_dir, "agent_output")
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
print(f"✅ 创建目录: {os.path.relpath(directory, project_root)}")
|
||||
|
||||
# 创建task.json
|
||||
task_config = TASK_JSON_TEMPLATE.copy()
|
||||
task_config["id"] = task_id
|
||||
|
||||
# 更新可选参数
|
||||
if category:
|
||||
task_config["category"] = category
|
||||
if difficulty:
|
||||
task_config["difficulty"] = difficulty
|
||||
if instruction:
|
||||
task_config["instruction"] = instruction
|
||||
|
||||
task_json_path = os.path.join(task_dir, "task.json")
|
||||
with open(task_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(task_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"✅ 创建配置: {os.path.relpath(task_json_path, project_root)}")
|
||||
|
||||
# 创建README
|
||||
readme_content = f"""# 任务: {task_id}
|
||||
|
||||
## 任务信息
|
||||
- **ID**: {task_id}
|
||||
- **类别**: {task_config['category']}
|
||||
- **难度**: {task_config['difficulty']}
|
||||
|
||||
## 指令
|
||||
{task_config['instruction']}
|
||||
|
||||
## 数据采集状态
|
||||
- [ ] 环境重置与文件注入
|
||||
- [ ] 操作轨迹录制
|
||||
- [ ] 结果文件收集
|
||||
- [ ] 坐标转换处理
|
||||
- [ ] 评测验证
|
||||
|
||||
## 采集命令
|
||||
```bash
|
||||
# 完整流程
|
||||
python scripts/collect_task.py {task_id} --mode full
|
||||
|
||||
# 分步执行
|
||||
python scripts/collect_task.py {task_id} --mode reset
|
||||
python scripts/collect_task.py {task_id} --mode record
|
||||
python scripts/collect_task.py {task_id} --mode collect
|
||||
python scripts/process_trajectory.py {task_id}
|
||||
python scripts/run_eval.py {task_id}
|
||||
```
|
||||
|
||||
## 文件结构
|
||||
```
|
||||
{task_id}/
|
||||
├── task.json # 任务配置
|
||||
├── ground_truth/ # 标准答案输出
|
||||
├── human_demo/ # 人类操作轨迹
|
||||
│ ├── actions_raw.json # 原始轨迹(未转换坐标)
|
||||
│ ├── actions.json # 处理后轨迹(已转换坐标)
|
||||
│ └── screens/ # 截图序列
|
||||
└── agent_output/ # Agent输出(评测时使用)
|
||||
```
|
||||
"""
|
||||
|
||||
readme_path = os.path.join(task_dir, "README.md")
|
||||
with open(readme_path, 'w', encoding='utf-8') as f:
|
||||
f.write(readme_content)
|
||||
|
||||
print(f"✅ 创建说明: {os.path.relpath(readme_path, project_root)}")
|
||||
|
||||
print("=" * 60)
|
||||
print("✅ 任务初始化完成!")
|
||||
print("\n📝 下一步:")
|
||||
print(f" 1. 编辑任务配置: {task_json_path}")
|
||||
print(f" 2. 确保输入文件存在:例如 {task_config['input']['source_file']}")
|
||||
print(f" 3. 开始数据采集: python scripts/tools/collect_task.py {task_id}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="初始化新任务",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
# 创建新任务
|
||||
python scripts/init_task.py smoothing_001
|
||||
|
||||
# 强制覆盖已存在的任务
|
||||
python scripts/init_task.py smoothing_001 --force
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("task_id", help="任务ID(建议格式: category_序号)")
|
||||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||||
parser.add_argument("--force", action="store_true", help="强制覆盖已存在的任务")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
success = init_task(args.task_id, args.project_root, args.force)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
165
scripts/tools/process_trajectory.py
Normal file
165
scripts/tools/process_trajectory.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
轨迹数据后处理
|
||||
将录制的原始Host坐标转换为VM内坐标
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def process_trajectory(task_id, project_root=".", force=False,
|
||||
scale_x_adjust=1.0, scale_y_adjust=1.0,
|
||||
offset_x=0, offset_y=0):
|
||||
"""
|
||||
处理轨迹数据:将Host坐标转换为VM坐标
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
project_root: 项目根目录
|
||||
force: 是否强制覆盖已有的处理结果
|
||||
scale_x_adjust: X轴缩放调整系数
|
||||
scale_y_adjust: Y轴缩放调整系数
|
||||
offset_x: X轴偏移调整
|
||||
offset_y: Y轴偏移调整
|
||||
"""
|
||||
# 路径
|
||||
task_dir = os.path.join(project_root, "tasks", task_id)
|
||||
human_demo_dir = os.path.join(task_dir, "human_demo")
|
||||
raw_path = os.path.join(human_demo_dir, "actions_raw.json")
|
||||
processed_path = os.path.join(human_demo_dir, "actions.json")
|
||||
|
||||
# 检查文件
|
||||
if not os.path.exists(raw_path):
|
||||
logger.error(f"❌ 原始轨迹文件不存在: {raw_path}")
|
||||
logger.info(" 请先运行: python scripts/collect_task.py <task_id> --mode record")
|
||||
return False
|
||||
|
||||
if os.path.exists(processed_path) and not force:
|
||||
logger.warning(f"⚠️ 处理后的文件已存在: {processed_path}")
|
||||
logger.info(" 使用 --force 参数强制覆盖")
|
||||
return False
|
||||
|
||||
# 读取原始数据
|
||||
logger.info(f"读取原始轨迹: {raw_path}")
|
||||
with open(raw_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
metadata = data['metadata']
|
||||
actions = data['actions']
|
||||
|
||||
logger.info(f"任务ID: {metadata['task_id']}")
|
||||
logger.info(f"动作数: {len(actions)}")
|
||||
|
||||
# 获取分辨率信息
|
||||
if 'vm_resolution' in metadata and 'vm_screenshot_resolution' in metadata:
|
||||
vm_w, vm_h = metadata['vm_resolution']
|
||||
screenshot_w, screenshot_h = metadata['vm_screenshot_resolution']
|
||||
|
||||
# 计算缩放比例
|
||||
# 注意:Host端的点击坐标对应截图坐标,需要转换为VM内实际坐标
|
||||
scale_x = (vm_w / screenshot_w) * scale_x_adjust
|
||||
scale_y = (vm_h / screenshot_h) * scale_y_adjust
|
||||
|
||||
logger.info(f"VM分辨率: {vm_w}x{vm_h}")
|
||||
logger.info(f"截图分辨率: {screenshot_w}x{screenshot_h}")
|
||||
logger.info(f"转换比例: X={scale_x:.3f}, Y={scale_y:.3f}")
|
||||
|
||||
if scale_x_adjust != 1.0 or scale_y_adjust != 1.0:
|
||||
logger.info(f"应用调整系数: X={scale_x_adjust}, Y={scale_y_adjust}")
|
||||
if offset_x != 0 or offset_y != 0:
|
||||
logger.info(f"应用偏移调整: X={offset_x}, Y={offset_y}")
|
||||
else:
|
||||
logger.warning("⚠️ 元数据缺少分辨率信息,使用默认比例1.0")
|
||||
scale_x = 1.0 * scale_x_adjust
|
||||
scale_y = 1.0 * scale_y_adjust
|
||||
|
||||
# 转换坐标
|
||||
converted_count = 0
|
||||
for action in actions:
|
||||
if 'pos_host' in action and action['pos_host']:
|
||||
host_x, host_y = action['pos_host']
|
||||
|
||||
# 应用转换
|
||||
vm_x = int(host_x * scale_x + offset_x)
|
||||
vm_y = int(host_y * scale_y + offset_y)
|
||||
|
||||
action['pos_vm'] = [vm_x, vm_y]
|
||||
converted_count += 1
|
||||
|
||||
logger.info(f"✅ 坐标转换完成: {converted_count}/{len(actions)} 个动作")
|
||||
|
||||
# 添加处理信息到元数据
|
||||
metadata['processed'] = {
|
||||
"processed_at": __import__('datetime').datetime.now().isoformat(),
|
||||
"scale_x": scale_x,
|
||||
"scale_y": scale_y,
|
||||
"offset_x": offset_x,
|
||||
"offset_y": offset_y,
|
||||
"converted_actions": converted_count
|
||||
}
|
||||
|
||||
# 保存处理后的数据
|
||||
logger.info(f"保存处理后的轨迹: {processed_path}")
|
||||
with open(processed_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info("✅ 处理完成!")
|
||||
|
||||
# 输出示例
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 坐标转换示例(前5个点击):")
|
||||
print("-" * 60)
|
||||
|
||||
click_count = 0
|
||||
for action in actions:
|
||||
if action['type'] == 'click' and 'pos_host' in action:
|
||||
host_x, host_y = action['pos_host']
|
||||
vm_x, vm_y = action['pos_vm'] if action['pos_vm'] else (0, 0)
|
||||
# 转换为整数显示
|
||||
print(f" Host({int(host_x):4d}, {int(host_y):4d}) → VM({int(vm_x):4d}, {int(vm_y):4d})")
|
||||
|
||||
click_count += 1
|
||||
if click_count >= 5:
|
||||
break
|
||||
|
||||
print("=" * 60)
|
||||
print("\n💡 下一步:可视化验证(可选)")
|
||||
print(f" python scripts/visualize_trajectory.py {task_id}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="处理轨迹数据,转换坐标")
|
||||
parser.add_argument("task_id", help="任务ID")
|
||||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||||
parser.add_argument("--force", action="store_true", help="强制覆盖已有文件")
|
||||
parser.add_argument("--scale-x", type=float, default=1.0, help="X轴缩放调整系数")
|
||||
parser.add_argument("--scale-y", type=float, default=1.0, help="Y轴缩放调整系数")
|
||||
parser.add_argument("--offset-x", type=int, default=0, help="X轴偏移调整")
|
||||
parser.add_argument("--offset-y", type=int, default=0, help="Y轴偏移调整")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
success = process_trajectory(
|
||||
task_id=args.task_id,
|
||||
project_root=args.project_root,
|
||||
force=args.force,
|
||||
scale_x_adjust=args.scale_x,
|
||||
scale_y_adjust=args.scale_y,
|
||||
offset_x=args.offset_x,
|
||||
offset_y=args.offset_y
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
261
scripts/tools/run_eval.py
Normal file
261
scripts/tools/run_eval.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
评测入口脚本
|
||||
支持单任务或批量评测
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.evaluator import evaluate
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def evaluate_task(task_id, project_root=".", verbose=True):
|
||||
"""
|
||||
评测单个任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
project_root: 项目根目录
|
||||
verbose: 是否详细输出
|
||||
|
||||
Returns:
|
||||
tuple: (score, message, details)
|
||||
"""
|
||||
task_dir = os.path.join(project_root, "tasks", task_id)
|
||||
task_json_path = os.path.join(task_dir, "task.json")
|
||||
|
||||
# 检查任务配置
|
||||
if not os.path.exists(task_json_path):
|
||||
logger.error(f"❌ 任务配置不存在: {task_json_path}")
|
||||
return 0, "任务配置不存在", {}
|
||||
|
||||
# 加载任务配置
|
||||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||||
task_config = json.load(f)
|
||||
|
||||
if verbose:
|
||||
print("\n" + "=" * 60)
|
||||
print(f"📝 评测任务: {task_id}")
|
||||
print("=" * 60)
|
||||
print(f"类别: {task_config.get('category', 'N/A')}")
|
||||
print(f"难度: {task_config.get('difficulty', 'N/A')}")
|
||||
print(f"指令: {task_config.get('instruction', 'N/A')}")
|
||||
print("=" * 60)
|
||||
|
||||
# 获取评测配置
|
||||
eval_config = task_config.get('evaluation', {})
|
||||
method = eval_config.get('method', 'xrd_data_compare')
|
||||
|
||||
# 构建文件路径
|
||||
gt_path = os.path.join(task_dir, eval_config.get('ground_truth', ''))
|
||||
agent_path = os.path.join(task_dir, eval_config.get('target_output', ''))
|
||||
tolerance = eval_config.get('tolerance', 1e-4)
|
||||
|
||||
# 检查文件
|
||||
if not os.path.exists(gt_path):
|
||||
logger.error(f"❌ Ground truth文件不存在: {gt_path}")
|
||||
return 0, "Ground truth文件不存在", {}
|
||||
|
||||
if not os.path.exists(agent_path):
|
||||
logger.error(f"❌ Agent输出文件不存在: {agent_path}")
|
||||
return 0, "Agent输出文件不存在", {}
|
||||
|
||||
# 执行评测
|
||||
try:
|
||||
if method == 'xrd_data_compare':
|
||||
score, message = evaluate(gt_path, agent_path, tolerance, mode="xrd_data")
|
||||
elif method == 'peak_report_compare':
|
||||
score, message = evaluate(gt_path, agent_path, tolerance, mode="peak_report")
|
||||
else:
|
||||
logger.warning(f"⚠️ 未知的评测方法: {method}")
|
||||
score, message = 0, f"未知的评测方法: {method}"
|
||||
|
||||
details = {
|
||||
"task_id": task_id,
|
||||
"method": method,
|
||||
"ground_truth": gt_path,
|
||||
"agent_output": agent_path,
|
||||
"tolerance": tolerance,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(f"\n📊 评测结果:")
|
||||
print(f" Score: {score}")
|
||||
print(f" {message}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
return score, message, details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 评测失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 0, f"评测失败: {str(e)}", {}
|
||||
|
||||
|
||||
def evaluate_batch(task_ids, project_root=".", output_file=None):
|
||||
"""
|
||||
批量评测多个任务
|
||||
|
||||
Args:
|
||||
task_ids: 任务ID列表
|
||||
project_root: 项目根目录
|
||||
output_file: 结果输出文件(JSON格式)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 批量评测")
|
||||
print("=" * 60)
|
||||
print(f"任务数: {len(task_ids)}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
results = []
|
||||
total_score = 0
|
||||
|
||||
for i, task_id in enumerate(task_ids, 1):
|
||||
print(f"\n[{i}/{len(task_ids)}] 评测: {task_id}")
|
||||
score, message, details = evaluate_task(task_id, project_root, verbose=False)
|
||||
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"score": score,
|
||||
"message": message,
|
||||
**details
|
||||
}
|
||||
results.append(result)
|
||||
total_score += score
|
||||
|
||||
status = "✅ 通过" if score == 1 else "❌ 失败"
|
||||
print(f" {status}: {message}")
|
||||
|
||||
# 统计
|
||||
pass_count = sum(1 for r in results if r['score'] == 1)
|
||||
pass_rate = pass_count / len(task_ids) * 100 if task_ids else 0
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📈 评测统计")
|
||||
print("=" * 60)
|
||||
print(f"总任务数: {len(task_ids)}")
|
||||
print(f"通过数: {pass_count}")
|
||||
print(f"失败数: {len(task_ids) - pass_count}")
|
||||
print(f"通过率: {pass_rate:.1f}%")
|
||||
print(f"平均分: {total_score / len(task_ids):.2f}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# 保存结果
|
||||
if output_file:
|
||||
output_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_tasks": len(task_ids),
|
||||
"pass_count": pass_count,
|
||||
"pass_rate": pass_rate,
|
||||
"results": results
|
||||
}
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"📄 详细结果已保存到: {output_file}\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def discover_tasks(project_root="."):
|
||||
"""
|
||||
自动发现所有任务
|
||||
|
||||
Returns:
|
||||
list: 任务ID列表
|
||||
"""
|
||||
tasks_dir = os.path.join(project_root, "tasks")
|
||||
|
||||
if not os.path.exists(tasks_dir):
|
||||
return []
|
||||
|
||||
task_ids = []
|
||||
for item in os.listdir(tasks_dir):
|
||||
task_dir = os.path.join(tasks_dir, item)
|
||||
task_json = os.path.join(task_dir, "task.json")
|
||||
|
||||
if os.path.isdir(task_dir) and os.path.exists(task_json):
|
||||
task_ids.append(item)
|
||||
|
||||
return sorted(task_ids)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JADE Benchmark 评测工具",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
# 评测单个任务
|
||||
python scripts/run_eval.py smoothing_001
|
||||
|
||||
# 评测多个任务
|
||||
python scripts/run_eval.py smoothing_001 peak_search_001
|
||||
|
||||
# 评测所有任务
|
||||
python scripts/run_eval.py --all
|
||||
|
||||
# 保存结果到文件
|
||||
python scripts/run_eval.py --all --output results.json
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("task_ids", nargs="*", help="任务ID列表")
|
||||
parser.add_argument("--all", action="store_true", help="评测所有任务")
|
||||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||||
parser.add_argument("--output", help="结果输出文件(JSON格式)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确定要评测的任务
|
||||
if args.all:
|
||||
task_ids = discover_tasks(args.project_root)
|
||||
if not task_ids:
|
||||
logger.error("❌ 未找到任何任务")
|
||||
sys.exit(1)
|
||||
logger.info(f"发现 {len(task_ids)} 个任务")
|
||||
elif args.task_ids:
|
||||
task_ids = args.task_ids
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# 执行评测
|
||||
try:
|
||||
if len(task_ids) == 1:
|
||||
# 单任务评测
|
||||
score, message, _ = evaluate_task(task_ids[0], args.project_root)
|
||||
sys.exit(0 if score == 1 else 1)
|
||||
else:
|
||||
# 批量评测
|
||||
results = evaluate_batch(task_ids, args.project_root, args.output)
|
||||
|
||||
# 返回码:全部通过返回0,否则返回1
|
||||
all_pass = all(r['score'] == 1 for r in results)
|
||||
sys.exit(0 if all_pass else 1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⏹ 评测已取消")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user