202 lines
6.3 KiB
Python
202 lines
6.3 KiB
Python
import asyncio
|
|
from typing import List, Optional, Union, Dict, Any
|
|
import json
|
|
import os
|
|
import hashlib
|
|
from pathlib import Path
|
|
from omegaconf import DictConfig
|
|
from dataclasses import dataclass, asdict
|
|
import copy
|
|
import logging
|
|
import random
|
|
|
|
from prompts import COMPUTER_USE_PROMPT, COMPUTER_USE_PROMPT_WITH_CALL_USER
|
|
from log_config import setup_logging
|
|
|
|
# 设置统一的日志系统
|
|
setup_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TaskLoader:
|
|
def __init__(self, task_cfg: DictConfig, storage_root):
|
|
self.task_file = Path(task_cfg.task_file)
|
|
#self.task_root = Path(task_cfg.task_root)
|
|
self.osworld_root = Path(task_cfg.osworld_root)
|
|
|
|
self._latest_sha: Optional[str] = None
|
|
self.storage_root = storage_root
|
|
self.resume = task_cfg.resume
|
|
|
|
def poll_for_tasks(self) -> List[Dict]:
|
|
"""find new tasks json file
|
|
return list of TaskInfo dict if there is new json
|
|
else return []
|
|
"""
|
|
self._maybe_refresh_dataset()
|
|
|
|
tasks_list = [task.to_dict() for task in self._tasks]
|
|
random.shuffle(tasks_list)
|
|
|
|
return tasks_list
|
|
|
|
def _maybe_refresh_dataset_bak(self):
|
|
|
|
# check new json
|
|
latest_json = self._find_latest_json()
|
|
|
|
if latest_json is None:
|
|
return False # no json file
|
|
|
|
sha = self._calc_sha1(latest_json)
|
|
if sha == self._latest_sha:
|
|
return False # no change
|
|
|
|
with open(latest_json) as f:
|
|
data = json.load(f)
|
|
|
|
raw_tasks = [
|
|
{"task_type": task_type, "task_id": task_id}
|
|
for task_type, task_ids in data.items()
|
|
for task_id in task_ids
|
|
]
|
|
|
|
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
|
|
self._latest_sha = sha
|
|
|
|
logger.info(f"当前任务文件: {str(latest_json)}")
|
|
logger.info(f"任务总数: {len(raw_tasks)}")
|
|
|
|
return True
|
|
|
|
def _maybe_refresh_dataset(self):
|
|
|
|
latest_json = self.task_file
|
|
print("Current tasks file: ", str(latest_json))
|
|
|
|
with open(latest_json) as f:
|
|
data = json.load(f)
|
|
|
|
raw_tasks = [
|
|
{"task_type": task_type, "task_id": task_id}
|
|
for task_type, task_ids in data.items()
|
|
for task_id in task_ids
|
|
]
|
|
|
|
if self.resume:
|
|
# 过滤已完成或类型不匹配的任务
|
|
filtered_tasks = []
|
|
storage_root = Path(self.storage_root)
|
|
|
|
for raw in raw_tasks:
|
|
task_id = str(raw["task_id"])
|
|
task_type_expected = raw["task_type"]
|
|
|
|
# 找到所有以 task_id 开头的子目录(允许有多个版本)
|
|
candidate_dirs = [
|
|
d for d in storage_root.iterdir()
|
|
if d.is_dir() and d.name.startswith(task_id)
|
|
]
|
|
|
|
# 默认认为任务未完成
|
|
task_finished = False
|
|
|
|
for d in candidate_dirs:
|
|
cfg_path = d / "task_config.json"
|
|
if not cfg_path.exists():
|
|
print("找不到config文件")
|
|
continue
|
|
|
|
try:
|
|
with cfg_path.open("r", encoding="utf-8") as cf:
|
|
cfg = json.load(cf)
|
|
except Exception:
|
|
print("配置损坏,忽略此目录")
|
|
continue
|
|
|
|
# 3.1 task_type 不同 => 不是同一个任务,直接跳过这目录
|
|
if cfg.get("raw", {}).get("task_type") != task_type_expected:
|
|
continue
|
|
|
|
# 3.2 task_type 相同,检查 reward.txt
|
|
if (d / "reward.txt").exists():
|
|
task_finished = True
|
|
break # 已找到完成记录,无需再看其他目录
|
|
if not task_finished:
|
|
filtered_tasks.append(raw)
|
|
self._tasks = [build_task(raw, self.osworld_root) for raw in filtered_tasks]
|
|
print(f"Total number of tasks: {len(raw_tasks)}, Remained:{len(filtered_tasks)}")
|
|
|
|
else:
|
|
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
|
|
print(f"Total number of tasks: {len(raw_tasks)}")
|
|
|
|
return True
|
|
|
|
def _find_latest_json(self) -> Optional[Path]:
|
|
files = list(self.task_root.glob("*.json"))
|
|
return max(files, key=lambda p: p.stat().st_mtime) if files else None
|
|
|
|
@staticmethod
|
|
def _calc_sha1(fp: Path, chunk_size=2<<20) -> str:
|
|
h = hashlib.sha1()
|
|
with fp.open("rb") as f:
|
|
for chunk in iter(lambda: f.read(chunk_size), b""):
|
|
h.update(chunk)
|
|
return h.hexdigest()
|
|
|
|
|
|
@dataclass
|
|
class TaskInfo:
|
|
messages: List
|
|
instruction: str
|
|
task_config: Dict
|
|
|
|
def to_dict(self):
|
|
return asdict(self)
|
|
|
|
|
|
def build_task(raw: Dict, osworld_root: Path, use_call_user: bool = False) -> TaskInfo:
|
|
|
|
task_type = raw["task_type"]
|
|
task_id = raw["task_id"]
|
|
task_path = os.path.join(osworld_root, task_type, task_id + ".json")
|
|
with open(task_path) as f:
|
|
task_data = json.load(f)
|
|
|
|
task_data["raw"] = {
|
|
"task_type": task_type,
|
|
"task_id": task_id
|
|
}
|
|
|
|
instruction = task_data["instruction"]
|
|
|
|
if "human-ground-truth" in task_data and "single-action" in task_data["human-ground-truth"]:
|
|
plan = task_data["human-ground-truth"]["single-action"]
|
|
plan_text = "\n".join(plan)
|
|
instruction = instruction.strip() + "\nHere is an instruction to help you complete the task: \n" + plan_text
|
|
|
|
system_prompt = COMPUTER_USE_PROMPT if not use_call_user else COMPUTER_USE_PROMPT_WITH_CALL_USER
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": system_prompt.format(
|
|
instruction=instruction,
|
|
language="English"
|
|
)}
|
|
]
|
|
}
|
|
]
|
|
|
|
|
|
return TaskInfo(
|
|
messages = messages,
|
|
instruction = instruction,
|
|
task_config = task_data
|
|
) |