feat: 新增 refine_metadata 脚本,更新 extract_instructions_v2
This commit is contained in:
@@ -85,6 +85,13 @@ SOFTWARE_CONFIG = {
|
||||
{"type": "sleep", "parameters": {"seconds": 5}}
|
||||
]
|
||||
},
|
||||
"jade": {
|
||||
"snapshot": "jade",
|
||||
"config": [
|
||||
{"type": "launch", "parameters": {"command": ["C:\\JADE\\jade 6.5\\MDI Jade 6.5\\jade6.5.exe"]}},
|
||||
{"type": "sleep", "parameters": {"seconds": 5}}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Default config for unknown software
|
||||
@@ -522,8 +529,12 @@ async def main():
|
||||
global stats, FORCE_REGENERATE
|
||||
stats = ProcessingStats()
|
||||
|
||||
# Parse --force flag
|
||||
# Parse --force flag and --software filter
|
||||
FORCE_REGENERATE = "--force" in sys.argv
|
||||
software_filter = None
|
||||
for i, arg in enumerate(sys.argv):
|
||||
if arg == "--software" and i + 1 < len(sys.argv):
|
||||
software_filter = sys.argv[i + 1]
|
||||
|
||||
if not API_KEY:
|
||||
logger.error("OPENAI_API_KEY environment variable not set.")
|
||||
@@ -536,9 +547,10 @@ async def main():
|
||||
logger.info(f"Please put software PDF tutorials into subfolders in: {INPUT_FOLDER}")
|
||||
return
|
||||
|
||||
# Find files
|
||||
# Find files (optionally filtered by --software)
|
||||
search_folder = INPUT_FOLDER / software_filter if software_filter else INPUT_FOLDER
|
||||
files = []
|
||||
for root, _, filenames in os.walk(INPUT_FOLDER):
|
||||
for root, _, filenames in os.walk(search_folder):
|
||||
for f in filenames:
|
||||
if Path(f).suffix.lower() in SUPPORTED_EXTENSIONS:
|
||||
files.append(os.path.join(root, f))
|
||||
|
||||
529
evaluation_examples/refine_metadata.py
Normal file
529
evaluation_examples/refine_metadata.py
Normal file
@@ -0,0 +1,529 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
refine_metadata.py — 将现有 benchmark JSON 的 metadata.steps 精炼为原子级 GUI 操作步骤。
|
||||
直接覆盖原文件,原始步骤保存到 metadata.steps_original。
|
||||
|
||||
用法:
|
||||
# 处理 avogadro 的所有任务
|
||||
python refine_metadata.py --software avogadro
|
||||
|
||||
# 处理指定任务
|
||||
python refine_metadata.py --tasks avogadro/building-organic-molecules_task1.json
|
||||
|
||||
# 全部 7 个软件
|
||||
python refine_metadata.py
|
||||
|
||||
# 强制覆盖已精炼过的 (有 steps_original 的)
|
||||
python refine_metadata.py --force
|
||||
|
||||
# 预览模式 (不调用 API)
|
||||
python refine_metadata.py --dry-run
|
||||
|
||||
环境变量:
|
||||
OPENAI_API_KEY — API Key (必需)
|
||||
OPENAI_BASE_URL — API Base URL (默认 https://api.openai.com/v1)
|
||||
EXTRACT_MODEL — 模型名 (默认 gemini-3.1-pro-preview)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import base64
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import io
|
||||
|
||||
# ─── Configuration ───────────────────────────────────────────────────────────
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
PROJECT_ROOT = SCRIPT_DIR.parent
|
||||
|
||||
API_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
API_URL = f"{API_BASE_URL}/chat/completions"
|
||||
API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
MODEL_NAME = os.getenv("EXTRACT_MODEL", "gemini-3.1-pro-preview")
|
||||
|
||||
MAX_CONCURRENT_REQUESTS = 3
|
||||
MAX_RETRY_ATTEMPTS = 3
|
||||
RETRY_DELAY = 5
|
||||
RETRY_BACKOFF = 2
|
||||
MAX_IMAGES_PER_REQUEST = 20
|
||||
MAX_TOKENS = 8192
|
||||
|
||||
EXAMPLES_FOLDER = SCRIPT_DIR / "examples"
|
||||
INPUT_FOLDER = SCRIPT_DIR / "inputs"
|
||||
|
||||
# 只处理这 7 个软件
|
||||
TARGET_SOFTWARE = {"avogadro", "imagej", "jade", "origin", "ovito", "pymol", "vesta"}
|
||||
|
||||
# ─── Prompt ──────────────────────────────────────────────────────────────────
|
||||
|
||||
SYSTEM_PROMPT = """你是一位桌面科研软件 GUI 自动化专家。你的任务是为一个 **已确定的具体任务** 生成极其详细的、原子级 GUI 操作步骤。
|
||||
|
||||
## 背景
|
||||
这些步骤将作为"教程级操作指南",直接注入给 AI Agent 去操控桌面软件。Agent 会严格按步骤执行,因此:
|
||||
- **每一步必须是一个不可分割的原子 GUI 操作**(单击一个按钮、输入一个值、选择一个菜单项等)
|
||||
- 不能跳步或合并操作
|
||||
- 必须包含所有隐含的中间操作
|
||||
|
||||
## 你将收到
|
||||
1. **教程文档的截图**(PDF 教程页面图片)— 作为操作上下文参考
|
||||
2. **任务目标** (task_goal) — 一句话描述要完成的任务
|
||||
3. **当前粗粒度步骤** (current_steps) — 现有的概要步骤,但粒度太粗
|
||||
|
||||
## 你需要输出
|
||||
根据教程文档内容,将粗粒度步骤细化为**原子级 GUI 操作步骤**。
|
||||
|
||||
## 步骤编写规则
|
||||
|
||||
### 每步必须满足:
|
||||
1. **只包含一个明确的、不可分割的 GUI 原子动作**
|
||||
2. **必须指明**:操作的控件类型(菜单项/按钮/输入框/复选框/下拉菜单/树节点等)
|
||||
3. **必须指明**:控件的标签文字或位置描述
|
||||
4. **必须指明**:操作类型(单击/双击/右键/输入文字/勾选/选择等)
|
||||
5. **包含所有隐含步骤**:不要遗漏任何中间操作
|
||||
|
||||
### ❌ 错误(太粗):
|
||||
- "搜索 benzene 并插入" — 缺少展开文件夹、选中文件、点击插入等中间步骤
|
||||
- "设置平滑参数" — 没说具体修改哪个控件、输入什么值
|
||||
- "在对话框中配置选项" — 没有分解为逐个控件操作
|
||||
|
||||
### ✅ 正确(原子级):
|
||||
- "在\"筛选\"输入框中,将光标定位到输入框,清空已有内容,输入 benzene"
|
||||
- "在筛选结果的树形列表中,双击展开名为 aromatics 的文件夹"
|
||||
- "在展开的列表中,单击选中 benzene.cjson 文件"
|
||||
- "点击对话框底部的\"插入\"按钮"
|
||||
|
||||
## 输出格式
|
||||
只返回一个 JSON 对象:
|
||||
{
|
||||
"steps": "细化后的操作步骤字符串,每步带编号,用换行符分隔"
|
||||
}
|
||||
|
||||
不要添加任何其他文本。只输出 JSON。"""
|
||||
|
||||
|
||||
# ─── Logging ─────────────────────────────────────────────────────────────────
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── Stats ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
total: int = 0
|
||||
completed: int = 0
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
retries: int = 0
|
||||
start_time: datetime = field(default_factory=datetime.now)
|
||||
failures: List[Tuple[str, str]] = field(default_factory=list)
|
||||
|
||||
def log_progress(self):
|
||||
done = self.completed + self.failed + self.skipped
|
||||
pct = (done / self.total * 100) if self.total else 0
|
||||
logger.info(f"进度: {done}/{self.total} ({pct:.0f}%) | 成功={self.completed} 失败={self.failed} 跳过={self.skipped}")
|
||||
|
||||
def summary(self):
|
||||
elapsed = (datetime.now() - self.start_time).total_seconds()
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"完成! 总计={self.total} 成功={self.completed} 失败={self.failed} 跳过={self.skipped} 重试={self.retries}")
|
||||
logger.info(f"耗时: {int(elapsed//60)}m{int(elapsed%60)}s")
|
||||
if self.failures:
|
||||
logger.info("失败列表:")
|
||||
for path, err in self.failures:
|
||||
logger.info(f" {path}: {err[:100]}")
|
||||
|
||||
|
||||
# ─── Document → Images ──────────────────────────────────────────────────────
|
||||
|
||||
def convert_pdf_to_images(pdf_path: str) -> List[str]:
|
||||
"""Convert PDF to base64 JPEG images."""
|
||||
try:
|
||||
from pdf2image import convert_from_path
|
||||
|
||||
# Quick page count
|
||||
quick = convert_from_path(pdf_path, dpi=36, fmt='jpeg')
|
||||
total_pages = len(quick)
|
||||
del quick
|
||||
|
||||
if total_pages > MAX_IMAGES_PER_REQUEST:
|
||||
dpi, quality = 100, 80
|
||||
step = total_pages / MAX_IMAGES_PER_REQUEST
|
||||
selected = [int(step * i) + 1 for i in range(MAX_IMAGES_PER_REQUEST)]
|
||||
logger.info(f" 大 PDF ({total_pages}页): 采样 {len(selected)} 页 @ {dpi} DPI")
|
||||
images = []
|
||||
for pn in selected:
|
||||
imgs = convert_from_path(pdf_path, dpi=dpi, fmt='jpeg', first_page=pn, last_page=pn)
|
||||
if imgs:
|
||||
buf = io.BytesIO()
|
||||
imgs[0].save(buf, format='JPEG', quality=quality)
|
||||
images.append(base64.b64encode(buf.getvalue()).decode())
|
||||
return images
|
||||
else:
|
||||
dpi, quality = 150, 90
|
||||
imgs = convert_from_path(pdf_path, dpi=dpi, fmt='jpeg')
|
||||
result = []
|
||||
for img in imgs:
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format='JPEG', quality=quality)
|
||||
result.append(base64.b64encode(buf.getvalue()).decode())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f" PDF转图片失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# ─── Image Cache ─────────────────────────────────────────────────────────────
|
||||
|
||||
_pdf_image_cache: Dict[str, List[str]] = {}
|
||||
|
||||
def get_tutorial_images(pdf_path: str) -> List[str]:
|
||||
"""Get images for a tutorial PDF, with caching to avoid re-converting."""
|
||||
key = str(pdf_path)
|
||||
if key not in _pdf_image_cache:
|
||||
logger.info(f" 正在转换教程 PDF: {Path(pdf_path).name}")
|
||||
_pdf_image_cache[key] = convert_pdf_to_images(pdf_path)
|
||||
logger.info(f" 得到 {len(_pdf_image_cache[key])} 张图片")
|
||||
return _pdf_image_cache[key]
|
||||
|
||||
|
||||
# ─── API Call ────────────────────────────────────────────────────────────────
|
||||
|
||||
async def refine_steps_via_api(
|
||||
task_goal: str,
|
||||
current_steps: str,
|
||||
tutorial_images: List[str],
|
||||
session: aiohttp.ClientSession,
|
||||
stats: Stats
|
||||
) -> Tuple[str, bool]:
|
||||
"""Call LLM to refine coarse steps into fine-grained atomic GUI steps."""
|
||||
|
||||
user_text = f"""## 任务目标
|
||||
{task_goal}
|
||||
|
||||
## 当前粗粒度步骤(需要细化)
|
||||
{current_steps}
|
||||
|
||||
请根据以上教程截图内容,将粗粒度步骤细化为原子级 GUI 操作步骤。只输出 JSON。"""
|
||||
|
||||
content = [{"type": "text", "text": user_text}]
|
||||
for img_b64 in tutorial_images[:MAX_IMAGES_PER_REQUEST]:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}
|
||||
})
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": content}
|
||||
]
|
||||
|
||||
for attempt in range(1, MAX_RETRY_ATTEMPTS + 1):
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
}
|
||||
|
||||
async with session.post(API_URL, headers=headers, json=payload, timeout=240) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return data['choices'][0]['message']['content'], True
|
||||
else:
|
||||
err = await resp.text()
|
||||
logger.warning(f" API错误 ({resp.status}): {err[:200]}")
|
||||
if resp.status == 413:
|
||||
return "Payload too large", False
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f" API超时 (attempt {attempt})")
|
||||
except Exception as e:
|
||||
logger.warning(f" API异常: {e}")
|
||||
|
||||
if attempt < MAX_RETRY_ATTEMPTS:
|
||||
delay = RETRY_DELAY * (RETRY_BACKOFF ** (attempt - 1))
|
||||
logger.info(f" 重试 {attempt}/{MAX_RETRY_ATTEMPTS} (等待 {delay}s)")
|
||||
stats.retries += 1
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
return "Max retries exceeded", False
|
||||
|
||||
|
||||
def parse_refined_steps(raw_content: str) -> Optional[str]:
|
||||
"""Extract the 'steps' string from the LLM's JSON response."""
|
||||
# Try to find JSON block in markdown code fence
|
||||
m = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw_content, re.DOTALL)
|
||||
if m:
|
||||
json_str = m.group(1)
|
||||
else:
|
||||
# Try raw JSON
|
||||
m = re.search(r'\{.*\}', raw_content, re.DOTALL)
|
||||
json_str = m.group(0) if m else raw_content
|
||||
|
||||
try:
|
||||
obj = json.loads(json_str)
|
||||
steps = obj.get("steps", "")
|
||||
if steps and len(steps) > 20:
|
||||
return steps
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Fallback: if content looks like numbered steps directly
|
||||
if re.search(r'^\d+\.\s', raw_content, re.MULTILINE):
|
||||
lines = [l.strip() for l in raw_content.strip().split('\n') if re.match(r'^\d+\.', l.strip())]
|
||||
if len(lines) >= 2:
|
||||
return '\n'.join(lines)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ─── Task Processing ────────────────────────────────────────────────────────
|
||||
|
||||
def find_tutorial_pdf(software: str, task_id: str) -> Optional[Path]:
|
||||
"""Find the tutorial PDF for a given task.
|
||||
Task ID format: {tutorial_stem}_task{N}
|
||||
Tutorial PDF: inputs/{software}/{tutorial_stem}.pdf
|
||||
"""
|
||||
m = re.match(r'^(.+?)_task\d+$', task_id)
|
||||
if not m:
|
||||
return None
|
||||
tutorial_stem = m.group(1)
|
||||
pdf_path = INPUT_FOLDER / software / f"{tutorial_stem}.pdf"
|
||||
if pdf_path.exists():
|
||||
return pdf_path
|
||||
|
||||
# Fallback: try to find any PDF with matching prefix
|
||||
sw_input_dir = INPUT_FOLDER / software
|
||||
if sw_input_dir.exists():
|
||||
for f in sw_input_dir.iterdir():
|
||||
if f.suffix.lower() == '.pdf' and f.stem == tutorial_stem:
|
||||
return f
|
||||
return None
|
||||
|
||||
|
||||
async def process_task(
|
||||
task_json_path: Path,
|
||||
software: str,
|
||||
session: aiohttp.ClientSession,
|
||||
semaphore: asyncio.Semaphore,
|
||||
stats: Stats,
|
||||
force: bool = False,
|
||||
dry_run: bool = False
|
||||
):
|
||||
"""Process a single task JSON: refine its metadata.steps."""
|
||||
async with semaphore:
|
||||
rel_path = f"{software}/{task_json_path.name}"
|
||||
|
||||
# Load task JSON
|
||||
try:
|
||||
with open(task_json_path, 'r', encoding='utf-8') as f:
|
||||
task_data = json.load(f)
|
||||
except Exception as e:
|
||||
stats.failed += 1
|
||||
stats.failures.append((rel_path, f"JSON读取失败: {e}"))
|
||||
stats.log_progress()
|
||||
return
|
||||
|
||||
task_id = task_data.get("id", task_json_path.stem)
|
||||
instruction = task_data.get("instruction", "")
|
||||
metadata = task_data.get("metadata", {})
|
||||
current_steps = metadata.get("steps", "")
|
||||
|
||||
# Check if already refined (has steps_original)
|
||||
if metadata.get("steps_original") and not force:
|
||||
logger.info(f"跳过 (已精炼): {rel_path}")
|
||||
stats.skipped += 1
|
||||
stats.log_progress()
|
||||
return
|
||||
|
||||
if not instruction:
|
||||
logger.warning(f"跳过 (无instruction): {rel_path}")
|
||||
stats.skipped += 1
|
||||
stats.log_progress()
|
||||
return
|
||||
|
||||
# Find tutorial PDF
|
||||
tutorial_pdf = find_tutorial_pdf(software, task_id)
|
||||
|
||||
if dry_run:
|
||||
logger.info(f"[DRY RUN] 将处理: {rel_path}")
|
||||
logger.info(f" 任务: {instruction[:80]}...")
|
||||
logger.info(f" 当前步骤: {current_steps[:80]}...")
|
||||
logger.info(f" 教程PDF: {tutorial_pdf or '未找到'}")
|
||||
stats.completed += 1
|
||||
stats.log_progress()
|
||||
return
|
||||
|
||||
if not tutorial_pdf:
|
||||
logger.warning(f"跳过 (无教程PDF): {rel_path} (task_id={task_id})")
|
||||
stats.skipped += 1
|
||||
stats.log_progress()
|
||||
return
|
||||
|
||||
# Get tutorial images (cached)
|
||||
images = get_tutorial_images(str(tutorial_pdf))
|
||||
if not images:
|
||||
logger.warning(f"跳过 (PDF转图片失败): {rel_path}")
|
||||
stats.skipped += 1
|
||||
stats.log_progress()
|
||||
return
|
||||
|
||||
# Call API to refine steps
|
||||
logger.info(f"正在精炼: {rel_path}")
|
||||
raw_response, success = await refine_steps_via_api(
|
||||
instruction, current_steps, images, session, stats
|
||||
)
|
||||
|
||||
if not success:
|
||||
stats.failed += 1
|
||||
stats.failures.append((rel_path, raw_response[:200]))
|
||||
stats.log_progress()
|
||||
return
|
||||
|
||||
# Parse refined steps
|
||||
refined_steps = parse_refined_steps(raw_response)
|
||||
if not refined_steps:
|
||||
stats.failed += 1
|
||||
stats.failures.append((rel_path, f"无法解析LLM返回: {raw_response[:200]}"))
|
||||
stats.log_progress()
|
||||
return
|
||||
|
||||
# Save original steps, update with refined steps
|
||||
task_data["metadata"]["steps_original"] = current_steps
|
||||
task_data["metadata"]["steps"] = refined_steps
|
||||
|
||||
# Write back to original file
|
||||
with open(task_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(task_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# Log step count change
|
||||
step_count_old = len([l for l in current_steps.split('\n') if l.strip()])
|
||||
step_count_new = len([l for l in refined_steps.split('\n') if l.strip()])
|
||||
logger.info(f" ✓ {rel_path}: {step_count_old}步 → {step_count_new}步")
|
||||
|
||||
stats.completed += 1
|
||||
stats.log_progress()
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def collect_tasks(software_filter: List[str] = None, task_filter: List[str] = None) -> List[Tuple[Path, str]]:
|
||||
"""Collect all (task_json_path, software_name) pairs to process."""
|
||||
results = []
|
||||
|
||||
if task_filter:
|
||||
for t in task_filter:
|
||||
parts = t.split('/')
|
||||
if len(parts) == 2:
|
||||
sw, fname = parts
|
||||
path = EXAMPLES_FOLDER / sw / fname
|
||||
if path.exists():
|
||||
results.append((path, sw))
|
||||
else:
|
||||
logger.warning(f"任务文件不存在: {t}")
|
||||
else:
|
||||
logger.warning(f"任务格式应为 software/filename.json: {t}")
|
||||
return results
|
||||
|
||||
softwares = software_filter if software_filter else sorted(TARGET_SOFTWARE)
|
||||
|
||||
for sw in softwares:
|
||||
sw_dir = EXAMPLES_FOLDER / sw
|
||||
if not sw_dir.exists():
|
||||
logger.warning(f"软件目录不存在: {sw}")
|
||||
continue
|
||||
for f in sorted(sw_dir.glob("*.json")):
|
||||
results.append((f, sw))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Refine benchmark metadata.steps to fine-grained GUI operations")
|
||||
parser.add_argument('--software', nargs='*', help='Only process these software names (e.g. avogadro jade)')
|
||||
parser.add_argument('--tasks', nargs='*', help='Only process specific tasks (e.g. avogadro/building-organic-molecules_task1.json)')
|
||||
parser.add_argument('--force', action='store_true', help='Overwrite already-refined files (those with steps_original)')
|
||||
parser.add_argument('--dry-run', action='store_true', help='Preview mode, no API calls')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not API_KEY and not args.dry_run:
|
||||
logger.error("请设置 OPENAI_API_KEY 环境变量")
|
||||
return
|
||||
|
||||
# Collect tasks
|
||||
tasks = collect_tasks(args.software, args.tasks)
|
||||
if not tasks:
|
||||
logger.error("未找到需要处理的任务文件")
|
||||
return
|
||||
|
||||
stats = Stats(total=len(tasks))
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Refine Metadata — 细粒度步骤生成")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"模型: {MODEL_NAME}")
|
||||
logger.info(f"数据目录: {EXAMPLES_FOLDER}")
|
||||
logger.info(f"教程目录: {INPUT_FOLDER}")
|
||||
logger.info(f"任务数: {len(tasks)}")
|
||||
logger.info(f"并发数: {MAX_CONCURRENT_REQUESTS}")
|
||||
if args.force:
|
||||
logger.info("模式: 强制覆盖")
|
||||
if args.dry_run:
|
||||
logger.info("模式: DRY RUN (不调用API)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Show distribution
|
||||
from collections import Counter
|
||||
dist = Counter(sw for _, sw in tasks)
|
||||
for sw, cnt in sorted(dist.items()):
|
||||
logger.info(f" {sw}: {cnt} tasks")
|
||||
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
coros = [
|
||||
process_task(path, sw, session, semaphore, stats, args.force, args.dry_run)
|
||||
for path, sw in tasks
|
||||
]
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
stats.summary()
|
||||
|
||||
# Save processing report
|
||||
report = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": MODEL_NAME,
|
||||
"total": stats.total,
|
||||
"completed": stats.completed,
|
||||
"failed": stats.failed,
|
||||
"skipped": stats.skipped,
|
||||
"failures": [{"file": f, "error": e} for f, e in stats.failures]
|
||||
}
|
||||
report_path = EXAMPLES_FOLDER / "refine_report.json"
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"处理报告: {report_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user