diff --git a/run.py b/run.py index c56e142..7118d5b 100644 --- a/run.py +++ b/run.py @@ -6,9 +6,10 @@ import datetime import json import logging import os -import signal import sys +from tqdm import tqdm + from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent @@ -52,7 +53,8 @@ def handler(signo, frame): raise RuntimeError("Time limit exceeded!") -signal.signal(signal.SIGALRM, handler) +# fixme: windows doesn't support signal +# signal.signal(signal.SIGALRM, handler) def config() -> argparse.Namespace: @@ -128,8 +130,8 @@ def test( headless=args.headless, ) - for domain in test_all_meta: - for example_id in test_all_meta[domain]: + for domain in tqdm(test_all_meta, desc="Domain"): + for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") with open(config_file, "r", encoding="utf-8") as f: @@ -154,7 +156,7 @@ def test( # example start running try: - signal.alarm(time_limit) + # signal.alarm(time_limit) fixme: windows doesn't support signal agent.reset() obs = env.reset(task_config=example) done = False @@ -204,6 +206,8 @@ def test( result = env.evaluate() logger.info("Result: %.2f", result) scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) except RuntimeError as e: logger.error(f"Error in example {domain}/{example_id}: {e}") @@ -224,6 +228,10 @@ def test( })) f.write("\n") continue + except Exception as e: + logger.error(f"Error in example {domain}/{example_id}: {e}") + continue + env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") @@ -236,9 +244,13 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ finished = {} for domain in os.listdir(target_dir): + finished[domain] = [] domain_path = os.path.join(target_dir, domain) if os.path.isdir(domain_path): - finished[domain] = os.listdir(domain_path) + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path) and "result.txt" in os.listdir(example_path): + finished[domain].append(example_id) if not finished: return total_file_json