From 51d644c88bcf0a3102bd4a0b79bb09144e8f3aea Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 15 Mar 2024 21:12:18 +0800 Subject: [PATCH] Merge --- run.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/run.py b/run.py index 953c6b7..c56e142 100644 --- a/run.py +++ b/run.py @@ -6,8 +6,8 @@ import datetime import json import logging import os -import sys import signal +import sys from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent @@ -46,11 +46,15 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") + # make sure each example won't exceed the time limit def handler(signo, frame): raise RuntimeError("Time limit exceeded!") + + signal.signal(signal.SIGALRM, handler) + def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" @@ -175,7 +179,7 @@ def test( # Save screenshot and trajectory information with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), - "wb") as _f: + "wb") as _f: with open(obs['screenshot'], "rb") as __f: screenshot = __f.read() _f.write(screenshot) @@ -245,6 +249,7 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ return total_file_json + if __name__ == '__main__': ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -253,7 +258,13 @@ if __name__ == '__main__': with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: test_all_meta = json.load(f) - test_file_list = get_unfinished(args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta) + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta + ) left_info = "" for domain in test_file_list: left_info += f"{domain}: {len(test_file_list[domain])}\n"