Files
sci-gui-agent-benchmark/mm_agents/dart_gui/task_loader.py
2025-11-07 21:50:01 +08:00

202 lines
6.3 KiB
Python

import asyncio
from typing import List, Optional, Union, Dict, Any
import json
import os
import hashlib
from pathlib import Path
from omegaconf import DictConfig
from dataclasses import dataclass, asdict
import copy
import logging
import random
from prompts import COMPUTER_USE_PROMPT, COMPUTER_USE_PROMPT_WITH_CALL_USER
from log_config import setup_logging
# 设置统一的日志系统
setup_logging()
logger = logging.getLogger(__name__)
class TaskLoader:
def __init__(self, task_cfg: DictConfig, storage_root):
self.task_file = Path(task_cfg.task_file)
#self.task_root = Path(task_cfg.task_root)
self.osworld_root = Path(task_cfg.osworld_root)
self._latest_sha: Optional[str] = None
self.storage_root = storage_root
self.resume = task_cfg.resume
def poll_for_tasks(self) -> List[Dict]:
"""find new tasks json file
return list of TaskInfo dict if there is new json
else return []
"""
self._maybe_refresh_dataset()
tasks_list = [task.to_dict() for task in self._tasks]
random.shuffle(tasks_list)
return tasks_list
def _maybe_refresh_dataset_bak(self):
# check new json
latest_json = self._find_latest_json()
if latest_json is None:
return False # no json file
sha = self._calc_sha1(latest_json)
if sha == self._latest_sha:
return False # no change
with open(latest_json) as f:
data = json.load(f)
raw_tasks = [
{"task_type": task_type, "task_id": task_id}
for task_type, task_ids in data.items()
for task_id in task_ids
]
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
self._latest_sha = sha
logger.info(f"当前任务文件: {str(latest_json)}")
logger.info(f"任务总数: {len(raw_tasks)}")
return True
def _maybe_refresh_dataset(self):
latest_json = self.task_file
print("Current tasks file: ", str(latest_json))
with open(latest_json) as f:
data = json.load(f)
raw_tasks = [
{"task_type": task_type, "task_id": task_id}
for task_type, task_ids in data.items()
for task_id in task_ids
]
if self.resume:
# 过滤已完成或类型不匹配的任务
filtered_tasks = []
storage_root = Path(self.storage_root)
for raw in raw_tasks:
task_id = str(raw["task_id"])
task_type_expected = raw["task_type"]
# 找到所有以 task_id 开头的子目录(允许有多个版本)
candidate_dirs = [
d for d in storage_root.iterdir()
if d.is_dir() and d.name.startswith(task_id)
]
# 默认认为任务未完成
task_finished = False
for d in candidate_dirs:
cfg_path = d / "task_config.json"
if not cfg_path.exists():
print("找不到config文件")
continue
try:
with cfg_path.open("r", encoding="utf-8") as cf:
cfg = json.load(cf)
except Exception:
print("配置损坏,忽略此目录")
continue
# 3.1 task_type 不同 => 不是同一个任务,直接跳过这目录
if cfg.get("raw", {}).get("task_type") != task_type_expected:
continue
# 3.2 task_type 相同,检查 reward.txt
if (d / "reward.txt").exists():
task_finished = True
break # 已找到完成记录,无需再看其他目录
if not task_finished:
filtered_tasks.append(raw)
self._tasks = [build_task(raw, self.osworld_root) for raw in filtered_tasks]
print(f"Total number of tasks: {len(raw_tasks)}, Remained:{len(filtered_tasks)}")
else:
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
print(f"Total number of tasks: {len(raw_tasks)}")
return True
def _find_latest_json(self) -> Optional[Path]:
files = list(self.task_root.glob("*.json"))
return max(files, key=lambda p: p.stat().st_mtime) if files else None
@staticmethod
def _calc_sha1(fp: Path, chunk_size=2<<20) -> str:
h = hashlib.sha1()
with fp.open("rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
h.update(chunk)
return h.hexdigest()
@dataclass
class TaskInfo:
messages: List
instruction: str
task_config: Dict
def to_dict(self):
return asdict(self)
def build_task(raw: Dict, osworld_root: Path, use_call_user: bool = False) -> TaskInfo:
task_type = raw["task_type"]
task_id = raw["task_id"]
task_path = os.path.join(osworld_root, task_type, task_id + ".json")
with open(task_path) as f:
task_data = json.load(f)
task_data["raw"] = {
"task_type": task_type,
"task_id": task_id
}
instruction = task_data["instruction"]
if "human-ground-truth" in task_data and "single-action" in task_data["human-ground-truth"]:
plan = task_data["human-ground-truth"]["single-action"]
plan_text = "\n".join(plan)
instruction = instruction.strip() + "\nHere is an instruction to help you complete the task: \n" + plan_text
system_prompt = COMPUTER_USE_PROMPT if not use_call_user else COMPUTER_USE_PROMPT_WITH_CALL_USER
messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": system_prompt.format(
instruction=instruction,
language="English"
)}
]
}
]
return TaskInfo(
messages = messages,
instruction = instruction,
task_config = task_data
)