From c7e30044566c6654b5af687665e7ac8a7db998f5 Mon Sep 17 00:00:00 2001 From: Xiaochuan Li Date: Tue, 23 Jul 2024 19:28:40 -0500 Subject: [PATCH] fix the bug about auto download; now the default vmware path is None, which can trigger the auto download manner (#58) --- run.py | 117 +++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 76 insertions(+), 41 deletions(-) diff --git a/run.py b/run.py index 7e71ff7..d85bb05 100644 --- a/run.py +++ b/run.py @@ -1,13 +1,13 @@ """Script to run end-to-end evaluation on the benchmark. Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. """ + import argparse import datetime import json import logging import os import sys -# import wandb from tqdm import tqdm @@ -15,16 +15,25 @@ import lib_run_single from desktop_env.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent +# import wandb + + # Logger Configs {{{ # logger = logging.getLogger() logger.setLevel(logging.DEBUG) datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") -file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8") -debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8") +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) stdout_handler = logging.StreamHandler(sys.stdout) -sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8") +sdebug_handler = logging.FileHandler( + os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" +) file_handler.setLevel(logging.INFO) debug_handler.setLevel(logging.DEBUG) @@ -32,7 +41,8 @@ stdout_handler.setLevel(logging.INFO) sdebug_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( - fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s") + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) file_handler.setFormatter(formatter) debug_handler.setFormatter(formatter) stdout_handler.setFormatter(formatter) @@ -45,30 +55,27 @@ logger.addHandler(file_handler) logger.addHandler(debug_handler) logger.addHandler(stdout_handler) logger.addHandler(sdebug_handler) -# }}} Logger Configs # +# }}} Logger Configs # logger = logging.getLogger("desktopenv.experiment") + def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" ) # environment config - parser.add_argument("--path_to_vm", type=str, - default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx") + parser.add_argument("--path_to_vm", type=str, default=None) parser.add_argument( "--headless", action="store_true", help="Run in headless machine" ) - parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type") + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) parser.add_argument( "--observation_type", - choices=[ - "screenshot", - "a11y_tree", - "screenshot_a11y_tree", - "som" - ], + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], default="a11y_tree", help="Observation type", ) @@ -79,7 +86,9 @@ def config() -> argparse.Namespace: # agent config parser.add_argument("--max_trajectory_length", type=int, default=3) - parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) # lm config parser.add_argument("--model", type=str, default="gpt-4-0125-preview") @@ -90,7 +99,9 @@ def config() -> argparse.Namespace: # example config parser.add_argument("--domain", type=str, default="all") - parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_all.json") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) # logging related parser.add_argument("--result_dir", type=str, default="./results") @@ -99,18 +110,14 @@ def config() -> argparse.Namespace: return args -def test( - args: argparse.Namespace, - test_all_meta: dict -) -> None: +def test(args: argparse.Namespace, test_all_meta: dict) -> None: scores = [] max_steps = args.max_steps # log args logger.info("Args: %s", args) # set wandb project - cfg_args = \ - { + cfg_args = { "path_to_vm": args.path_to_vm, "headless": args.headless, "action_space": args.action_space, @@ -125,7 +132,7 @@ def test( "top_p": args.top_p, "max_tokens": args.max_tokens, "stop_token": args.stop_token, - "result_dir": args.result_dir + "result_dir": args.result_dir, } agent = PromptAgent( @@ -143,12 +150,15 @@ def test( action_space=agent.action_space, screen_size=(args.screen_width, args.screen_height), headless=args.headless, - require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + require_a11y_tree=args.observation_type + in ["a11y_tree", "screenshot_a11y_tree", "som"], ) for domain in tqdm(test_all_meta, desc="Domain"): for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): - config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) with open(config_file, "r", encoding="utf-8") as f: example = json.load(f) @@ -160,7 +170,9 @@ def test( logger.info(f"[Instruction]: {instruction}") # wandb each example config settings cfg_args["instruction"] = instruction - cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S") + cfg_args["start_time"] = datetime.datetime.now().strftime( + "%Y:%m:%d-%H:%M:%S" + ) # run.config.update(cfg_args) example_result_dir = os.path.join( @@ -169,27 +181,41 @@ def test( args.observation_type, args.model, domain, - example_id + example_id, ) os.makedirs(example_result_dir, exist_ok=True) # example start running try: - lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, - scores) + lib_run_single.run_single_example( + agent, + env, + example, + max_steps, + instruction, + args, + example_result_dir, + scores, + ) except Exception as e: logger.error(f"Exception in {domain}/{example_id}: {e}") - env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write(json.dumps({ - "Error": f"Time limit exceeded in {domain}/{example_id}" - })) + f.write( + json.dumps( + {"Error": f"Time limit exceeded in {domain}/{example_id}"} + ) + ) f.write("\n") env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") -def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json): +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): target_dir = os.path.join(result_dir, action_space, observation_type, use_model) if not os.path.exists(target_dir): @@ -217,7 +243,9 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ for domain, examples in finished.items(): if domain in total_file_json: - total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples] + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] return total_file_json @@ -239,7 +267,13 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file if "result.txt" in os.listdir(example_path): # empty all files under example_id try: - all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read())) + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) except: all_result.append(0.0) @@ -251,7 +285,7 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file return all_result -if __name__ == '__main__': +if __name__ == "__main__": ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" args = config() @@ -267,17 +301,18 @@ if __name__ == '__main__': args.model, args.observation_type, args.result_dir, - test_all_meta + test_all_meta, ) left_info = "" for domain in test_file_list: left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info(f"Left tasks:\n{left_info}") - get_result(args.action_space, + get_result( + args.action_space, args.model, args.observation_type, args.result_dir, - test_all_meta + test_all_meta, ) test(args, test_file_list)