Files
sci-gui-agent-benchmark/scripts/tools/process_trajectory.py
2026-01-12 18:30:12 +08:00

166 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
轨迹数据后处理
将录制的原始Host坐标转换为VM内坐标
"""
import json
import os
import sys
import argparse
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def process_trajectory(task_id, project_root=".", force=False,
scale_x_adjust=1.0, scale_y_adjust=1.0,
offset_x=0, offset_y=0):
"""
处理轨迹数据将Host坐标转换为VM坐标
Args:
task_id: 任务ID
project_root: 项目根目录
force: 是否强制覆盖已有的处理结果
scale_x_adjust: X轴缩放调整系数
scale_y_adjust: Y轴缩放调整系数
offset_x: X轴偏移调整
offset_y: Y轴偏移调整
"""
# 路径
task_dir = os.path.join(project_root, "tasks", task_id)
human_demo_dir = os.path.join(task_dir, "human_demo")
raw_path = os.path.join(human_demo_dir, "actions_raw.json")
processed_path = os.path.join(human_demo_dir, "actions.json")
# 检查文件
if not os.path.exists(raw_path):
logger.error(f"❌ 原始轨迹文件不存在: {raw_path}")
logger.info(" 请先运行: python scripts/collect_task.py <task_id> --mode record")
return False
if os.path.exists(processed_path) and not force:
logger.warning(f"⚠️ 处理后的文件已存在: {processed_path}")
logger.info(" 使用 --force 参数强制覆盖")
return False
# 读取原始数据
logger.info(f"读取原始轨迹: {raw_path}")
with open(raw_path, 'r', encoding='utf-8') as f:
data = json.load(f)
metadata = data['metadata']
actions = data['actions']
logger.info(f"任务ID: {metadata['task_id']}")
logger.info(f"动作数: {len(actions)}")
# 获取分辨率信息
if 'vm_resolution' in metadata and 'vm_screenshot_resolution' in metadata:
vm_w, vm_h = metadata['vm_resolution']
screenshot_w, screenshot_h = metadata['vm_screenshot_resolution']
# 计算缩放比例
# 注意Host端的点击坐标对应截图坐标需要转换为VM内实际坐标
scale_x = (vm_w / screenshot_w) * scale_x_adjust
scale_y = (vm_h / screenshot_h) * scale_y_adjust
logger.info(f"VM分辨率: {vm_w}x{vm_h}")
logger.info(f"截图分辨率: {screenshot_w}x{screenshot_h}")
logger.info(f"转换比例: X={scale_x:.3f}, Y={scale_y:.3f}")
if scale_x_adjust != 1.0 or scale_y_adjust != 1.0:
logger.info(f"应用调整系数: X={scale_x_adjust}, Y={scale_y_adjust}")
if offset_x != 0 or offset_y != 0:
logger.info(f"应用偏移调整: X={offset_x}, Y={offset_y}")
else:
logger.warning("⚠️ 元数据缺少分辨率信息使用默认比例1.0")
scale_x = 1.0 * scale_x_adjust
scale_y = 1.0 * scale_y_adjust
# 转换坐标
converted_count = 0
for action in actions:
if 'pos_host' in action and action['pos_host']:
host_x, host_y = action['pos_host']
# 应用转换
vm_x = int(host_x * scale_x + offset_x)
vm_y = int(host_y * scale_y + offset_y)
action['pos_vm'] = [vm_x, vm_y]
converted_count += 1
logger.info(f"✅ 坐标转换完成: {converted_count}/{len(actions)} 个动作")
# 添加处理信息到元数据
metadata['processed'] = {
"processed_at": __import__('datetime').datetime.now().isoformat(),
"scale_x": scale_x,
"scale_y": scale_y,
"offset_x": offset_x,
"offset_y": offset_y,
"converted_actions": converted_count
}
# 保存处理后的数据
logger.info(f"保存处理后的轨迹: {processed_path}")
with open(processed_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info("✅ 处理完成!")
# 输出示例
print("\n" + "=" * 60)
print("📊 坐标转换示例前5个点击:")
print("-" * 60)
click_count = 0
for action in actions:
if action['type'] == 'click' and 'pos_host' in action:
host_x, host_y = action['pos_host']
vm_x, vm_y = action['pos_vm'] if action['pos_vm'] else (0, 0)
# 转换为整数显示
print(f" Host({int(host_x):4d}, {int(host_y):4d}) → VM({int(vm_x):4d}, {int(vm_y):4d})")
click_count += 1
if click_count >= 5:
break
print("=" * 60)
print("\n💡 下一步:可视化验证(可选)")
print(f" python scripts/visualize_trajectory.py {task_id}")
print("=" * 60 + "\n")
return True
def main():
parser = argparse.ArgumentParser(description="处理轨迹数据,转换坐标")
parser.add_argument("task_id", help="任务ID")
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--force", action="store_true", help="强制覆盖已有文件")
parser.add_argument("--scale-x", type=float, default=1.0, help="X轴缩放调整系数")
parser.add_argument("--scale-y", type=float, default=1.0, help="Y轴缩放调整系数")
parser.add_argument("--offset-x", type=int, default=0, help="X轴偏移调整")
parser.add_argument("--offset-y", type=int, default=0, help="Y轴偏移调整")
args = parser.parse_args()
success = process_trajectory(
task_id=args.task_id,
project_root=args.project_root,
force=args.force,
scale_x_adjust=args.scale_x,
scale_y_adjust=args.scale_y,
offset_x=args.offset_x,
offset_y=args.offset_y
)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()