From 71ca8fbe1c40e52edf09fde0cf7bf7d836e52693 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Thu, 14 Mar 2024 19:25:25 +0800 Subject: [PATCH] refactor on exp code --- experiment_screenshot.py | 419 +++++++++++++++++++----- mm_agents/{gpt_4v_agent.py => agent.py} | 2 +- 2 files changed, 345 insertions(+), 76 deletions(-) rename mm_agents/{gpt_4v_agent.py => agent.py} (99%) diff --git a/experiment_screenshot.py b/experiment_screenshot.py index b426401..ffcbdf2 100644 --- a/experiment_screenshot.py +++ b/experiment_screenshot.py @@ -1,15 +1,24 @@ -# todo: unifiy all the experiments python file into one file -import argparse +"""Script to run end-to-end evaluation on the benchmark. +Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. +""" import datetime -import json -import logging -import os import sys import func_timeout +import argparse +import glob +import json +import logging +import os +import time +from pathlib import Path +import openai +import requests +import torch +from beartype import beartype from desktop_env.envs.desktop_env import DesktopEnv -from mm_agents.gpt_4v_agent import GPT4v_Agent # todo: change the name into PromptAgent +from mm_agents.agent import PromptAgent # todo: change the name into PromptAgent # Logger Configs {{{ # logger = logging.getLogger() @@ -45,9 +54,6 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") -# todo: move the PATH_TO_VM to the argparser -PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx" - def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True, max_time=600): @@ -146,14 +152,13 @@ def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"): example["snapshot"] = "exp_v5" api_key = os.environ.get("OPENAI_API_KEY") - agent = GPT4v_Agent(api_key=api_key, - model=gpt4_model, - instruction=example['instruction'], - action_space=action_space, - exp="screenshot") - # - # api_key = os.environ.get("GENAI_API_KEY") - # agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space, exp="screenshot") + agent = PromptAgent( + api_key=api_key, + model=gpt4_model, + instruction=example['instruction'], + action_space=action_space, + exp="screenshot" + ) root_trajectory_dir = "exp_trajectory" @@ -188,41 +193,33 @@ def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" ) - parser.add_argument( - "--render", action="store_true", help="Render the browser" - ) + # environment config + parser.add_argument("--path_to_vm", type=str, default="Ubuntu\\Ubuntu.vmx") parser.add_argument( - "--slow_mo", - type=int, - default=0, - help="Slow down the browser by the specified amount", - ) - parser.add_argument( - "--action_set_tag", default="id_accessibility_tree", help="Action type" + "--headless", action="store_true", help="Run in headless machine" ) + parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type") parser.add_argument( "--observation_type", choices=[ - "accessibility_tree", - "accessibility_tree_with_captioner", - "html", - "image", - "image_som", + "screenshot", + "a11y_tree", + "screenshot_a11y_tree", + "som" ], default="accessibility_tree", help="Observation type", ) - parser.add_argument( - "--current_viewport_only", - action="store_true", - help="Only use the current viewport for the observation", - ) - parser.add_argument("--viewport_width", type=int, default=1280) - parser.add_argument("--viewport_height", type=int, default=2048) + # parser.add_argument( + # "--current_viewport_only", + # action="store_true", + # help="Only use the current viewport for the observation", + # ) + parser.add_argument("--screen_width", type=int, default=1920) + parser.add_argument("--screen_height", type=int, default=1080) parser.add_argument("--save_trace_enabled", action="store_true") parser.add_argument("--sleep_after_execution", type=float, default=0.0) - parser.add_argument("--max_steps", type=int, default=30) # agent config @@ -230,7 +227,7 @@ def config() -> argparse.Namespace: parser.add_argument( "--instruction_path", type=str, - default="agents/prompts/state_action_agent.json", + default="", ) parser.add_argument( "--parsing_failure_th", @@ -247,28 +244,6 @@ def config() -> argparse.Namespace: parser.add_argument("--test_config_base_dir", type=str) - parser.add_argument( - "--eval_captioning_model_device", - type=str, - default="cpu", - choices=["cpu", "cuda"], - help="Device to run eval captioning model on. By default, runs it on CPU.", - ) - parser.add_argument( - "--eval_captioning_model", - type=str, - default="Salesforce/blip2-flan-t5-xl", - choices=["Salesforce/blip2-flan-t5-xl"], - help="Captioning backbone for VQA-type evals.", - ) - parser.add_argument( - "--captioning_model", - type=str, - default="Salesforce/blip2-flan-t5-xl", - choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"], - help="Captioning backbone for accessibility tree alt text.", - ) - # lm config parser.add_argument("--provider", type=str, default="openai") parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613") @@ -293,36 +268,330 @@ def config() -> argparse.Namespace: # example config parser.add_argument("--test_start_idx", type=int, default=0) - parser.add_argument("--test_end_idx", type=int, default=910) + parser.add_argument("--test_end_idx", type=int, default=378) # logging related parser.add_argument("--result_dir", type=str, default="") args = parser.parse_args() - # check the whether the action space is compatible with the observation space - if ( - args.action_set_tag == "id_accessibility_tree" - and args.observation_type - not in [ - "accessibility_tree", + return args + + +@beartype +def early_stop( + trajectory, max_steps: int, thresholds: dict[str, int] +) -> tuple[bool, str]: + """Check whether need to stop early""" + + # reach the max step + num_steps = (len(trajectory) - 1) / 2 + if num_steps >= max_steps: + return True, f"Reach max steps {max_steps}" + + # Case: parsing failure for k times + k = thresholds["parsing_failure"] + last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment] + if len(last_k_actions) >= k: + if all( + [ + action["action_type"] == "" + for action in last_k_actions + ] + ): + return True, f"Failed to parse actions for {k} times" + + # Case: same action for k times + k = thresholds["repeating_action"] + last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment] + action_seq = trajectory[1::2] # type: ignore[assignment] + + if len(action_seq) == 0: + return False, "" + + last_action = action_seq[-1] + + if last_action["action_type"] != ActionTypes.TYPE: + if len(last_k_actions) >= k: + if all( + [ + is_equivalent(action, last_action) + for action in last_k_actions + ] + ): + return True, f"Same action for {k} times" + + else: + # check the action sequence + if ( + sum([is_equivalent(action, last_action) for action in action_seq]) + >= k + ): + return True, f"Same typing action for {k} times" + + return False, "" + + +@beartype +def test( + args: argparse.Namespace, + config_file_list: list[str] +) -> None: + scores = [] + max_steps = args.max_steps + + early_stop_thresholds = { + "parsing_failure": args.parsing_failure_th, + "repeating_action": args.repeating_action_failure_th, + } + + if args.observation_type in [ "accessibility_tree_with_captioner", "image_som", - ] + ]: + device = torch.device("cuda") if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + caption_image_fn = image_utils.get_captioning_fn( + device, dtype, args.captioning_model + ) + else: + caption_image_fn = None + + # Load a (possibly different) captioning model for running VQA evals. + if ( + caption_image_fn + and args.eval_captioning_model == args.captioning_model ): - raise ValueError( - f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}" + eval_caption_image_fn = caption_image_fn + else: + eval_caption_image_fn = image_utils.get_captioning_fn( + args.eval_captioning_model_device, + torch.float16 + if ( + torch.cuda.is_available() + and args.eval_captioning_model_device == "cuda" + ) + else torch.float32, + args.eval_captioning_model, ) - return args + agent = construct_agent( + args, + captioning_fn=caption_image_fn + if args.observation_type == "accessibility_tree_with_captioner" + else None, + ) # NOTE: captioning_fn here is used for captioning input images. + + env = ScriptBrowserEnv( + headless=not args.render, + slow_mo=args.slow_mo, + observation_type=args.observation_type, + current_viewport_only=args.current_viewport_only, + viewport_size={ + "width": args.viewport_width, + "height": args.viewport_height, + }, + save_trace_enabled=args.save_trace_enabled, + sleep_after_execution=args.sleep_after_execution, + # NOTE: captioning_fn here is used for LLM + captioning baselines. + # This can be different from the captioning model used for evals. + captioning_fn=caption_image_fn, + ) + + for config_file in config_file_list: + try: + render_helper = RenderHelper( + config_file, args.result_dir, args.action_set_tag + ) + + # Load task. + with open(config_file) as f: + _c = json.load(f) + intent = _c["intent"] + task_id = _c["task_id"] + image_paths = _c.get("image", None) + images = [] + + # Load input images for the task, if any. + if image_paths is not None: + if isinstance(image_paths, str): + image_paths = [image_paths] + for image_path in image_paths: + # Load image either from the web or from a local path. + if image_path.startswith("http"): + input_image = Image.open(requests.get(image_path, stream=True).raw) + else: + input_image = Image.open(image_path) + + images.append(input_image) + + logger.info(f"[Config file]: {config_file}") + logger.info(f"[Intent]: {intent}") + + agent.reset(config_file) + trajectory: Trajectory = [] + obs, info = env.reset(options={"config_file": config_file}) + state_info: StateInfo = {"observation": obs, "info": info} + trajectory.append(state_info) + + meta_data = {"action_history": ["None"]} + while True: + early_stop_flag, stop_info = early_stop( + trajectory, max_steps, early_stop_thresholds + ) + + if early_stop_flag: + action = create_stop_action(f"Early stop: {stop_info}") + else: + try: + action = agent.next_action( + trajectory, + intent, + images=images, + meta_data=meta_data, + ) + except ValueError as e: + # get the error message + action = create_stop_action(f"ERROR: {str(e)}") + + trajectory.append(action) + + action_str = get_action_description( + action, + state_info["info"]["observation_metadata"], + action_set_tag=args.action_set_tag, + prompt_constructor=agent.prompt_constructor + if isinstance(agent, PromptAgent) + else None, + ) + render_helper.render( + action, state_info, meta_data, args.render_screenshot + ) + meta_data["action_history"].append(action_str) + + if action["action_type"] == ActionTypes.STOP: + break + + obs, _, terminated, _, info = env.step(action) + state_info = {"observation": obs, "info": info} + trajectory.append(state_info) + + if terminated: + # add a action place holder + trajectory.append(create_stop_action("")) + break + + # NOTE: eval_caption_image_fn is used for running eval_vqa functions. + evaluator = evaluator_router( + config_file, captioning_fn=eval_caption_image_fn + ) + score = evaluator( + trajectory=trajectory, + config_file=config_file, + page=env.page, + client=env.get_page_client(env.page), + ) + + scores.append(score) + + if score == 1: + logger.info(f"[Result] (PASS) {config_file}") + else: + logger.info(f"[Result] (FAIL) {config_file}") + + if args.save_trace_enabled: + env.save_trace( + Path(args.result_dir) / "traces" / f"{task_id}.zip" + ) + except openai.OpenAIError as e: + logger.info(f"[OpenAI Error] {repr(e)}") + except Exception as e: + logger.info(f"[Unhandled Error] {repr(e)}]") + import traceback + + # write to error file + with open(Path(args.result_dir) / "error.txt", "a") as f: + f.write(f"[Config file]: {config_file}\n") + f.write(f"[Unhandled Error] {repr(e)}\n") + f.write(traceback.format_exc()) # write stack trace to file + + render_helper.close() + + env.close() + logger.info(f"Average score: {sum(scores) / len(scores)}") + + +def prepare(args: argparse.Namespace) -> None: + # convert prompt python files to json + from agent.prompts import to_json + + to_json.run() + + # prepare result dir + result_dir = args.result_dir + if not result_dir: + result_dir = ( + f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}" + ) + if not Path(result_dir).exists(): + Path(result_dir).mkdir(parents=True, exist_ok=True) + args.result_dir = result_dir + logger.info(f"Create result dir: {result_dir}") + + if not (Path(result_dir) / "traces").exists(): + (Path(result_dir) / "traces").mkdir(parents=True) + + # log the log file + with open(os.path.join(result_dir, "log_files.txt"), "a+") as f: + f.write(f"{LOG_FILE_NAME}\n") + + +def get_unfinished(config_files: list[str], result_dir: str) -> list[str]: + result_files = glob.glob(f"{result_dir}/*.html") + task_ids = [ + os.path.basename(f).split(".")[0].split("_")[1] for f in result_files + ] + unfinished_configs = [] + for config_file in config_files: + task_id = os.path.basename(config_file).split(".")[0] + if task_id not in task_ids: + unfinished_configs.append(config_file) + return unfinished_configs + + +@beartype +def dump_config(args: argparse.Namespace) -> None: + config_file = Path(args.result_dir) / "config.json" + if not config_file.exists(): + with open(config_file, "w") as f: + json.dump(vars(args), f, indent=4) + logger.info(f"Dump config to {config_file}") if __name__ == '__main__': ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" args = config() - args.sleep_after_execution = 2.5 + args.sleep_after_execution = 5 prepare(args) + test_config_base_dir = args.test_config_base_dir + + test_file_list = [] + st_idx = args.test_start_idx + ed_idx = args.test_end_idx + for i in range(st_idx, ed_idx): + test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json")) + test_file_list = get_unfinished(test_file_list, args.result_dir) + print(f"Total {len(test_file_list)} tasks left") + args.render = False + args.render_screenshot = True + args.save_trace_enabled = True + + args.current_viewport_only = True + dump_config(args) + + test(args, test_file_list) + # todo: add recorder of the progress of the examples # todo: remove the useless example files diff --git a/mm_agents/gpt_4v_agent.py b/mm_agents/agent.py similarity index 99% rename from mm_agents/gpt_4v_agent.py rename to mm_agents/agent.py index 7e9c400..4ad0445 100644 --- a/mm_agents/gpt_4v_agent.py +++ b/mm_agents/agent.py @@ -169,7 +169,7 @@ def parse_code_from_som_string(input_string, masks): return actions -class GPT4v_Agent: +class PromptAgent: def __init__( self, api_key,