336 lines
12 KiB
Python
336 lines
12 KiB
Python
# todo: unifiy all the experiments python file into one file
|
|
import argparse
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
import func_timeout
|
|
|
|
from desktop_env.envs.desktop_env import DesktopEnv
|
|
from mm_agents.gpt_4v_agent import GPT4v_Agent # todo: change the name into PromptAgent
|
|
|
|
# Logger Configs {{{ #
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
|
|
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
|
|
file_handler.setLevel(logging.INFO)
|
|
debug_handler.setLevel(logging.DEBUG)
|
|
stdout_handler.setLevel(logging.INFO)
|
|
sdebug_handler.setLevel(logging.DEBUG)
|
|
|
|
formatter = logging.Formatter(
|
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
file_handler.setFormatter(formatter)
|
|
debug_handler.setFormatter(formatter)
|
|
stdout_handler.setFormatter(formatter)
|
|
sdebug_handler.setFormatter(formatter)
|
|
|
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
|
|
logger.addHandler(file_handler)
|
|
logger.addHandler(debug_handler)
|
|
logger.addHandler(stdout_handler)
|
|
logger.addHandler(sdebug_handler)
|
|
# }}} Logger Configs #
|
|
|
|
logger = logging.getLogger("desktopenv.experiment")
|
|
|
|
# todo: move the PATH_TO_VM to the argparser
|
|
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
|
|
|
|
|
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
|
|
max_time=600):
|
|
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
env = DesktopEnv(
|
|
path_to_vm=PATH_TO_VM,
|
|
action_space=agent.action_space,
|
|
task_config=example,
|
|
headless=True
|
|
)
|
|
# reset the environment to certain snapshot
|
|
observation = env.reset()
|
|
done = False
|
|
step_num = 0
|
|
|
|
if recording:
|
|
# send a request to the server to start recording
|
|
env.controller.start_recording()
|
|
|
|
while not done and step_num < max_steps:
|
|
actions = agent.predict(observation)
|
|
step_num += 1
|
|
for action in actions:
|
|
# Capture the timestamp before executing the action
|
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
logger.info("Step %d: %s", step_num, action)
|
|
|
|
observation, reward, done, info = env.step(action)
|
|
|
|
logger.info("Reward: %.2f", reward)
|
|
logger.info("Done: %s", done)
|
|
logger.info("Info: %s", info)
|
|
|
|
# Save screenshot and trajectory information
|
|
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
with open(observation['screenshot'], "rb") as __f:
|
|
screenshot = __f.read()
|
|
_f.write(screenshot)
|
|
|
|
with open(trajectory_recording_path, "a") as f:
|
|
f.write(json.dumps({
|
|
"step_num": step_num,
|
|
"action_timestamp": action_timestamp,
|
|
"action": action,
|
|
"reward": reward,
|
|
"done": done,
|
|
"info": info,
|
|
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
}))
|
|
f.write("\n")
|
|
|
|
if done:
|
|
logger.info("The episode is done.")
|
|
break
|
|
|
|
def stop_recording():
|
|
try:
|
|
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
except Exception as e:
|
|
print(f"An error occurred while stopping the recording: {e}")
|
|
|
|
try:
|
|
func_timeout.func_timeout(120, stop_recording)
|
|
# todo: make sure we got the video file, check the bug
|
|
except func_timeout.exceptions.FunctionTimedOut:
|
|
logger.info("Recording timed out.")
|
|
|
|
result = env.evaluate()
|
|
logger.info("Result: %.2f", result)
|
|
|
|
# fixme: change to write the result into a separate file
|
|
with open(trajectory_recording_path, "a") as f:
|
|
f.write(json.dumps({
|
|
"result": result
|
|
}))
|
|
f.write("\n")
|
|
|
|
# todo: append the result to the wandb for visualization
|
|
|
|
# env.close()
|
|
logger.info("Environment closed.")
|
|
|
|
|
|
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
|
# todo: merge the main function into the run_one_example function
|
|
# fixme: change all the settings like action_space, model, etc. to the argparser
|
|
action_space = "pyautogui"
|
|
gemini_model = "gemini-pro-vision"
|
|
|
|
logger.info("Running example %s/%s", example_class, example_id)
|
|
logger.info("Using model %s", gpt4_model)
|
|
# logger.info("Using model %s", gemini_model)
|
|
|
|
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
example = json.load(f)
|
|
example["snapshot"] = "exp_v5"
|
|
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
agent = GPT4v_Agent(api_key=api_key,
|
|
model=gpt4_model,
|
|
instruction=example['instruction'],
|
|
action_space=action_space,
|
|
exp="screenshot")
|
|
#
|
|
# api_key = os.environ.get("GENAI_API_KEY")
|
|
# agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space, exp="screenshot")
|
|
|
|
root_trajectory_dir = "exp_trajectory"
|
|
|
|
example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gpt4_model, example_id)
|
|
# example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gemini_model, example_id)
|
|
|
|
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
|
|
if os.path.exists(os.path.join(example_trajectory_dir, "trajectory.json")):
|
|
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "r") as f:
|
|
lines = f.readlines()
|
|
# strip the last line if it is empty
|
|
lines = [line.strip() for line in lines if line.strip() != ""]
|
|
if len(lines) > 0:
|
|
last_line = json.loads(lines[-1])
|
|
if "result" in last_line:
|
|
logger.info(
|
|
f"evaluation_examples/examples/{example_class}/{example_id}.json" + "has been evaluated. Skip.")
|
|
return
|
|
|
|
try:
|
|
func_timeout.func_timeout(1200, run_one_example, args=(example, agent, 15, example_trajectory_dir))
|
|
except Exception as e:
|
|
print(f"An error occurred: {e}")
|
|
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "a") as f:
|
|
f.write(json.dumps({
|
|
"error": str(e)
|
|
}))
|
|
|
|
|
|
def config() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Run end-to-end evaluation on the benchmark"
|
|
)
|
|
parser.add_argument(
|
|
"--render", action="store_true", help="Render the browser"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--slow_mo",
|
|
type=int,
|
|
default=0,
|
|
help="Slow down the browser by the specified amount",
|
|
)
|
|
parser.add_argument(
|
|
"--action_set_tag", default="id_accessibility_tree", help="Action type"
|
|
)
|
|
parser.add_argument(
|
|
"--observation_type",
|
|
choices=[
|
|
"accessibility_tree",
|
|
"accessibility_tree_with_captioner",
|
|
"html",
|
|
"image",
|
|
"image_som",
|
|
],
|
|
default="accessibility_tree",
|
|
help="Observation type",
|
|
)
|
|
parser.add_argument(
|
|
"--current_viewport_only",
|
|
action="store_true",
|
|
help="Only use the current viewport for the observation",
|
|
)
|
|
parser.add_argument("--viewport_width", type=int, default=1280)
|
|
parser.add_argument("--viewport_height", type=int, default=2048)
|
|
parser.add_argument("--save_trace_enabled", action="store_true")
|
|
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
|
|
|
parser.add_argument("--max_steps", type=int, default=30)
|
|
|
|
# agent config
|
|
parser.add_argument("--agent_type", type=str, default="prompt")
|
|
parser.add_argument(
|
|
"--instruction_path",
|
|
type=str,
|
|
default="agents/prompts/state_action_agent.json",
|
|
)
|
|
parser.add_argument(
|
|
"--parsing_failure_th",
|
|
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
|
|
type=int,
|
|
default=3,
|
|
)
|
|
parser.add_argument(
|
|
"--repeating_action_failure_th",
|
|
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
|
|
type=int,
|
|
default=5,
|
|
)
|
|
|
|
parser.add_argument("--test_config_base_dir", type=str)
|
|
|
|
parser.add_argument(
|
|
"--eval_captioning_model_device",
|
|
type=str,
|
|
default="cpu",
|
|
choices=["cpu", "cuda"],
|
|
help="Device to run eval captioning model on. By default, runs it on CPU.",
|
|
)
|
|
parser.add_argument(
|
|
"--eval_captioning_model",
|
|
type=str,
|
|
default="Salesforce/blip2-flan-t5-xl",
|
|
choices=["Salesforce/blip2-flan-t5-xl"],
|
|
help="Captioning backbone for VQA-type evals.",
|
|
)
|
|
parser.add_argument(
|
|
"--captioning_model",
|
|
type=str,
|
|
default="Salesforce/blip2-flan-t5-xl",
|
|
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
|
|
help="Captioning backbone for accessibility tree alt text.",
|
|
)
|
|
|
|
# lm config
|
|
parser.add_argument("--provider", type=str, default="openai")
|
|
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
|
parser.add_argument("--mode", type=str, default="chat")
|
|
parser.add_argument("--temperature", type=float, default=1.0)
|
|
parser.add_argument("--top_p", type=float, default=0.9)
|
|
parser.add_argument("--context_length", type=int, default=0)
|
|
parser.add_argument("--max_tokens", type=int, default=384)
|
|
parser.add_argument("--stop_token", type=str, default=None)
|
|
parser.add_argument(
|
|
"--max_retry",
|
|
type=int,
|
|
help="max retry times to perform generations when parsing fails",
|
|
default=1,
|
|
)
|
|
parser.add_argument(
|
|
"--max_obs_length",
|
|
type=int,
|
|
help="when not zero, will truncate the observation to this length before feeding to the model",
|
|
default=3840,
|
|
)
|
|
|
|
# example config
|
|
parser.add_argument("--test_start_idx", type=int, default=0)
|
|
parser.add_argument("--test_end_idx", type=int, default=910)
|
|
|
|
# logging related
|
|
parser.add_argument("--result_dir", type=str, default="")
|
|
args = parser.parse_args()
|
|
|
|
# check the whether the action space is compatible with the observation space
|
|
if (
|
|
args.action_set_tag == "id_accessibility_tree"
|
|
and args.observation_type
|
|
not in [
|
|
"accessibility_tree",
|
|
"accessibility_tree_with_captioner",
|
|
"image_som",
|
|
]
|
|
):
|
|
raise ValueError(
|
|
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
|
|
)
|
|
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
####### The complete version of the list of examples #######
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
args = config()
|
|
args.sleep_after_execution = 2.5
|
|
prepare(args)
|
|
|
|
# todo: add recorder of the progress of the examples
|
|
|
|
# todo: remove the useless example files
|
|
|
|
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
|
test_all_meta = json.load(f)
|
|
|
|
for domain in test_all_meta:
|
|
for example_id in test_all_meta[domain]:
|
|
main(domain, example_id, args.model)
|