530 lines
20 KiB
Python
530 lines
20 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
refine_metadata.py — 将现有 benchmark JSON 的 metadata.steps 精炼为原子级 GUI 操作步骤。
|
||
直接覆盖原文件,原始步骤保存到 metadata.steps_original。
|
||
|
||
用法:
|
||
# 处理 avogadro 的所有任务
|
||
python refine_metadata.py --software avogadro
|
||
|
||
# 处理指定任务
|
||
python refine_metadata.py --tasks avogadro/building-organic-molecules_task1.json
|
||
|
||
# 全部 7 个软件
|
||
python refine_metadata.py
|
||
|
||
# 强制覆盖已精炼过的 (有 steps_original 的)
|
||
python refine_metadata.py --force
|
||
|
||
# 预览模式 (不调用 API)
|
||
python refine_metadata.py --dry-run
|
||
|
||
环境变量:
|
||
OPENAI_API_KEY — API Key (必需)
|
||
OPENAI_BASE_URL — API Base URL (默认 https://api.openai.com/v1)
|
||
EXTRACT_MODEL — 模型名 (默认 gemini-3.1-pro-preview)
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import asyncio
|
||
import aiohttp
|
||
import base64
|
||
import logging
|
||
import json
|
||
import re
|
||
import shutil
|
||
import tempfile
|
||
import argparse
|
||
from pathlib import Path
|
||
from typing import List, Optional, Dict, Tuple
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
import io
|
||
|
||
# ─── Configuration ───────────────────────────────────────────────────────────
|
||
|
||
SCRIPT_DIR = Path(__file__).parent
|
||
PROJECT_ROOT = SCRIPT_DIR.parent
|
||
|
||
API_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||
API_URL = f"{API_BASE_URL}/chat/completions"
|
||
API_KEY = os.getenv("OPENAI_API_KEY")
|
||
MODEL_NAME = os.getenv("EXTRACT_MODEL", "gemini-3.1-pro-preview")
|
||
|
||
MAX_CONCURRENT_REQUESTS = 3
|
||
MAX_RETRY_ATTEMPTS = 3
|
||
RETRY_DELAY = 5
|
||
RETRY_BACKOFF = 2
|
||
MAX_IMAGES_PER_REQUEST = 20
|
||
MAX_TOKENS = 8192
|
||
|
||
EXAMPLES_FOLDER = SCRIPT_DIR / "examples"
|
||
INPUT_FOLDER = SCRIPT_DIR / "inputs"
|
||
|
||
# 只处理这 7 个软件
|
||
TARGET_SOFTWARE = {"avogadro", "imagej", "jade", "origin", "ovito", "pymol", "vesta"}
|
||
|
||
# ─── Prompt ──────────────────────────────────────────────────────────────────
|
||
|
||
SYSTEM_PROMPT = """你是一位桌面科研软件 GUI 自动化专家。你的任务是为一个 **已确定的具体任务** 生成极其详细的、原子级 GUI 操作步骤。
|
||
|
||
## 背景
|
||
这些步骤将作为"教程级操作指南",直接注入给 AI Agent 去操控桌面软件。Agent 会严格按步骤执行,因此:
|
||
- **每一步必须是一个不可分割的原子 GUI 操作**(单击一个按钮、输入一个值、选择一个菜单项等)
|
||
- 不能跳步或合并操作
|
||
- 必须包含所有隐含的中间操作
|
||
|
||
## 你将收到
|
||
1. **教程文档的截图**(PDF 教程页面图片)— 作为操作上下文参考
|
||
2. **任务目标** (task_goal) — 一句话描述要完成的任务
|
||
3. **当前粗粒度步骤** (current_steps) — 现有的概要步骤,但粒度太粗
|
||
|
||
## 你需要输出
|
||
根据教程文档内容,将粗粒度步骤细化为**原子级 GUI 操作步骤**。
|
||
|
||
## 步骤编写规则
|
||
|
||
### 每步必须满足:
|
||
1. **只包含一个明确的、不可分割的 GUI 原子动作**
|
||
2. **必须指明**:操作的控件类型(菜单项/按钮/输入框/复选框/下拉菜单/树节点等)
|
||
3. **必须指明**:控件的标签文字或位置描述
|
||
4. **必须指明**:操作类型(单击/双击/右键/输入文字/勾选/选择等)
|
||
5. **包含所有隐含步骤**:不要遗漏任何中间操作
|
||
|
||
### ❌ 错误(太粗):
|
||
- "搜索 benzene 并插入" — 缺少展开文件夹、选中文件、点击插入等中间步骤
|
||
- "设置平滑参数" — 没说具体修改哪个控件、输入什么值
|
||
- "在对话框中配置选项" — 没有分解为逐个控件操作
|
||
|
||
### ✅ 正确(原子级):
|
||
- "在\"筛选\"输入框中,将光标定位到输入框,清空已有内容,输入 benzene"
|
||
- "在筛选结果的树形列表中,双击展开名为 aromatics 的文件夹"
|
||
- "在展开的列表中,单击选中 benzene.cjson 文件"
|
||
- "点击对话框底部的\"插入\"按钮"
|
||
|
||
## 输出格式
|
||
只返回一个 JSON 对象:
|
||
{
|
||
"steps": "细化后的操作步骤字符串,每步带编号,用换行符分隔"
|
||
}
|
||
|
||
不要添加任何其他文本。只输出 JSON。"""
|
||
|
||
|
||
# ─── Logging ─────────────────────────────────────────────────────────────────
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||
handlers=[logging.StreamHandler(sys.stdout)]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ─── Stats ───────────────────────────────────────────────────────────────────
|
||
|
||
@dataclass
|
||
class Stats:
|
||
total: int = 0
|
||
completed: int = 0
|
||
failed: int = 0
|
||
skipped: int = 0
|
||
retries: int = 0
|
||
start_time: datetime = field(default_factory=datetime.now)
|
||
failures: List[Tuple[str, str]] = field(default_factory=list)
|
||
|
||
def log_progress(self):
|
||
done = self.completed + self.failed + self.skipped
|
||
pct = (done / self.total * 100) if self.total else 0
|
||
logger.info(f"进度: {done}/{self.total} ({pct:.0f}%) | 成功={self.completed} 失败={self.failed} 跳过={self.skipped}")
|
||
|
||
def summary(self):
|
||
elapsed = (datetime.now() - self.start_time).total_seconds()
|
||
logger.info("=" * 60)
|
||
logger.info(f"完成! 总计={self.total} 成功={self.completed} 失败={self.failed} 跳过={self.skipped} 重试={self.retries}")
|
||
logger.info(f"耗时: {int(elapsed//60)}m{int(elapsed%60)}s")
|
||
if self.failures:
|
||
logger.info("失败列表:")
|
||
for path, err in self.failures:
|
||
logger.info(f" {path}: {err[:100]}")
|
||
|
||
|
||
# ─── Document → Images ──────────────────────────────────────────────────────
|
||
|
||
def convert_pdf_to_images(pdf_path: str) -> List[str]:
|
||
"""Convert PDF to base64 JPEG images."""
|
||
try:
|
||
from pdf2image import convert_from_path
|
||
|
||
# Quick page count
|
||
quick = convert_from_path(pdf_path, dpi=36, fmt='jpeg')
|
||
total_pages = len(quick)
|
||
del quick
|
||
|
||
if total_pages > MAX_IMAGES_PER_REQUEST:
|
||
dpi, quality = 100, 80
|
||
step = total_pages / MAX_IMAGES_PER_REQUEST
|
||
selected = [int(step * i) + 1 for i in range(MAX_IMAGES_PER_REQUEST)]
|
||
logger.info(f" 大 PDF ({total_pages}页): 采样 {len(selected)} 页 @ {dpi} DPI")
|
||
images = []
|
||
for pn in selected:
|
||
imgs = convert_from_path(pdf_path, dpi=dpi, fmt='jpeg', first_page=pn, last_page=pn)
|
||
if imgs:
|
||
buf = io.BytesIO()
|
||
imgs[0].save(buf, format='JPEG', quality=quality)
|
||
images.append(base64.b64encode(buf.getvalue()).decode())
|
||
return images
|
||
else:
|
||
dpi, quality = 150, 90
|
||
imgs = convert_from_path(pdf_path, dpi=dpi, fmt='jpeg')
|
||
result = []
|
||
for img in imgs:
|
||
buf = io.BytesIO()
|
||
img.save(buf, format='JPEG', quality=quality)
|
||
result.append(base64.b64encode(buf.getvalue()).decode())
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f" PDF转图片失败: {e}")
|
||
return []
|
||
|
||
|
||
# ─── Image Cache ─────────────────────────────────────────────────────────────
|
||
|
||
_pdf_image_cache: Dict[str, List[str]] = {}
|
||
|
||
def get_tutorial_images(pdf_path: str) -> List[str]:
|
||
"""Get images for a tutorial PDF, with caching to avoid re-converting."""
|
||
key = str(pdf_path)
|
||
if key not in _pdf_image_cache:
|
||
logger.info(f" 正在转换教程 PDF: {Path(pdf_path).name}")
|
||
_pdf_image_cache[key] = convert_pdf_to_images(pdf_path)
|
||
logger.info(f" 得到 {len(_pdf_image_cache[key])} 张图片")
|
||
return _pdf_image_cache[key]
|
||
|
||
|
||
# ─── API Call ────────────────────────────────────────────────────────────────
|
||
|
||
async def refine_steps_via_api(
|
||
task_goal: str,
|
||
current_steps: str,
|
||
tutorial_images: List[str],
|
||
session: aiohttp.ClientSession,
|
||
stats: Stats
|
||
) -> Tuple[str, bool]:
|
||
"""Call LLM to refine coarse steps into fine-grained atomic GUI steps."""
|
||
|
||
user_text = f"""## 任务目标
|
||
{task_goal}
|
||
|
||
## 当前粗粒度步骤(需要细化)
|
||
{current_steps}
|
||
|
||
请根据以上教程截图内容,将粗粒度步骤细化为原子级 GUI 操作步骤。只输出 JSON。"""
|
||
|
||
content = [{"type": "text", "text": user_text}]
|
||
for img_b64 in tutorial_images[:MAX_IMAGES_PER_REQUEST]:
|
||
content.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}
|
||
})
|
||
|
||
messages = [
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user", "content": content}
|
||
]
|
||
|
||
for attempt in range(1, MAX_RETRY_ATTEMPTS + 1):
|
||
try:
|
||
headers = {
|
||
"Authorization": f"Bearer {API_KEY}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": MODEL_NAME,
|
||
"messages": messages,
|
||
"max_tokens": MAX_TOKENS,
|
||
}
|
||
|
||
async with session.post(API_URL, headers=headers, json=payload, timeout=240) as resp:
|
||
if resp.status == 200:
|
||
data = await resp.json()
|
||
return data['choices'][0]['message']['content'], True
|
||
else:
|
||
err = await resp.text()
|
||
logger.warning(f" API错误 ({resp.status}): {err[:200]}")
|
||
if resp.status == 413:
|
||
return "Payload too large", False
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f" API超时 (attempt {attempt})")
|
||
except Exception as e:
|
||
logger.warning(f" API异常: {e}")
|
||
|
||
if attempt < MAX_RETRY_ATTEMPTS:
|
||
delay = RETRY_DELAY * (RETRY_BACKOFF ** (attempt - 1))
|
||
logger.info(f" 重试 {attempt}/{MAX_RETRY_ATTEMPTS} (等待 {delay}s)")
|
||
stats.retries += 1
|
||
await asyncio.sleep(delay)
|
||
|
||
return "Max retries exceeded", False
|
||
|
||
|
||
def parse_refined_steps(raw_content: str) -> Optional[str]:
|
||
"""Extract the 'steps' string from the LLM's JSON response."""
|
||
# Try to find JSON block in markdown code fence
|
||
m = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw_content, re.DOTALL)
|
||
if m:
|
||
json_str = m.group(1)
|
||
else:
|
||
# Try raw JSON
|
||
m = re.search(r'\{.*\}', raw_content, re.DOTALL)
|
||
json_str = m.group(0) if m else raw_content
|
||
|
||
try:
|
||
obj = json.loads(json_str)
|
||
steps = obj.get("steps", "")
|
||
if steps and len(steps) > 20:
|
||
return steps
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Fallback: if content looks like numbered steps directly
|
||
if re.search(r'^\d+\.\s', raw_content, re.MULTILINE):
|
||
lines = [l.strip() for l in raw_content.strip().split('\n') if re.match(r'^\d+\.', l.strip())]
|
||
if len(lines) >= 2:
|
||
return '\n'.join(lines)
|
||
|
||
return None
|
||
|
||
|
||
# ─── Task Processing ────────────────────────────────────────────────────────
|
||
|
||
def find_tutorial_pdf(software: str, task_id: str) -> Optional[Path]:
|
||
"""Find the tutorial PDF for a given task.
|
||
Task ID format: {tutorial_stem}_task{N}
|
||
Tutorial PDF: inputs/{software}/{tutorial_stem}.pdf
|
||
"""
|
||
m = re.match(r'^(.+?)_task\d+$', task_id)
|
||
if not m:
|
||
return None
|
||
tutorial_stem = m.group(1)
|
||
pdf_path = INPUT_FOLDER / software / f"{tutorial_stem}.pdf"
|
||
if pdf_path.exists():
|
||
return pdf_path
|
||
|
||
# Fallback: try to find any PDF with matching prefix
|
||
sw_input_dir = INPUT_FOLDER / software
|
||
if sw_input_dir.exists():
|
||
for f in sw_input_dir.iterdir():
|
||
if f.suffix.lower() == '.pdf' and f.stem == tutorial_stem:
|
||
return f
|
||
return None
|
||
|
||
|
||
async def process_task(
|
||
task_json_path: Path,
|
||
software: str,
|
||
session: aiohttp.ClientSession,
|
||
semaphore: asyncio.Semaphore,
|
||
stats: Stats,
|
||
force: bool = False,
|
||
dry_run: bool = False
|
||
):
|
||
"""Process a single task JSON: refine its metadata.steps."""
|
||
async with semaphore:
|
||
rel_path = f"{software}/{task_json_path.name}"
|
||
|
||
# Load task JSON
|
||
try:
|
||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||
task_data = json.load(f)
|
||
except Exception as e:
|
||
stats.failed += 1
|
||
stats.failures.append((rel_path, f"JSON读取失败: {e}"))
|
||
stats.log_progress()
|
||
return
|
||
|
||
task_id = task_data.get("id", task_json_path.stem)
|
||
instruction = task_data.get("instruction", "")
|
||
metadata = task_data.get("metadata", {})
|
||
current_steps = metadata.get("steps", "")
|
||
|
||
# Check if already refined (has steps_original)
|
||
if metadata.get("steps_original") and not force:
|
||
logger.info(f"跳过 (已精炼): {rel_path}")
|
||
stats.skipped += 1
|
||
stats.log_progress()
|
||
return
|
||
|
||
if not instruction:
|
||
logger.warning(f"跳过 (无instruction): {rel_path}")
|
||
stats.skipped += 1
|
||
stats.log_progress()
|
||
return
|
||
|
||
# Find tutorial PDF
|
||
tutorial_pdf = find_tutorial_pdf(software, task_id)
|
||
|
||
if dry_run:
|
||
logger.info(f"[DRY RUN] 将处理: {rel_path}")
|
||
logger.info(f" 任务: {instruction[:80]}...")
|
||
logger.info(f" 当前步骤: {current_steps[:80]}...")
|
||
logger.info(f" 教程PDF: {tutorial_pdf or '未找到'}")
|
||
stats.completed += 1
|
||
stats.log_progress()
|
||
return
|
||
|
||
if not tutorial_pdf:
|
||
logger.warning(f"跳过 (无教程PDF): {rel_path} (task_id={task_id})")
|
||
stats.skipped += 1
|
||
stats.log_progress()
|
||
return
|
||
|
||
# Get tutorial images (cached)
|
||
images = get_tutorial_images(str(tutorial_pdf))
|
||
if not images:
|
||
logger.warning(f"跳过 (PDF转图片失败): {rel_path}")
|
||
stats.skipped += 1
|
||
stats.log_progress()
|
||
return
|
||
|
||
# Call API to refine steps
|
||
logger.info(f"正在精炼: {rel_path}")
|
||
raw_response, success = await refine_steps_via_api(
|
||
instruction, current_steps, images, session, stats
|
||
)
|
||
|
||
if not success:
|
||
stats.failed += 1
|
||
stats.failures.append((rel_path, raw_response[:200]))
|
||
stats.log_progress()
|
||
return
|
||
|
||
# Parse refined steps
|
||
refined_steps = parse_refined_steps(raw_response)
|
||
if not refined_steps:
|
||
stats.failed += 1
|
||
stats.failures.append((rel_path, f"无法解析LLM返回: {raw_response[:200]}"))
|
||
stats.log_progress()
|
||
return
|
||
|
||
# Save original steps, update with refined steps
|
||
task_data["metadata"]["steps_original"] = current_steps
|
||
task_data["metadata"]["steps"] = refined_steps
|
||
|
||
# Write back to original file
|
||
with open(task_json_path, 'w', encoding='utf-8') as f:
|
||
json.dump(task_data, f, ensure_ascii=False, indent=2)
|
||
|
||
# Log step count change
|
||
step_count_old = len([l for l in current_steps.split('\n') if l.strip()])
|
||
step_count_new = len([l for l in refined_steps.split('\n') if l.strip()])
|
||
logger.info(f" ✓ {rel_path}: {step_count_old}步 → {step_count_new}步")
|
||
|
||
stats.completed += 1
|
||
stats.log_progress()
|
||
|
||
|
||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||
|
||
def collect_tasks(software_filter: List[str] = None, task_filter: List[str] = None) -> List[Tuple[Path, str]]:
|
||
"""Collect all (task_json_path, software_name) pairs to process."""
|
||
results = []
|
||
|
||
if task_filter:
|
||
for t in task_filter:
|
||
parts = t.split('/')
|
||
if len(parts) == 2:
|
||
sw, fname = parts
|
||
path = EXAMPLES_FOLDER / sw / fname
|
||
if path.exists():
|
||
results.append((path, sw))
|
||
else:
|
||
logger.warning(f"任务文件不存在: {t}")
|
||
else:
|
||
logger.warning(f"任务格式应为 software/filename.json: {t}")
|
||
return results
|
||
|
||
softwares = software_filter if software_filter else sorted(TARGET_SOFTWARE)
|
||
|
||
for sw in softwares:
|
||
sw_dir = EXAMPLES_FOLDER / sw
|
||
if not sw_dir.exists():
|
||
logger.warning(f"软件目录不存在: {sw}")
|
||
continue
|
||
for f in sorted(sw_dir.glob("*.json")):
|
||
results.append((f, sw))
|
||
|
||
return results
|
||
|
||
|
||
async def main():
|
||
parser = argparse.ArgumentParser(description="Refine benchmark metadata.steps to fine-grained GUI operations")
|
||
parser.add_argument('--software', nargs='*', help='Only process these software names (e.g. avogadro jade)')
|
||
parser.add_argument('--tasks', nargs='*', help='Only process specific tasks (e.g. avogadro/building-organic-molecules_task1.json)')
|
||
parser.add_argument('--force', action='store_true', help='Overwrite already-refined files (those with steps_original)')
|
||
parser.add_argument('--dry-run', action='store_true', help='Preview mode, no API calls')
|
||
args = parser.parse_args()
|
||
|
||
if not API_KEY and not args.dry_run:
|
||
logger.error("请设置 OPENAI_API_KEY 环境变量")
|
||
return
|
||
|
||
# Collect tasks
|
||
tasks = collect_tasks(args.software, args.tasks)
|
||
if not tasks:
|
||
logger.error("未找到需要处理的任务文件")
|
||
return
|
||
|
||
stats = Stats(total=len(tasks))
|
||
|
||
logger.info("=" * 60)
|
||
logger.info("Refine Metadata — 细粒度步骤生成")
|
||
logger.info("=" * 60)
|
||
logger.info(f"模型: {MODEL_NAME}")
|
||
logger.info(f"数据目录: {EXAMPLES_FOLDER}")
|
||
logger.info(f"教程目录: {INPUT_FOLDER}")
|
||
logger.info(f"任务数: {len(tasks)}")
|
||
logger.info(f"并发数: {MAX_CONCURRENT_REQUESTS}")
|
||
if args.force:
|
||
logger.info("模式: 强制覆盖")
|
||
if args.dry_run:
|
||
logger.info("模式: DRY RUN (不调用API)")
|
||
logger.info("=" * 60)
|
||
|
||
# Show distribution
|
||
from collections import Counter
|
||
dist = Counter(sw for _, sw in tasks)
|
||
for sw, cnt in sorted(dist.items()):
|
||
logger.info(f" {sw}: {cnt} tasks")
|
||
|
||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
coros = [
|
||
process_task(path, sw, session, semaphore, stats, args.force, args.dry_run)
|
||
for path, sw in tasks
|
||
]
|
||
await asyncio.gather(*coros)
|
||
|
||
stats.summary()
|
||
|
||
# Save processing report
|
||
report = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"model": MODEL_NAME,
|
||
"total": stats.total,
|
||
"completed": stats.completed,
|
||
"failed": stats.failed,
|
||
"skipped": stats.skipped,
|
||
"failures": [{"file": f, "error": e} for f, e in stats.failures]
|
||
}
|
||
report_path = EXAMPLES_FOLDER / "refine_report.json"
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
json.dump(report, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"处理报告: {report_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|