refactor on exp code

This commit is contained in:
Timothyxxx
2024-03-14 19:25:25 +08:00
parent 313521ac52
commit 71ca8fbe1c
2 changed files with 345 additions and 76 deletions

View File

@@ -1,15 +1,24 @@
# todo: unifiy all the experiments python file into one file
import argparse
"""Script to run end-to-end evaluation on the benchmark.
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import datetime
import json
import logging
import os
import sys
import func_timeout
import argparse
import glob
import json
import logging
import os
import time
from pathlib import Path
import openai
import requests
import torch
from beartype import beartype
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.gpt_4v_agent import GPT4v_Agent # todo: change the name into PromptAgent
from mm_agents.agent import PromptAgent # todo: change the name into PromptAgent
# Logger Configs {{{ #
logger = logging.getLogger()
@@ -45,9 +54,6 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment")
# todo: move the PATH_TO_VM to the argparser
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
max_time=600):
@@ -146,14 +152,13 @@ def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
example["snapshot"] = "exp_v5"
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key,
model=gpt4_model,
instruction=example['instruction'],
action_space=action_space,
exp="screenshot")
#
# api_key = os.environ.get("GENAI_API_KEY")
# agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space, exp="screenshot")
agent = PromptAgent(
api_key=api_key,
model=gpt4_model,
instruction=example['instruction'],
action_space=action_space,
exp="screenshot"
)
root_trajectory_dir = "exp_trajectory"
@@ -188,41 +193,33 @@ def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
parser.add_argument(
"--render", action="store_true", help="Render the browser"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default="Ubuntu\\Ubuntu.vmx")
parser.add_argument(
"--slow_mo",
type=int,
default=0,
help="Slow down the browser by the specified amount",
)
parser.add_argument(
"--action_set_tag", default="id_accessibility_tree", help="Action type"
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
parser.add_argument(
"--observation_type",
choices=[
"accessibility_tree",
"accessibility_tree_with_captioner",
"html",
"image",
"image_som",
"screenshot",
"a11y_tree",
"screenshot_a11y_tree",
"som"
],
default="accessibility_tree",
help="Observation type",
)
parser.add_argument(
"--current_viewport_only",
action="store_true",
help="Only use the current viewport for the observation",
)
parser.add_argument("--viewport_width", type=int, default=1280)
parser.add_argument("--viewport_height", type=int, default=2048)
# parser.add_argument(
# "--current_viewport_only",
# action="store_true",
# help="Only use the current viewport for the observation",
# )
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--save_trace_enabled", action="store_true")
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=30)
# agent config
@@ -230,7 +227,7 @@ def config() -> argparse.Namespace:
parser.add_argument(
"--instruction_path",
type=str,
default="agents/prompts/state_action_agent.json",
default="",
)
parser.add_argument(
"--parsing_failure_th",
@@ -247,28 +244,6 @@ def config() -> argparse.Namespace:
parser.add_argument("--test_config_base_dir", type=str)
parser.add_argument(
"--eval_captioning_model_device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="Device to run eval captioning model on. By default, runs it on CPU.",
)
parser.add_argument(
"--eval_captioning_model",
type=str,
default="Salesforce/blip2-flan-t5-xl",
choices=["Salesforce/blip2-flan-t5-xl"],
help="Captioning backbone for VQA-type evals.",
)
parser.add_argument(
"--captioning_model",
type=str,
default="Salesforce/blip2-flan-t5-xl",
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
help="Captioning backbone for accessibility tree alt text.",
)
# lm config
parser.add_argument("--provider", type=str, default="openai")
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
@@ -293,36 +268,330 @@ def config() -> argparse.Namespace:
# example config
parser.add_argument("--test_start_idx", type=int, default=0)
parser.add_argument("--test_end_idx", type=int, default=910)
parser.add_argument("--test_end_idx", type=int, default=378)
# logging related
parser.add_argument("--result_dir", type=str, default="")
args = parser.parse_args()
# check the whether the action space is compatible with the observation space
if (
args.action_set_tag == "id_accessibility_tree"
and args.observation_type
not in [
"accessibility_tree",
return args
@beartype
def early_stop(
trajectory, max_steps: int, thresholds: dict[str, int]
) -> tuple[bool, str]:
"""Check whether need to stop early"""
# reach the max step
num_steps = (len(trajectory) - 1) / 2
if num_steps >= max_steps:
return True, f"Reach max steps {max_steps}"
# Case: parsing failure for k times
k = thresholds["parsing_failure"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
if len(last_k_actions) >= k:
if all(
[
action["action_type"] == ""
for action in last_k_actions
]
):
return True, f"Failed to parse actions for {k} times"
# Case: same action for k times
k = thresholds["repeating_action"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
action_seq = trajectory[1::2] # type: ignore[assignment]
if len(action_seq) == 0:
return False, ""
last_action = action_seq[-1]
if last_action["action_type"] != ActionTypes.TYPE:
if len(last_k_actions) >= k:
if all(
[
is_equivalent(action, last_action)
for action in last_k_actions
]
):
return True, f"Same action for {k} times"
else:
# check the action sequence
if (
sum([is_equivalent(action, last_action) for action in action_seq])
>= k
):
return True, f"Same typing action for {k} times"
return False, ""
@beartype
def test(
args: argparse.Namespace,
config_file_list: list[str]
) -> None:
scores = []
max_steps = args.max_steps
early_stop_thresholds = {
"parsing_failure": args.parsing_failure_th,
"repeating_action": args.repeating_action_failure_th,
}
if args.observation_type in [
"accessibility_tree_with_captioner",
"image_som",
]
]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
caption_image_fn = image_utils.get_captioning_fn(
device, dtype, args.captioning_model
)
else:
caption_image_fn = None
# Load a (possibly different) captioning model for running VQA evals.
if (
caption_image_fn
and args.eval_captioning_model == args.captioning_model
):
raise ValueError(
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
eval_caption_image_fn = caption_image_fn
else:
eval_caption_image_fn = image_utils.get_captioning_fn(
args.eval_captioning_model_device,
torch.float16
if (
torch.cuda.is_available()
and args.eval_captioning_model_device == "cuda"
)
else torch.float32,
args.eval_captioning_model,
)
return args
agent = construct_agent(
args,
captioning_fn=caption_image_fn
if args.observation_type == "accessibility_tree_with_captioner"
else None,
) # NOTE: captioning_fn here is used for captioning input images.
env = ScriptBrowserEnv(
headless=not args.render,
slow_mo=args.slow_mo,
observation_type=args.observation_type,
current_viewport_only=args.current_viewport_only,
viewport_size={
"width": args.viewport_width,
"height": args.viewport_height,
},
save_trace_enabled=args.save_trace_enabled,
sleep_after_execution=args.sleep_after_execution,
# NOTE: captioning_fn here is used for LLM + captioning baselines.
# This can be different from the captioning model used for evals.
captioning_fn=caption_image_fn,
)
for config_file in config_file_list:
try:
render_helper = RenderHelper(
config_file, args.result_dir, args.action_set_tag
)
# Load task.
with open(config_file) as f:
_c = json.load(f)
intent = _c["intent"]
task_id = _c["task_id"]
image_paths = _c.get("image", None)
images = []
# Load input images for the task, if any.
if image_paths is not None:
if isinstance(image_paths, str):
image_paths = [image_paths]
for image_path in image_paths:
# Load image either from the web or from a local path.
if image_path.startswith("http"):
input_image = Image.open(requests.get(image_path, stream=True).raw)
else:
input_image = Image.open(image_path)
images.append(input_image)
logger.info(f"[Config file]: {config_file}")
logger.info(f"[Intent]: {intent}")
agent.reset(config_file)
trajectory: Trajectory = []
obs, info = env.reset(options={"config_file": config_file})
state_info: StateInfo = {"observation": obs, "info": info}
trajectory.append(state_info)
meta_data = {"action_history": ["None"]}
while True:
early_stop_flag, stop_info = early_stop(
trajectory, max_steps, early_stop_thresholds
)
if early_stop_flag:
action = create_stop_action(f"Early stop: {stop_info}")
else:
try:
action = agent.next_action(
trajectory,
intent,
images=images,
meta_data=meta_data,
)
except ValueError as e:
# get the error message
action = create_stop_action(f"ERROR: {str(e)}")
trajectory.append(action)
action_str = get_action_description(
action,
state_info["info"]["observation_metadata"],
action_set_tag=args.action_set_tag,
prompt_constructor=agent.prompt_constructor
if isinstance(agent, PromptAgent)
else None,
)
render_helper.render(
action, state_info, meta_data, args.render_screenshot
)
meta_data["action_history"].append(action_str)
if action["action_type"] == ActionTypes.STOP:
break
obs, _, terminated, _, info = env.step(action)
state_info = {"observation": obs, "info": info}
trajectory.append(state_info)
if terminated:
# add a action place holder
trajectory.append(create_stop_action(""))
break
# NOTE: eval_caption_image_fn is used for running eval_vqa functions.
evaluator = evaluator_router(
config_file, captioning_fn=eval_caption_image_fn
)
score = evaluator(
trajectory=trajectory,
config_file=config_file,
page=env.page,
client=env.get_page_client(env.page),
)
scores.append(score)
if score == 1:
logger.info(f"[Result] (PASS) {config_file}")
else:
logger.info(f"[Result] (FAIL) {config_file}")
if args.save_trace_enabled:
env.save_trace(
Path(args.result_dir) / "traces" / f"{task_id}.zip"
)
except openai.OpenAIError as e:
logger.info(f"[OpenAI Error] {repr(e)}")
except Exception as e:
logger.info(f"[Unhandled Error] {repr(e)}]")
import traceback
# write to error file
with open(Path(args.result_dir) / "error.txt", "a") as f:
f.write(f"[Config file]: {config_file}\n")
f.write(f"[Unhandled Error] {repr(e)}\n")
f.write(traceback.format_exc()) # write stack trace to file
render_helper.close()
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
def prepare(args: argparse.Namespace) -> None:
# convert prompt python files to json
from agent.prompts import to_json
to_json.run()
# prepare result dir
result_dir = args.result_dir
if not result_dir:
result_dir = (
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
)
if not Path(result_dir).exists():
Path(result_dir).mkdir(parents=True, exist_ok=True)
args.result_dir = result_dir
logger.info(f"Create result dir: {result_dir}")
if not (Path(result_dir) / "traces").exists():
(Path(result_dir) / "traces").mkdir(parents=True)
# log the log file
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
f.write(f"{LOG_FILE_NAME}\n")
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
result_files = glob.glob(f"{result_dir}/*.html")
task_ids = [
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
]
unfinished_configs = []
for config_file in config_files:
task_id = os.path.basename(config_file).split(".")[0]
if task_id not in task_ids:
unfinished_configs.append(config_file)
return unfinished_configs
@beartype
def dump_config(args: argparse.Namespace) -> None:
config_file = Path(args.result_dir) / "config.json"
if not config_file.exists():
with open(config_file, "w") as f:
json.dump(vars(args), f, indent=4)
logger.info(f"Dump config to {config_file}")
if __name__ == '__main__':
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
args.sleep_after_execution = 2.5
args.sleep_after_execution = 5
prepare(args)
test_config_base_dir = args.test_config_base_dir
test_file_list = []
st_idx = args.test_start_idx
ed_idx = args.test_end_idx
for i in range(st_idx, ed_idx):
test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json"))
test_file_list = get_unfinished(test_file_list, args.result_dir)
print(f"Total {len(test_file_list)} tasks left")
args.render = False
args.render_screenshot = True
args.save_trace_enabled = True
args.current_viewport_only = True
dump_config(args)
test(args, test_file_list)
# todo: add recorder of the progress of the examples
# todo: remove the useless example files

View File

@@ -169,7 +169,7 @@ def parse_code_from_som_string(input_string, masks):
return actions
class GPT4v_Agent:
class PromptAgent:
def __init__(
self,
api_key,