Files
sci-gui-agent-benchmark/experiment_screenshot.py
2024-03-14 13:16:49 +08:00

336 lines
12 KiB
Python

# todo: unifiy all the experiments python file into one file
import argparse
import datetime
import json
import logging
import os
import sys
import func_timeout
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.gpt_4v_agent import GPT4v_Agent # todo: change the name into PromptAgent
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
# todo: move the PATH_TO_VM to the argparser
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
max_time=600):
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
env = DesktopEnv(
path_to_vm=PATH_TO_VM,
action_space=agent.action_space,
task_config=example,
headless=True
)
# reset the environment to certain snapshot
observation = env.reset()
done = False
step_num = 0
if recording:
# send a request to the server to start recording
env.controller.start_recording()
while not done and step_num < max_steps:
actions = agent.predict(observation)
step_num += 1
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_num, action)
observation, reward, done, info = env.step(action)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"step_num": step_num,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
def stop_recording():
try:
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
except Exception as e:
print(f"An error occurred while stopping the recording: {e}")
try:
func_timeout.func_timeout(120, stop_recording)
# todo: make sure we got the video file, check the bug
except func_timeout.exceptions.FunctionTimedOut:
logger.info("Recording timed out.")
result = env.evaluate()
logger.info("Result: %.2f", result)
# fixme: change to write the result into a separate file
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"result": result
}))
f.write("\n")
# todo: append the result to the wandb for visualization
# env.close()
logger.info("Environment closed.")
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
# todo: merge the main function into the run_one_example function
# fixme: change all the settings like action_space, model, etc. to the argparser
action_space = "pyautogui"
gemini_model = "gemini-pro-vision"
logger.info("Running example %s/%s", example_class, example_id)
logger.info("Using model %s", gpt4_model)
# logger.info("Using model %s", gemini_model)
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
example = json.load(f)
example["snapshot"] = "exp_v5"
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key,
model=gpt4_model,
instruction=example['instruction'],
action_space=action_space,
exp="screenshot")
#
# api_key = os.environ.get("GENAI_API_KEY")
# agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space, exp="screenshot")
root_trajectory_dir = "exp_trajectory"
example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gpt4_model, example_id)
# example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gemini_model, example_id)
os.makedirs(example_trajectory_dir, exist_ok=True)
if os.path.exists(os.path.join(example_trajectory_dir, "trajectory.json")):
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "r") as f:
lines = f.readlines()
# strip the last line if it is empty
lines = [line.strip() for line in lines if line.strip() != ""]
if len(lines) > 0:
last_line = json.loads(lines[-1])
if "result" in last_line:
logger.info(
f"evaluation_examples/examples/{example_class}/{example_id}.json" + "has been evaluated. Skip.")
return
try:
func_timeout.func_timeout(1200, run_one_example, args=(example, agent, 15, example_trajectory_dir))
except Exception as e:
print(f"An error occurred: {e}")
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "a") as f:
f.write(json.dumps({
"error": str(e)
}))
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
parser.add_argument(
"--render", action="store_true", help="Render the browser"
)
parser.add_argument(
"--slow_mo",
type=int,
default=0,
help="Slow down the browser by the specified amount",
)
parser.add_argument(
"--action_set_tag", default="id_accessibility_tree", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=[
"accessibility_tree",
"accessibility_tree_with_captioner",
"html",
"image",
"image_som",
],
default="accessibility_tree",
help="Observation type",
)
parser.add_argument(
"--current_viewport_only",
action="store_true",
help="Only use the current viewport for the observation",
)
parser.add_argument("--viewport_width", type=int, default=1280)
parser.add_argument("--viewport_height", type=int, default=2048)
parser.add_argument("--save_trace_enabled", action="store_true")
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=30)
# agent config
parser.add_argument("--agent_type", type=str, default="prompt")
parser.add_argument(
"--instruction_path",
type=str,
default="agents/prompts/state_action_agent.json",
)
parser.add_argument(
"--parsing_failure_th",
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
type=int,
default=3,
)
parser.add_argument(
"--repeating_action_failure_th",
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
type=int,
default=5,
)
parser.add_argument("--test_config_base_dir", type=str)
parser.add_argument(
"--eval_captioning_model_device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="Device to run eval captioning model on. By default, runs it on CPU.",
)
parser.add_argument(
"--eval_captioning_model",
type=str,
default="Salesforce/blip2-flan-t5-xl",
choices=["Salesforce/blip2-flan-t5-xl"],
help="Captioning backbone for VQA-type evals.",
)
parser.add_argument(
"--captioning_model",
type=str,
default="Salesforce/blip2-flan-t5-xl",
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
help="Captioning backbone for accessibility tree alt text.",
)
# lm config
parser.add_argument("--provider", type=str, default="openai")
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
parser.add_argument("--mode", type=str, default="chat")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--context_length", type=int, default=0)
parser.add_argument("--max_tokens", type=int, default=384)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument(
"--max_retry",
type=int,
help="max retry times to perform generations when parsing fails",
default=1,
)
parser.add_argument(
"--max_obs_length",
type=int,
help="when not zero, will truncate the observation to this length before feeding to the model",
default=3840,
)
# example config
parser.add_argument("--test_start_idx", type=int, default=0)
parser.add_argument("--test_end_idx", type=int, default=910)
# logging related
parser.add_argument("--result_dir", type=str, default="")
args = parser.parse_args()
# check the whether the action space is compatible with the observation space
if (
args.action_set_tag == "id_accessibility_tree"
and args.observation_type
not in [
"accessibility_tree",
"accessibility_tree_with_captioner",
"image_som",
]
):
raise ValueError(
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
)
return args
if __name__ == '__main__':
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
args.sleep_after_execution = 2.5
prepare(args)
# todo: add recorder of the progress of the examples
# todo: remove the useless example files
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
for domain in test_all_meta:
for example_id in test_all_meta[domain]:
main(domain, example_id, args.model)