refactor on exp code
This commit is contained in:
@@ -1,15 +1,24 @@
|
||||
# todo: unifiy all the experiments python file into one file
|
||||
import argparse
|
||||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import func_timeout
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
import openai
|
||||
import requests
|
||||
import torch
|
||||
from beartype import beartype
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent # todo: change the name into PromptAgent
|
||||
from mm_agents.agent import PromptAgent # todo: change the name into PromptAgent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
@@ -45,9 +54,6 @@ logger.addHandler(sdebug_handler)
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
# todo: move the PATH_TO_VM to the argparser
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
|
||||
max_time=600):
|
||||
@@ -146,14 +152,13 @@ def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key,
|
||||
model=gpt4_model,
|
||||
instruction=example['instruction'],
|
||||
action_space=action_space,
|
||||
exp="screenshot")
|
||||
#
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space, exp="screenshot")
|
||||
agent = PromptAgent(
|
||||
api_key=api_key,
|
||||
model=gpt4_model,
|
||||
instruction=example['instruction'],
|
||||
action_space=action_space,
|
||||
exp="screenshot"
|
||||
)
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
@@ -188,41 +193,33 @@ def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--render", action="store_true", help="Render the browser"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default="Ubuntu\\Ubuntu.vmx")
|
||||
parser.add_argument(
|
||||
"--slow_mo",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Slow down the browser by the specified amount",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_set_tag", default="id_accessibility_tree", help="Action type"
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=[
|
||||
"accessibility_tree",
|
||||
"accessibility_tree_with_captioner",
|
||||
"html",
|
||||
"image",
|
||||
"image_som",
|
||||
"screenshot",
|
||||
"a11y_tree",
|
||||
"screenshot_a11y_tree",
|
||||
"som"
|
||||
],
|
||||
default="accessibility_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--current_viewport_only",
|
||||
action="store_true",
|
||||
help="Only use the current viewport for the observation",
|
||||
)
|
||||
parser.add_argument("--viewport_width", type=int, default=1280)
|
||||
parser.add_argument("--viewport_height", type=int, default=2048)
|
||||
# parser.add_argument(
|
||||
# "--current_viewport_only",
|
||||
# action="store_true",
|
||||
# help="Only use the current viewport for the observation",
|
||||
# )
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--save_trace_enabled", action="store_true")
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--max_steps", type=int, default=30)
|
||||
|
||||
# agent config
|
||||
@@ -230,7 +227,7 @@ def config() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--instruction_path",
|
||||
type=str,
|
||||
default="agents/prompts/state_action_agent.json",
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parsing_failure_th",
|
||||
@@ -247,28 +244,6 @@ def config() -> argparse.Namespace:
|
||||
|
||||
parser.add_argument("--test_config_base_dir", type=str)
|
||||
|
||||
parser.add_argument(
|
||||
"--eval_captioning_model_device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="Device to run eval captioning model on. By default, runs it on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_captioning_model",
|
||||
type=str,
|
||||
default="Salesforce/blip2-flan-t5-xl",
|
||||
choices=["Salesforce/blip2-flan-t5-xl"],
|
||||
help="Captioning backbone for VQA-type evals.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--captioning_model",
|
||||
type=str,
|
||||
default="Salesforce/blip2-flan-t5-xl",
|
||||
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
|
||||
help="Captioning backbone for accessibility tree alt text.",
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--provider", type=str, default="openai")
|
||||
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
||||
@@ -293,36 +268,330 @@ def config() -> argparse.Namespace:
|
||||
|
||||
# example config
|
||||
parser.add_argument("--test_start_idx", type=int, default=0)
|
||||
parser.add_argument("--test_end_idx", type=int, default=910)
|
||||
parser.add_argument("--test_end_idx", type=int, default=378)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check the whether the action space is compatible with the observation space
|
||||
if (
|
||||
args.action_set_tag == "id_accessibility_tree"
|
||||
and args.observation_type
|
||||
not in [
|
||||
"accessibility_tree",
|
||||
return args
|
||||
|
||||
|
||||
@beartype
|
||||
def early_stop(
|
||||
trajectory, max_steps: int, thresholds: dict[str, int]
|
||||
) -> tuple[bool, str]:
|
||||
"""Check whether need to stop early"""
|
||||
|
||||
# reach the max step
|
||||
num_steps = (len(trajectory) - 1) / 2
|
||||
if num_steps >= max_steps:
|
||||
return True, f"Reach max steps {max_steps}"
|
||||
|
||||
# Case: parsing failure for k times
|
||||
k = thresholds["parsing_failure"]
|
||||
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
||||
if len(last_k_actions) >= k:
|
||||
if all(
|
||||
[
|
||||
action["action_type"] == ""
|
||||
for action in last_k_actions
|
||||
]
|
||||
):
|
||||
return True, f"Failed to parse actions for {k} times"
|
||||
|
||||
# Case: same action for k times
|
||||
k = thresholds["repeating_action"]
|
||||
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
||||
action_seq = trajectory[1::2] # type: ignore[assignment]
|
||||
|
||||
if len(action_seq) == 0:
|
||||
return False, ""
|
||||
|
||||
last_action = action_seq[-1]
|
||||
|
||||
if last_action["action_type"] != ActionTypes.TYPE:
|
||||
if len(last_k_actions) >= k:
|
||||
if all(
|
||||
[
|
||||
is_equivalent(action, last_action)
|
||||
for action in last_k_actions
|
||||
]
|
||||
):
|
||||
return True, f"Same action for {k} times"
|
||||
|
||||
else:
|
||||
# check the action sequence
|
||||
if (
|
||||
sum([is_equivalent(action, last_action) for action in action_seq])
|
||||
>= k
|
||||
):
|
||||
return True, f"Same typing action for {k} times"
|
||||
|
||||
return False, ""
|
||||
|
||||
|
||||
@beartype
|
||||
def test(
|
||||
args: argparse.Namespace,
|
||||
config_file_list: list[str]
|
||||
) -> None:
|
||||
scores = []
|
||||
max_steps = args.max_steps
|
||||
|
||||
early_stop_thresholds = {
|
||||
"parsing_failure": args.parsing_failure_th,
|
||||
"repeating_action": args.repeating_action_failure_th,
|
||||
}
|
||||
|
||||
if args.observation_type in [
|
||||
"accessibility_tree_with_captioner",
|
||||
"image_som",
|
||||
]
|
||||
]:
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
caption_image_fn = image_utils.get_captioning_fn(
|
||||
device, dtype, args.captioning_model
|
||||
)
|
||||
else:
|
||||
caption_image_fn = None
|
||||
|
||||
# Load a (possibly different) captioning model for running VQA evals.
|
||||
if (
|
||||
caption_image_fn
|
||||
and args.eval_captioning_model == args.captioning_model
|
||||
):
|
||||
raise ValueError(
|
||||
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
|
||||
eval_caption_image_fn = caption_image_fn
|
||||
else:
|
||||
eval_caption_image_fn = image_utils.get_captioning_fn(
|
||||
args.eval_captioning_model_device,
|
||||
torch.float16
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and args.eval_captioning_model_device == "cuda"
|
||||
)
|
||||
else torch.float32,
|
||||
args.eval_captioning_model,
|
||||
)
|
||||
|
||||
return args
|
||||
agent = construct_agent(
|
||||
args,
|
||||
captioning_fn=caption_image_fn
|
||||
if args.observation_type == "accessibility_tree_with_captioner"
|
||||
else None,
|
||||
) # NOTE: captioning_fn here is used for captioning input images.
|
||||
|
||||
env = ScriptBrowserEnv(
|
||||
headless=not args.render,
|
||||
slow_mo=args.slow_mo,
|
||||
observation_type=args.observation_type,
|
||||
current_viewport_only=args.current_viewport_only,
|
||||
viewport_size={
|
||||
"width": args.viewport_width,
|
||||
"height": args.viewport_height,
|
||||
},
|
||||
save_trace_enabled=args.save_trace_enabled,
|
||||
sleep_after_execution=args.sleep_after_execution,
|
||||
# NOTE: captioning_fn here is used for LLM + captioning baselines.
|
||||
# This can be different from the captioning model used for evals.
|
||||
captioning_fn=caption_image_fn,
|
||||
)
|
||||
|
||||
for config_file in config_file_list:
|
||||
try:
|
||||
render_helper = RenderHelper(
|
||||
config_file, args.result_dir, args.action_set_tag
|
||||
)
|
||||
|
||||
# Load task.
|
||||
with open(config_file) as f:
|
||||
_c = json.load(f)
|
||||
intent = _c["intent"]
|
||||
task_id = _c["task_id"]
|
||||
image_paths = _c.get("image", None)
|
||||
images = []
|
||||
|
||||
# Load input images for the task, if any.
|
||||
if image_paths is not None:
|
||||
if isinstance(image_paths, str):
|
||||
image_paths = [image_paths]
|
||||
for image_path in image_paths:
|
||||
# Load image either from the web or from a local path.
|
||||
if image_path.startswith("http"):
|
||||
input_image = Image.open(requests.get(image_path, stream=True).raw)
|
||||
else:
|
||||
input_image = Image.open(image_path)
|
||||
|
||||
images.append(input_image)
|
||||
|
||||
logger.info(f"[Config file]: {config_file}")
|
||||
logger.info(f"[Intent]: {intent}")
|
||||
|
||||
agent.reset(config_file)
|
||||
trajectory: Trajectory = []
|
||||
obs, info = env.reset(options={"config_file": config_file})
|
||||
state_info: StateInfo = {"observation": obs, "info": info}
|
||||
trajectory.append(state_info)
|
||||
|
||||
meta_data = {"action_history": ["None"]}
|
||||
while True:
|
||||
early_stop_flag, stop_info = early_stop(
|
||||
trajectory, max_steps, early_stop_thresholds
|
||||
)
|
||||
|
||||
if early_stop_flag:
|
||||
action = create_stop_action(f"Early stop: {stop_info}")
|
||||
else:
|
||||
try:
|
||||
action = agent.next_action(
|
||||
trajectory,
|
||||
intent,
|
||||
images=images,
|
||||
meta_data=meta_data,
|
||||
)
|
||||
except ValueError as e:
|
||||
# get the error message
|
||||
action = create_stop_action(f"ERROR: {str(e)}")
|
||||
|
||||
trajectory.append(action)
|
||||
|
||||
action_str = get_action_description(
|
||||
action,
|
||||
state_info["info"]["observation_metadata"],
|
||||
action_set_tag=args.action_set_tag,
|
||||
prompt_constructor=agent.prompt_constructor
|
||||
if isinstance(agent, PromptAgent)
|
||||
else None,
|
||||
)
|
||||
render_helper.render(
|
||||
action, state_info, meta_data, args.render_screenshot
|
||||
)
|
||||
meta_data["action_history"].append(action_str)
|
||||
|
||||
if action["action_type"] == ActionTypes.STOP:
|
||||
break
|
||||
|
||||
obs, _, terminated, _, info = env.step(action)
|
||||
state_info = {"observation": obs, "info": info}
|
||||
trajectory.append(state_info)
|
||||
|
||||
if terminated:
|
||||
# add a action place holder
|
||||
trajectory.append(create_stop_action(""))
|
||||
break
|
||||
|
||||
# NOTE: eval_caption_image_fn is used for running eval_vqa functions.
|
||||
evaluator = evaluator_router(
|
||||
config_file, captioning_fn=eval_caption_image_fn
|
||||
)
|
||||
score = evaluator(
|
||||
trajectory=trajectory,
|
||||
config_file=config_file,
|
||||
page=env.page,
|
||||
client=env.get_page_client(env.page),
|
||||
)
|
||||
|
||||
scores.append(score)
|
||||
|
||||
if score == 1:
|
||||
logger.info(f"[Result] (PASS) {config_file}")
|
||||
else:
|
||||
logger.info(f"[Result] (FAIL) {config_file}")
|
||||
|
||||
if args.save_trace_enabled:
|
||||
env.save_trace(
|
||||
Path(args.result_dir) / "traces" / f"{task_id}.zip"
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
logger.info(f"[OpenAI Error] {repr(e)}")
|
||||
except Exception as e:
|
||||
logger.info(f"[Unhandled Error] {repr(e)}]")
|
||||
import traceback
|
||||
|
||||
# write to error file
|
||||
with open(Path(args.result_dir) / "error.txt", "a") as f:
|
||||
f.write(f"[Config file]: {config_file}\n")
|
||||
f.write(f"[Unhandled Error] {repr(e)}\n")
|
||||
f.write(traceback.format_exc()) # write stack trace to file
|
||||
|
||||
render_helper.close()
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def prepare(args: argparse.Namespace) -> None:
|
||||
# convert prompt python files to json
|
||||
from agent.prompts import to_json
|
||||
|
||||
to_json.run()
|
||||
|
||||
# prepare result dir
|
||||
result_dir = args.result_dir
|
||||
if not result_dir:
|
||||
result_dir = (
|
||||
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
|
||||
)
|
||||
if not Path(result_dir).exists():
|
||||
Path(result_dir).mkdir(parents=True, exist_ok=True)
|
||||
args.result_dir = result_dir
|
||||
logger.info(f"Create result dir: {result_dir}")
|
||||
|
||||
if not (Path(result_dir) / "traces").exists():
|
||||
(Path(result_dir) / "traces").mkdir(parents=True)
|
||||
|
||||
# log the log file
|
||||
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
|
||||
f.write(f"{LOG_FILE_NAME}\n")
|
||||
|
||||
|
||||
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
|
||||
result_files = glob.glob(f"{result_dir}/*.html")
|
||||
task_ids = [
|
||||
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
|
||||
]
|
||||
unfinished_configs = []
|
||||
for config_file in config_files:
|
||||
task_id = os.path.basename(config_file).split(".")[0]
|
||||
if task_id not in task_ids:
|
||||
unfinished_configs.append(config_file)
|
||||
return unfinished_configs
|
||||
|
||||
|
||||
@beartype
|
||||
def dump_config(args: argparse.Namespace) -> None:
|
||||
config_file = Path(args.result_dir) / "config.json"
|
||||
if not config_file.exists():
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
logger.info(f"Dump config to {config_file}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
args.sleep_after_execution = 2.5
|
||||
args.sleep_after_execution = 5
|
||||
prepare(args)
|
||||
|
||||
test_config_base_dir = args.test_config_base_dir
|
||||
|
||||
test_file_list = []
|
||||
st_idx = args.test_start_idx
|
||||
ed_idx = args.test_end_idx
|
||||
for i in range(st_idx, ed_idx):
|
||||
test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json"))
|
||||
test_file_list = get_unfinished(test_file_list, args.result_dir)
|
||||
print(f"Total {len(test_file_list)} tasks left")
|
||||
args.render = False
|
||||
args.render_screenshot = True
|
||||
args.save_trace_enabled = True
|
||||
|
||||
args.current_viewport_only = True
|
||||
dump_config(args)
|
||||
|
||||
test(args, test_file_list)
|
||||
|
||||
# todo: add recorder of the progress of the examples
|
||||
|
||||
# todo: remove the useless example files
|
||||
|
||||
@@ -169,7 +169,7 @@ def parse_code_from_som_string(input_string, masks):
|
||||
return actions
|
||||
|
||||
|
||||
class GPT4v_Agent:
|
||||
class PromptAgent:
|
||||
def __init__(
|
||||
self,
|
||||
api_key,
|
||||
Reference in New Issue
Block a user