Files
sci-gui-agent-benchmark/scripts/tools/batch_create_tasks.py
2026-01-12 18:30:12 +08:00

214 lines
7.3 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量创建任务
从JSON定义文件批量创建任务
用法:
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json
"""
import json
import os
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from scripts.tools.init_task import init_task
def batch_create_tasks(definition_file, project_root=".", force=False, skip_existing=True):
"""
批量创建任务
Args:
definition_file: 任务定义JSON文件路径
project_root: 项目根目录
force: 是否强制覆盖已存在的任务
skip_existing: 是否跳过已存在的任务与force互斥
"""
# 读取任务定义
with open(definition_file, 'r', encoding='utf-8') as f:
data = json.load(f)
tasks = data.get("tasks", [])
category = data.get("category", "unknown")
tutorial_source = data.get("tutorial_source", "")
print("=" * 60)
print(f"📚 批量创建任务: {category}")
if tutorial_source:
print(f" 教程来源: {tutorial_source}")
print(f" 任务数量: {len(tasks)}")
print("=" * 60)
results = {
"success": [],
"skipped": [],
"failed": []
}
for i, task_def in enumerate(tasks, 1):
task_id = task_def["id"]
print(f"\n[{i}/{len(tasks)}] 处理任务: {task_id}")
print("-" * 60)
# 检查是否已存在
task_dir = os.path.join(project_root, "tasks", task_id)
if os.path.exists(task_dir) and not force:
if skip_existing:
print(f"⏭️ 跳过(已存在): {task_id}")
results["skipped"].append(task_id)
continue
else:
print(f"⚠️ 任务已存在: {task_id}")
print(" 使用 --force 强制覆盖,或设置 skip_existing=True")
results["failed"].append((task_id, "已存在"))
continue
try:
# 调用 init_task 创建任务结构
success = init_task(
task_id=task_id,
project_root=project_root,
force=force,
category=task_def.get("category", category),
difficulty=task_def.get("difficulty", "easy"),
instruction=task_def.get("instruction", "")
)
if not success:
results["failed"].append((task_id, "初始化失败"))
continue
# 更新 task.json
task_json_path = os.path.join(task_dir, "task.json")
if os.path.exists(task_json_path):
with open(task_json_path, 'r', encoding='utf-8') as f:
task_config = json.load(f)
# 更新输入输出配置
source_file = task_def.get("source_file", "DEMO01.MDI")
if not os.path.isabs(source_file) and not source_file.startswith("../"):
source_file = f"../../data/source/{source_file}"
filename = os.path.basename(source_file)
inject_to = f"C:\\Users\\lzy\\Desktop\\{filename}"
output_filename = task_def.get("output_filename", "result.txt")
collect_from = f"C:\\Users\\lzy\\Desktop\\{output_filename}"
task_config["input"] = {
"source_file": source_file,
"inject_to": inject_to
}
task_config["output"] = {
"expected_file": output_filename,
"collect_from": collect_from
}
# 更新评测方法
eval_method = task_def.get("evaluation_method", "xrd_data_compare")
task_config["evaluation"]["method"] = eval_method
# 添加教程来源
if task_def.get("tutorial_source"):
task_config["tutorial_source"] = task_def["tutorial_source"]
# 添加备注
if task_def.get("notes"):
task_config["notes"] = task_def["notes"]
# 保存
with open(task_json_path, 'w', encoding='utf-8') as f:
json.dump(task_config, f, ensure_ascii=False, indent=2)
print(f"✅ 任务创建成功: {task_id}")
results["success"].append(task_id)
else:
print(f"❌ task.json 未创建: {task_json_path}")
results["failed"].append((task_id, "配置文件未创建"))
except Exception as e:
print(f"❌ 创建任务失败: {task_id}")
print(f" 错误: {e}")
import traceback
traceback.print_exc()
results["failed"].append((task_id, str(e)))
# 打印总结
print("\n" + "=" * 60)
print("📊 批量创建总结")
print("=" * 60)
print(f"✅ 成功: {len(results['success'])}")
if results["success"]:
for task_id in results["success"]:
print(f" - {task_id}")
if results["skipped"]:
print(f"\n⏭️ 跳过: {len(results['skipped'])}")
for task_id in results["skipped"]:
print(f" - {task_id}")
if results["failed"]:
print(f"\n❌ 失败: {len(results['failed'])}")
for task_id, reason in results["failed"]:
print(f" - {task_id}: {reason}")
print("=" * 60)
return results
def main():
import argparse
parser = argparse.ArgumentParser(
description="批量创建任务",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
# 批量创建任务(跳过已存在的)
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json
# 强制覆盖已存在的任务
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json --force
# 不跳过已存在的任务(遇到已存在就失败)
python scripts/tools/batch_create_tasks.py tasks/batch_definitions/basic_processing_tasks.json --no-skip-existing
"""
)
parser.add_argument("definition_file", help="任务定义JSON文件路径")
parser.add_argument("--project-root", default=".", help="项目根目录")
parser.add_argument("--force", action="store_true", help="强制覆盖已存在的任务")
parser.add_argument("--no-skip-existing", action="store_true", help="不跳过已存在的任务(遇到就失败)")
args = parser.parse_args()
if not os.path.exists(args.definition_file):
print(f"❌ 定义文件不存在: {args.definition_file}")
sys.exit(1)
skip_existing = not args.no_skip_existing
results = batch_create_tasks(
args.definition_file,
args.project_root,
force=args.force,
skip_existing=skip_existing
)
# 如果有失败的任务返回非0退出码
if results["failed"]:
sys.exit(1)
else:
sys.exit(0)
if __name__ == "__main__":
main()