"""Script to run end-to-end evaluation on the benchmark. Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. """ import argparse import datetime import json import logging import os import sys from tqdm # import tqdm import time import timeout_decorator from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent # Logger Configs {{{ # logger = logging.getLogger() logger.setLevel(logging.DEBUG) datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8") debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8") stdout_handler = logging.StreamHandler(sys.stdout) sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8") file_handler.setLevel(logging.INFO) debug_handler.setLevel(logging.DEBUG) stdout_handler.setLevel(logging.INFO) sdebug_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s") file_handler.setFormatter(formatter) debug_handler.setFormatter(formatter) stdout_handler.setFormatter(formatter) sdebug_handler.setFormatter(formatter) stdout_handler.addFilter(logging.Filter("desktopenv")) sdebug_handler.addFilter(logging.Filter("desktopenv")) logger.addHandler(file_handler) logger.addHandler(debug_handler) logger.addHandler(stdout_handler) logger.addHandler(sdebug_handler) # }}} Logger Configs # logger = logging.getLogger("desktopenv.experiment") # make sure each example won't exceed the time limit # def handler(signo, frame): # raise RuntimeError("Time limit exceeded!") # signal.signal(signal.SIGALRM, handler) def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" ) # environment config parser.add_argument("--path_to_vm", type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx") parser.add_argument( "--headless", action="store_true", help="Run in headless machine" ) parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type") parser.add_argument( "--observation_type", choices=[ "screenshot", "a11y_tree", "screenshot_a11y_tree", "som" ], default="som", help="Observation type", ) parser.add_argument("--screen_width", type=int, default=1920) parser.add_argument("--screen_height", type=int, default=1080) parser.add_argument("--sleep_after_execution", type=float, default=0.0) parser.add_argument("--max_steps", type=int, default=15) # agent config parser.add_argument("--max_trajectory_length", type=int, default=3) parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") parser.add_argument("--example_time_limit", type=int, default=600) # lm config parser.add_argument("--model", type=str, default="gpt-4-vision-preview") parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--stop_token", type=str, default=None) # logging related parser.add_argument("--result_dir", type=str, default="./results") args = parser.parse_args() return args def test( args: argparse.Namespace, test_all_meta: dict ) -> None: scores = [] max_steps = args.max_steps time_limit = args.example_time_limit # log args logger.info("Args: %s", args) agent = PromptAgent( model=args.model, max_tokens=args.max_tokens, action_space=args.action_space, observation_type=args.observation_type, max_trajectory_length=args.max_trajectory_length, ) env = DesktopEnv( path_to_vm=args.path_to_vm, action_space=agent.action_space, screen_size=(args.screen_width, args.screen_height), headless=args.headless, ) for domain in tqdm(test_all_meta, desc="Domain"): for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") with open(config_file, "r", encoding="utf-8") as f: example = json.load(f) logger.info(f"[Domain]: {domain}") logger.info(f"[Example ID]: {example_id}") instruction = example["instruction"] logger.info(f"[Instruction]: {instruction}") example_result_dir = os.path.join( args.result_dir, args.action_space, args.observation_type, args.model, domain, example_id ) os.makedirs(example_result_dir, exist_ok=True) @timeout_decorator.timeout(seconds=time_limit, timeout_exception=RuntimeError, exception_message="Time limit exceeded.") def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): agent.reset() obs = env.reset(task_config=example) done = False step_idx = 0 env.controller.start_recording() while not done and step_idx < max_steps: actions = agent.predict( instruction, obs ) for action in actions: # Capture the timestamp before executing the action action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") logger.info("Step %d: %s", step_idx + 1, action) obs, reward, done, info = env.step(action, args.sleep_after_execution) logger.info("Reward: %.2f", reward) logger.info("Done: %s", done) logger.info("Info: %s", info) # Save screenshot and trajectory information with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), "wb") as _f: with open(obs['screenshot'], "rb") as __f: screenshot = __f.read() _f.write(screenshot) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ "step_num": step_idx + 1, "action_timestamp": action_timestamp, "action": action, "reward": reward, "done": done, "info": info, "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" })) f.write("\n") if done: logger.info("The episode is done.") break step_idx += 1 result = env.evaluate() logger.info("Result: %.2f", result) scores.append(result) with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: f.write(f"{result}\n") env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) # example start running try: # signal.alarm(time_limit) run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores) except RuntimeError as e: logger.error(f"Error in example {domain}/{example_id}: {e}") # save info of this example and then continue env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ "Error": f"Error in example {domain}/{example_id}: {e}" })) f.write("\n") continue except Exception as e: logger.error(f"Error in example {domain}/{example_id}: {e}") continue env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json): target_dir = os.path.join(result_dir, action_space, observation_type, use_model) if not os.path.exists(target_dir): return total_file_json finished = {} for domain in os.listdir(target_dir): finished[domain] = [] finished[domain] = [] domain_path = os.path.join(target_dir, domain) if os.path.isdir(domain_path): for example_id in os.listdir(domain_path): example_path = os.path.join(domain_path, example_id) if os.path.isdir(example_path): if "result.txt" not in os.listdir(example_path): # empty all files under example_id for file in os.listdir(example_path): os.remove(os.path.join(example_path, file)) else: finished[domain].append(example_id) if not finished: return total_file_json for domain, examples in finished.items(): if domain in total_file_json: total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples] return total_file_json if __name__ == '__main__': ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" args = config() with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: test_all_meta = json.load(f) test_file_list = get_unfinished( args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta ) left_info = "" for domain in test_file_list: left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info(f"Left tasks:\n{left_info}") test(args, test_all_meta)