Files
sci-gui-agent-benchmark/manual_examine.py
2025-07-13 12:41:27 +00:00

317 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List, Dict
from tqdm import tqdm
from desktop_env.desktop_env import DesktopEnv
# Global variables for signal handling
active_environment = None
is_terminating = False
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Manual examination of benchmark tasks"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# example config
parser.add_argument("--domain", type=str, required=True, help="Specific domain to examine")
parser.add_argument("--example_id", type=str, required=True, help="Specific example ID to examine")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results_manual")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "manual-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "manual-debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def setup_example_logger(example, example_result_dir):
"""设置特定样例的日志记录器"""
runtime_logger = logging.getLogger(f"desktopenv.example.{example['id']}")
runtime_logger.setLevel(logging.DEBUG)
runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log")))
return runtime_logger
def run_manual_examination(env, example, instruction, args, example_result_dir):
"""手动检查单个样例的函数"""
runtime_logger = setup_example_logger(example, example_result_dir)
# 重置环境并加载任务配置
env.reset(task_config=example)
logger.info("环境正在初始化请等待60秒...")
time.sleep(60) # Wait for the environment to be ready
# 获取初始观察
obs = env._get_obs()
# 保存初始状态截图
initial_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
with open(os.path.join(example_result_dir, f"initial_state_{initial_timestamp}.png"), "wb") as f:
f.write(obs['screenshot'])
# 记录任务信息
with open(os.path.join(example_result_dir, "task_info.json"), "w", encoding="utf-8") as f:
json.dump({
"domain": args.domain,
"example_id": args.example_id,
"instruction": instruction,
"initial_timestamp": initial_timestamp,
"example_config": example
}, f, indent=2, ensure_ascii=False)
# 开始录制
env.controller.start_recording()
logger.info("="*80)
logger.info(f"任务域: {args.domain}")
logger.info(f"样例ID: {args.example_id}")
logger.info(f"任务指令: {instruction}")
logger.info("="*80)
logger.info("环境已准备就绪!")
logger.info("请在虚拟机中手动执行任务...")
logger.info("完成后请按回车键继续进行评估...")
logger.info("="*80)
# 阻塞等待用户手动操作
try:
input("按回车键开始评估...")
except KeyboardInterrupt:
logger.info("用户中断操作")
return None
logger.info("开始评估...")
# 获取最终状态截图
final_obs = env._get_obs()
final_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
with open(os.path.join(example_result_dir, f"final_state_{final_timestamp}.png"), "wb") as f:
f.write(final_obs['screenshot'])
# 评估结果
result = env.evaluate()
logger.info(f"评估结果: {result:.2f}")
# 保存结果
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
# 保存执行记录
with open(os.path.join(example_result_dir, "execution_log.jsonl"), "w", encoding="utf-8") as f:
f.write(json.dumps({
"type": "manual_execution",
"initial_timestamp": initial_timestamp,
"final_timestamp": final_timestamp,
"result": result,
"initial_screenshot": f"initial_state_{initial_timestamp}.png",
"final_screenshot": f"final_state_{final_timestamp}.png"
}, ensure_ascii=False))
f.write("\n")
# 结束录制
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
return result
def signal_handler(signum, frame):
"""处理终止信号以优雅关闭环境"""
global is_terminating, active_environment
# 避免重复处理
if is_terminating:
return
is_terminating = True
logger.info(f"接收到信号 {signum}。正在优雅关闭...")
# 关闭环境
if active_environment:
try:
logger.info("正在关闭环境...")
active_environment.close()
logger.info("环境已成功关闭")
except Exception as e:
logger.error(f"关闭环境时出错: {e}")
logger.info("关闭完成。退出程序。")
sys.exit(0)
def main():
global active_environment
# 注册信号处理器以优雅终止
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
logger.info("参数: %s", args)
# 加载指定的任务
config_file = os.path.join(
args.test_config_base_dir, f"examples/{args.domain}/{args.example_id}.json"
)
if not os.path.exists(config_file):
logger.error(f"配置文件不存在: {config_file}")
return
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
# 创建结果目录
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
"manual_examination",
args.domain,
args.example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
# 设置环境
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = "us-east-1"
active_environment = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name="aws",
region=REGION,
snapshot_name=IMAGE_ID_MAP[REGION],
screen_size=(args.screen_width, args.screen_height),
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
# 执行手动检查
result = run_manual_examination(
active_environment,
example,
example["instruction"],
args,
example_result_dir
)
if result is not None:
logger.info(f"手动检查完成。最终结果: {result:.2f}")
else:
logger.info("手动检查被中断")
except KeyboardInterrupt:
logger.info("主进程接收到KeyboardInterrupt")
# 信号处理器会处理清理工作
except Exception as e:
logger.error(f"主进程中的意外错误: {e}", exc_info=True)
# 也触发清理
signal_handler(signal.SIGTERM, None)
finally:
# 最终清理以防任何环境或进程仍然存在
logger.info("主进程最终清理...")
if active_environment is not None:
try:
logger.info("在最终清理中关闭环境...")
active_environment.close()
logger.info("在最终清理中环境已成功关闭")
except Exception as e:
logger.error(f"最终环境清理期间出错: {e}")
if __name__ == "__main__":
# 禁用tokenizers并行处理避免警告
os.environ["TOKENIZERS_PARALLELISM"] = "false"
main()