166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
"""
|
||
轨迹数据后处理
|
||
将录制的原始Host坐标转换为VM内坐标
|
||
"""
|
||
import json
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import logging
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def process_trajectory(task_id, project_root=".", force=False,
|
||
scale_x_adjust=1.0, scale_y_adjust=1.0,
|
||
offset_x=0, offset_y=0):
|
||
"""
|
||
处理轨迹数据:将Host坐标转换为VM坐标
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
project_root: 项目根目录
|
||
force: 是否强制覆盖已有的处理结果
|
||
scale_x_adjust: X轴缩放调整系数
|
||
scale_y_adjust: Y轴缩放调整系数
|
||
offset_x: X轴偏移调整
|
||
offset_y: Y轴偏移调整
|
||
"""
|
||
# 路径
|
||
task_dir = os.path.join(project_root, "tasks", task_id)
|
||
human_demo_dir = os.path.join(task_dir, "human_demo")
|
||
raw_path = os.path.join(human_demo_dir, "actions_raw.json")
|
||
processed_path = os.path.join(human_demo_dir, "actions.json")
|
||
|
||
# 检查文件
|
||
if not os.path.exists(raw_path):
|
||
logger.error(f"❌ 原始轨迹文件不存在: {raw_path}")
|
||
logger.info(" 请先运行: python scripts/collect_task.py <task_id> --mode record")
|
||
return False
|
||
|
||
if os.path.exists(processed_path) and not force:
|
||
logger.warning(f"⚠️ 处理后的文件已存在: {processed_path}")
|
||
logger.info(" 使用 --force 参数强制覆盖")
|
||
return False
|
||
|
||
# 读取原始数据
|
||
logger.info(f"读取原始轨迹: {raw_path}")
|
||
with open(raw_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
metadata = data['metadata']
|
||
actions = data['actions']
|
||
|
||
logger.info(f"任务ID: {metadata['task_id']}")
|
||
logger.info(f"动作数: {len(actions)}")
|
||
|
||
# 获取分辨率信息
|
||
if 'vm_resolution' in metadata and 'vm_screenshot_resolution' in metadata:
|
||
vm_w, vm_h = metadata['vm_resolution']
|
||
screenshot_w, screenshot_h = metadata['vm_screenshot_resolution']
|
||
|
||
# 计算缩放比例
|
||
# 注意:Host端的点击坐标对应截图坐标,需要转换为VM内实际坐标
|
||
scale_x = (vm_w / screenshot_w) * scale_x_adjust
|
||
scale_y = (vm_h / screenshot_h) * scale_y_adjust
|
||
|
||
logger.info(f"VM分辨率: {vm_w}x{vm_h}")
|
||
logger.info(f"截图分辨率: {screenshot_w}x{screenshot_h}")
|
||
logger.info(f"转换比例: X={scale_x:.3f}, Y={scale_y:.3f}")
|
||
|
||
if scale_x_adjust != 1.0 or scale_y_adjust != 1.0:
|
||
logger.info(f"应用调整系数: X={scale_x_adjust}, Y={scale_y_adjust}")
|
||
if offset_x != 0 or offset_y != 0:
|
||
logger.info(f"应用偏移调整: X={offset_x}, Y={offset_y}")
|
||
else:
|
||
logger.warning("⚠️ 元数据缺少分辨率信息,使用默认比例1.0")
|
||
scale_x = 1.0 * scale_x_adjust
|
||
scale_y = 1.0 * scale_y_adjust
|
||
|
||
# 转换坐标
|
||
converted_count = 0
|
||
for action in actions:
|
||
if 'pos_host' in action and action['pos_host']:
|
||
host_x, host_y = action['pos_host']
|
||
|
||
# 应用转换
|
||
vm_x = int(host_x * scale_x + offset_x)
|
||
vm_y = int(host_y * scale_y + offset_y)
|
||
|
||
action['pos_vm'] = [vm_x, vm_y]
|
||
converted_count += 1
|
||
|
||
logger.info(f"✅ 坐标转换完成: {converted_count}/{len(actions)} 个动作")
|
||
|
||
# 添加处理信息到元数据
|
||
metadata['processed'] = {
|
||
"processed_at": __import__('datetime').datetime.now().isoformat(),
|
||
"scale_x": scale_x,
|
||
"scale_y": scale_y,
|
||
"offset_x": offset_x,
|
||
"offset_y": offset_y,
|
||
"converted_actions": converted_count
|
||
}
|
||
|
||
# 保存处理后的数据
|
||
logger.info(f"保存处理后的轨迹: {processed_path}")
|
||
with open(processed_path, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||
|
||
logger.info("✅ 处理完成!")
|
||
|
||
# 输出示例
|
||
print("\n" + "=" * 60)
|
||
print("📊 坐标转换示例(前5个点击):")
|
||
print("-" * 60)
|
||
|
||
click_count = 0
|
||
for action in actions:
|
||
if action['type'] == 'click' and 'pos_host' in action:
|
||
host_x, host_y = action['pos_host']
|
||
vm_x, vm_y = action['pos_vm'] if action['pos_vm'] else (0, 0)
|
||
# 转换为整数显示
|
||
print(f" Host({int(host_x):4d}, {int(host_y):4d}) → VM({int(vm_x):4d}, {int(vm_y):4d})")
|
||
|
||
click_count += 1
|
||
if click_count >= 5:
|
||
break
|
||
|
||
print("=" * 60)
|
||
print("\n💡 下一步:可视化验证(可选)")
|
||
print(f" python scripts/visualize_trajectory.py {task_id}")
|
||
print("=" * 60 + "\n")
|
||
|
||
return True
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="处理轨迹数据,转换坐标")
|
||
parser.add_argument("task_id", help="任务ID")
|
||
parser.add_argument("--project-root", default=".", help="项目根目录")
|
||
parser.add_argument("--force", action="store_true", help="强制覆盖已有文件")
|
||
parser.add_argument("--scale-x", type=float, default=1.0, help="X轴缩放调整系数")
|
||
parser.add_argument("--scale-y", type=float, default=1.0, help="Y轴缩放调整系数")
|
||
parser.add_argument("--offset-x", type=int, default=0, help="X轴偏移调整")
|
||
parser.add_argument("--offset-y", type=int, default=0, help="Y轴偏移调整")
|
||
|
||
args = parser.parse_args()
|
||
|
||
success = process_trajectory(
|
||
task_id=args.task_id,
|
||
project_root=args.project_root,
|
||
force=args.force,
|
||
scale_x_adjust=args.scale_x,
|
||
scale_y_adjust=args.scale_y,
|
||
offset_x=args.offset_x,
|
||
offset_y=args.offset_y
|
||
)
|
||
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|