Refactor experiments and agent implementation

This commit is contained in:
Timothyxxx
2024-03-14 22:32:49 +08:00
parent 71ca8fbe1c
commit 44ff027801
8 changed files with 359 additions and 1944 deletions

View File

@@ -53,8 +53,8 @@ class DesktopEnv(gym.Env):
def __init__(
self,
path_to_vm: str,
snapshot_name: str ="init_state",
action_space: str = "computer_13",
task_config: Dict[str, Any] = None,
tmp_dir: str = "tmp",
cache_dir: str = "cache",
screen_size: Tuple[int] = (1920, 1080),
@@ -64,15 +64,6 @@ class DesktopEnv(gym.Env):
Args:
path_to_vm (str): path to .vmx file
action_space (str): "computer_13" | "pyautogui"
task_config (Dict[str, Any]): manages task configs integratedly,
including
* base snapshot
* task id (uuid)
* instruction
* setup config
* evaluator config
tmp_dir (str): temporary directory to store trajectory stuffs like
the extracted screenshots
cache_dir (str): cache directory to cache task-related stuffs like
@@ -81,6 +72,7 @@ class DesktopEnv(gym.Env):
# Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.snapshot_name = snapshot_name
self.tmp_dir_base: str = tmp_dir
self.cache_dir_base: str = cache_dir
self.vm_screen_size = screen_size
@@ -88,16 +80,12 @@ class DesktopEnv(gym.Env):
os.makedirs(self.tmp_dir_base, exist_ok=True)
# task-aware stuffs
# todo: handling the logic of snapshot directory
self._set_task_info(task_config)
# Initialize emulator and controller
logger.info("Initializing...")
self._start_emulator()
self.vm_ip = self._get_vm_ip()
self.controller = PythonController(vm_ip=self.vm_ip)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
# Meta info of the VM, move to the reset() function
self.vm_platform: str = "" # self.controller.get_vm_platform()
@@ -147,7 +135,7 @@ class DesktopEnv(gym.Env):
raise Exception("Failed to get VM IP address!")
def _save_state(self):
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
def _get_screenshot(self):
# random_uuid = str(uuid.uuid4())
@@ -167,7 +155,6 @@ class DesktopEnv(gym.Env):
return screenshot_image_path
def _set_task_info(self, task_config: Dict[str, Any]):
self.snapshot_path = task_config["snapshot"] # todo: save the snapshot when first start the environment, and then revert to it when reset
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
@@ -239,8 +226,8 @@ class DesktopEnv(gym.Env):
)
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
time.sleep(5)
print(self.vm_screen_size)

View File

@@ -1,432 +0,0 @@
import datetime
import json
import logging
import os
import sys
import func_timeout
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.gpt_4v_agent import GPT4v_Agent
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
env = DesktopEnv(
path_to_vm=PATH_TO_VM,
action_space=agent.action_space,
task_config=example
)
# reset the environment to certain snapshot
observation = env.reset()
done = False
step_num = 0
if recording:
# send a request to the server to start recording
env.controller.start_recording()
while not done and step_num < max_steps:
actions = agent.predict(observation)
step_num += 1
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_num, action)
observation, reward, done, info = env.step(action)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"step_num": step_num,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
def stop_recording():
try:
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
except Exception as e:
print(f"An error occurred while stopping the recording: {e}")
try:
func_timeout.func_timeout(30, stop_recording)
except func_timeout.exceptions.FunctionTimedOut:
logger.info("Recording timed out.")
result = env.evaluate()
logger.info("Result: %.2f", result)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"result": result
}))
f.write("\n")
# env.close()
logger.info("Environment closed.")
def main(example_class, example_id, gpt4_model="gpt-4-0125-preview"):
action_space = "pyautogui"
gemini_model = "gemini-pro-vision"
logger.info("Running example %s/%s", example_class, example_id)
logger.info("Using model %s", gpt4_model)
# logger.info("Using model %s", gemini_model)
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
example = json.load(f)
example["snapshot"] = "exp_v5"
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'], max_tokens=1000,
action_space=action_space, exp="a11y_tree")
# api_key = os.environ.get("GENAI_API_KEY")
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="a11y_tree")
root_trajectory_dir = "exp_trajectory"
example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gpt4_model, example_id)
# example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gemini_model, example_id)
os.makedirs(example_trajectory_dir, exist_ok=True)
run_one_example(example, agent, 15, example_trajectory_dir)
if __name__ == '__main__':
os_list = [
"94d95f96-9699-4208-98ba-3c3119edf9c2",
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
"e0df059f-28a6-4169-924f-b9623e7184cc",
"ddc75b62-7311-4af8-bfb3-859558542b36",
"b6781586-6346-41cd-935a-a6b1487918fc",
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
"13584542-872b-42d8-b299-866967b5c3ef",
"23393935-50c7-4a86-aeea-2b78fd089c5c"
]
# for example_id in os_list:
# try:
# main("os", example_id, gpt4_model="gpt-3.5-turbo-16k")
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
vlc_list = [
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
"8f080098-ddb1-424c-b438-4e96e5e4786e",
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
"fba2c100-79e8-42df-ae74-b592418d54f4",
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
]
chrome_list = [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"06fe7178-4491-4589-810f-2e2bc9502122",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"af630914-714e-4a24-a7bb-f9af687d3b91"
]
calc_list = [
"eb03d19a-b88d-4de4-8a64-ca0ac66f426b",
"0bf05a7d-b28b-44d2-955a-50b41e24012a",
"7a4e4bc8-922c-4c84-865c-25ba34136be1",
"2bd59342-0664-4ccb-ba87-79379096cc08",
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
"7efeb4b1-3d19-4762-b163-63328d66303b",
"4e6fcf72-daf3-439f-a232-c434ce416af6",
"6054afcb-5bab-4702-90a0-b259b5d3217c",
"abed40dc-063f-4598-8ba5-9fe749c0615d",
"01b269ae-2111-4a07-81fd-3fcd711993b0",
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
"51b11269-2ca8-4b2a-9163-f21758420e78",
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
"3aaa4e37-dc91-482e-99af-132a612d40f3",
"37608790-6147-45d0-9f20-1137bb35703d",
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
"d681960f-7bc3-4286-9913-a8812ba3261a",
"21df9241-f8d7-4509-b7f1-37e501a823f7",
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
"a01fbce3-2793-461f-ab86-43680ccbae25",
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c",
]
# for example_id in calc_list:
# main("libreoffice_calc", example_id)
impress_list = [
"5d901039-a89c-4bfb-967b-bf66f4df075e",
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
"c59742c0-4323-4b9d-8a02-723c251deaa0",
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
"0f84bef9-9790-432e-92b7-eece357603fb",
"ce88f674-ab7a-43da-9201-468d38539e4a",
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
"a097acff-6266-4291-9fbd-137af7ecd439",
"bf4e9888-f10f-47af-8dba-76413038b73c",
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
]
# for example_id in impress_list:
# main("libreoffice_impress", example_id)
thunderbird_list = [
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"12086550-11c0-466b-b367-1d9e75b3910e",
"06fe7178-4491-4589-810f-2e2bc9502122",
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"030eeff7-b492-4218-b312-701ec99ee0cc",
"94760984-3ff5-41ee-8347-cf1af709fea0",
"99146c54-4f37-4ab8-9327-5f3291665e1e",
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
]
# for example_id in thunderbird_list:
# main("thunderbird", example_id)
gimp_list = [
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
"554785e9-4523-4e7a-b8e1-8016f565f56a",
"77b8ab4d-994f-43ac-8930-8ca087d7c4b4",
"f4aec372-4fb0-4df5-a52b-79e0e2a5d6ce",
"d52d6308-ec58-42b7-a2c9-de80e4837b2b",
"2a729ded-3296-423d-aec4-7dd55ed5fbb3",
"b148e375-fe0b-4bec-90e7-38632b0d73c2",
"a746add2-cab0-4740-ac36-c3769d9bfb46",
"7b7617bd-57cc-468e-9c91-40c4ec2bcb3d",
"d16c99dc-2a1e-46f2-b350-d97c86c85c15",
"06ca5602-62ca-47f6-ad4f-da151cde54cc",
"e2dd0213-26db-4349-abe5-d5667bfd725c",
"f723c744-e62c-4ae6-98d1-750d3cd7d79d",
"72f83cdc-bf76-4531-9a1b-eb893a13f8aa",
"7767eef2-56a3-4cea-8c9f-48c070c7d65b",
"734d6579-c07d-47a8-9ae2-13339795476b"
]
# for example_id in gimp_list:
# try:
# main("gimp", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
vs_code_list = [
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
"eabc805a-bfcf-4460-b250-ac92135819f6",
"982d12a5-beab-424f-8d38-d2a48429e511",
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
"9439a27b-18ae-42d8-9778-5f68f891805e",
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
"930fdb3b-11a8-46fe-9bac-577332e2640e",
"276cc624-87ea-4f08-ab93-f770e3790175",
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
]
# for example_id in vs_code_list:
# try:
# main("vs_code", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
from tqdm import tqdm
# for example_id in tqdm(vlc_list):
# try:
# main("vlc", example_id, gpt4_model="gpt-3.5-turbo-16k")
# except Exception as e:
# print(f"An error occurred while running the example: {e}")
# continue
chrome_list = [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"06fe7178-4491-4589-810f-2e2bc9502122",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"af630914-714e-4a24-a7bb-f9af687d3b91"
]
# for example_id in tqdm(chrome_list):
# try:
# main("chrome", example_id, gpt4_model="gpt-3.5-turbo-16k")
# except Exception as e:
# print(f"An error occurred while running the example: {e}")
# continue
vs_code_list = [
# "0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
# "53ad5833-3455-407b-bbc6-45b4c79ab8fb",
# "eabc805a-bfcf-4460-b250-ac92135819f6",
# "982d12a5-beab-424f-8d38-d2a48429e511",
# "4e60007a-f5be-4bfc-9723-c39affa0a6d3",
# "e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
# "9439a27b-18ae-42d8-9778-5f68f891805e",
# "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
# "930fdb3b-11a8-46fe-9bac-577332e2640e",
# "276cc624-87ea-4f08-ab93-f770e3790175",
# "9d425400-e9b2-4424-9a4b-d4c7abac4140"
]
for example_id in tqdm(vs_code_list):
try:
main("vs_code", example_id, gpt4_model="gpt-3.5-turbo-16k")
except Exception as e:
print(f"An error occurred while running the example: {e}")
continue
thunderbird_list = [
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"12086550-11c0-466b-b367-1d9e75b3910e",
"06fe7178-4491-4589-810f-2e2bc9502122",
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"030eeff7-b492-4218-b312-701ec99ee0cc",
"94760984-3ff5-41ee-8347-cf1af709fea0",
"99146c54-4f37-4ab8-9327-5f3291665e1e",
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
]
# for example_id in tqdm(thunderbird_list):
# try:
# main("thunderbird", example_id, gpt4_model="gpt-3.5-turbo-16k")
# except Exception as e:
# print(f"An error occurred while running the example: {e}")
# continue
multiple_list = [
# "f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
# "897e3b53-5d4d-444b-85cb-2cdc8a97d903",
"2fe4b718-3bd7-46ec-bdce-b184f5653624",
"3680a5ee-6870-426a-a997-eba929a0d25c",
# "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
# "b52b40a5-ad70-4c53-b5b0-5650a8387052",
# "46407397-a7d5-4c6b-92c6-dbe038b1457b",
# "2b9493d7-49b8-493a-a71b-56cd1f4d6908",
# "51f5801c-18b3-4f25-b0c3-02f85507a078",
"58565672-7bfe-48ab-b828-db349231de6b",
# "2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
# "510f64c8-9bcc-4be1-8d30-638705850618",
# "937087b6-f668-4ba6-9110-60682ee33441",
# "ee9a3c83-f437-4879-8918-be5efbb9fac7",
# "3680a5ee-6870-426a-a997-eba929a0d25c",
# "e135df7c-7687-4ac0-a5f0-76b74438b53e",
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
# "58565672-7bfe-48ab-b828-db349231de6b",
# "2fe4b718-3bd7-46ec-bdce-b184f5653624"
]
for example_id in multiple_list:
try:
main("multi_apps", example_id, gpt4_model="gpt-3.5-turbo-16k")
except Exception as e:
logger.error("An error occurred while running the example: %s", e)
continue

View File

@@ -1,604 +0,0 @@
"""Script to run end-to-end evaluation on the benchmark.
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import datetime
import sys
import func_timeout
import argparse
import glob
import json
import logging
import os
import time
from pathlib import Path
import openai
import requests
import torch
from beartype import beartype
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent # todo: change the name into PromptAgent
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
max_time=600):
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
env = DesktopEnv(
path_to_vm=PATH_TO_VM,
action_space=agent.action_space,
task_config=example,
headless=True
)
# reset the environment to certain snapshot
observation = env.reset()
done = False
step_num = 0
if recording:
# send a request to the server to start recording
env.controller.start_recording()
while not done and step_num < max_steps:
actions = agent.predict(observation)
step_num += 1
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_num, action)
observation, reward, done, info = env.step(action)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"step_num": step_num,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
def stop_recording():
try:
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
except Exception as e:
print(f"An error occurred while stopping the recording: {e}")
try:
func_timeout.func_timeout(120, stop_recording)
# todo: make sure we got the video file, check the bug
except func_timeout.exceptions.FunctionTimedOut:
logger.info("Recording timed out.")
result = env.evaluate()
logger.info("Result: %.2f", result)
# fixme: change to write the result into a separate file
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"result": result
}))
f.write("\n")
# todo: append the result to the wandb for visualization
# env.close()
logger.info("Environment closed.")
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
# todo: merge the main function into the run_one_example function
# fixme: change all the settings like action_space, model, etc. to the argparser
action_space = "pyautogui"
gemini_model = "gemini-pro-vision"
logger.info("Running example %s/%s", example_class, example_id)
logger.info("Using model %s", gpt4_model)
# logger.info("Using model %s", gemini_model)
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
example = json.load(f)
example["snapshot"] = "exp_v5"
api_key = os.environ.get("OPENAI_API_KEY")
agent = PromptAgent(
api_key=api_key,
model=gpt4_model,
instruction=example['instruction'],
action_space=action_space,
exp="screenshot"
)
root_trajectory_dir = "exp_trajectory"
example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gpt4_model, example_id)
# example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gemini_model, example_id)
os.makedirs(example_trajectory_dir, exist_ok=True)
if os.path.exists(os.path.join(example_trajectory_dir, "trajectory.json")):
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "r") as f:
lines = f.readlines()
# strip the last line if it is empty
lines = [line.strip() for line in lines if line.strip() != ""]
if len(lines) > 0:
last_line = json.loads(lines[-1])
if "result" in last_line:
logger.info(
f"evaluation_examples/examples/{example_class}/{example_id}.json" + "has been evaluated. Skip.")
return
try:
func_timeout.func_timeout(1200, run_one_example, args=(example, agent, 15, example_trajectory_dir))
except Exception as e:
print(f"An error occurred: {e}")
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "a") as f:
f.write(json.dumps({
"error": str(e)
}))
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default="Ubuntu\\Ubuntu.vmx")
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
parser.add_argument(
"--observation_type",
choices=[
"screenshot",
"a11y_tree",
"screenshot_a11y_tree",
"som"
],
default="accessibility_tree",
help="Observation type",
)
# parser.add_argument(
# "--current_viewport_only",
# action="store_true",
# help="Only use the current viewport for the observation",
# )
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--save_trace_enabled", action="store_true")
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=30)
# agent config
parser.add_argument("--agent_type", type=str, default="prompt")
parser.add_argument(
"--instruction_path",
type=str,
default="",
)
parser.add_argument(
"--parsing_failure_th",
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
type=int,
default=3,
)
parser.add_argument(
"--repeating_action_failure_th",
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
type=int,
default=5,
)
parser.add_argument("--test_config_base_dir", type=str)
# lm config
parser.add_argument("--provider", type=str, default="openai")
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
parser.add_argument("--mode", type=str, default="chat")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--context_length", type=int, default=0)
parser.add_argument("--max_tokens", type=int, default=384)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument(
"--max_retry",
type=int,
help="max retry times to perform generations when parsing fails",
default=1,
)
parser.add_argument(
"--max_obs_length",
type=int,
help="when not zero, will truncate the observation to this length before feeding to the model",
default=3840,
)
# example config
parser.add_argument("--test_start_idx", type=int, default=0)
parser.add_argument("--test_end_idx", type=int, default=378)
# logging related
parser.add_argument("--result_dir", type=str, default="")
args = parser.parse_args()
return args
@beartype
def early_stop(
trajectory, max_steps: int, thresholds: dict[str, int]
) -> tuple[bool, str]:
"""Check whether need to stop early"""
# reach the max step
num_steps = (len(trajectory) - 1) / 2
if num_steps >= max_steps:
return True, f"Reach max steps {max_steps}"
# Case: parsing failure for k times
k = thresholds["parsing_failure"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
if len(last_k_actions) >= k:
if all(
[
action["action_type"] == ""
for action in last_k_actions
]
):
return True, f"Failed to parse actions for {k} times"
# Case: same action for k times
k = thresholds["repeating_action"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
action_seq = trajectory[1::2] # type: ignore[assignment]
if len(action_seq) == 0:
return False, ""
last_action = action_seq[-1]
if last_action["action_type"] != ActionTypes.TYPE:
if len(last_k_actions) >= k:
if all(
[
is_equivalent(action, last_action)
for action in last_k_actions
]
):
return True, f"Same action for {k} times"
else:
# check the action sequence
if (
sum([is_equivalent(action, last_action) for action in action_seq])
>= k
):
return True, f"Same typing action for {k} times"
return False, ""
@beartype
def test(
args: argparse.Namespace,
config_file_list: list[str]
) -> None:
scores = []
max_steps = args.max_steps
early_stop_thresholds = {
"parsing_failure": args.parsing_failure_th,
"repeating_action": args.repeating_action_failure_th,
}
if args.observation_type in [
"accessibility_tree_with_captioner",
"image_som",
]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
caption_image_fn = image_utils.get_captioning_fn(
device, dtype, args.captioning_model
)
else:
caption_image_fn = None
# Load a (possibly different) captioning model for running VQA evals.
if (
caption_image_fn
and args.eval_captioning_model == args.captioning_model
):
eval_caption_image_fn = caption_image_fn
else:
eval_caption_image_fn = image_utils.get_captioning_fn(
args.eval_captioning_model_device,
torch.float16
if (
torch.cuda.is_available()
and args.eval_captioning_model_device == "cuda"
)
else torch.float32,
args.eval_captioning_model,
)
agent = construct_agent(
args,
captioning_fn=caption_image_fn
if args.observation_type == "accessibility_tree_with_captioner"
else None,
) # NOTE: captioning_fn here is used for captioning input images.
env = ScriptBrowserEnv(
headless=not args.render,
slow_mo=args.slow_mo,
observation_type=args.observation_type,
current_viewport_only=args.current_viewport_only,
viewport_size={
"width": args.viewport_width,
"height": args.viewport_height,
},
save_trace_enabled=args.save_trace_enabled,
sleep_after_execution=args.sleep_after_execution,
# NOTE: captioning_fn here is used for LLM + captioning baselines.
# This can be different from the captioning model used for evals.
captioning_fn=caption_image_fn,
)
for config_file in config_file_list:
try:
render_helper = RenderHelper(
config_file, args.result_dir, args.action_set_tag
)
# Load task.
with open(config_file) as f:
_c = json.load(f)
intent = _c["intent"]
task_id = _c["task_id"]
image_paths = _c.get("image", None)
images = []
# Load input images for the task, if any.
if image_paths is not None:
if isinstance(image_paths, str):
image_paths = [image_paths]
for image_path in image_paths:
# Load image either from the web or from a local path.
if image_path.startswith("http"):
input_image = Image.open(requests.get(image_path, stream=True).raw)
else:
input_image = Image.open(image_path)
images.append(input_image)
logger.info(f"[Config file]: {config_file}")
logger.info(f"[Intent]: {intent}")
agent.reset(config_file)
trajectory: Trajectory = []
obs, info = env.reset(options={"config_file": config_file})
state_info: StateInfo = {"observation": obs, "info": info}
trajectory.append(state_info)
meta_data = {"action_history": ["None"]}
while True:
early_stop_flag, stop_info = early_stop(
trajectory, max_steps, early_stop_thresholds
)
if early_stop_flag:
action = create_stop_action(f"Early stop: {stop_info}")
else:
try:
action = agent.next_action(
trajectory,
intent,
images=images,
meta_data=meta_data,
)
except ValueError as e:
# get the error message
action = create_stop_action(f"ERROR: {str(e)}")
trajectory.append(action)
action_str = get_action_description(
action,
state_info["info"]["observation_metadata"],
action_set_tag=args.action_set_tag,
prompt_constructor=agent.prompt_constructor
if isinstance(agent, PromptAgent)
else None,
)
render_helper.render(
action, state_info, meta_data, args.render_screenshot
)
meta_data["action_history"].append(action_str)
if action["action_type"] == ActionTypes.STOP:
break
obs, _, terminated, _, info = env.step(action)
state_info = {"observation": obs, "info": info}
trajectory.append(state_info)
if terminated:
# add a action place holder
trajectory.append(create_stop_action(""))
break
# NOTE: eval_caption_image_fn is used for running eval_vqa functions.
evaluator = evaluator_router(
config_file, captioning_fn=eval_caption_image_fn
)
score = evaluator(
trajectory=trajectory,
config_file=config_file,
page=env.page,
client=env.get_page_client(env.page),
)
scores.append(score)
if score == 1:
logger.info(f"[Result] (PASS) {config_file}")
else:
logger.info(f"[Result] (FAIL) {config_file}")
if args.save_trace_enabled:
env.save_trace(
Path(args.result_dir) / "traces" / f"{task_id}.zip"
)
except openai.OpenAIError as e:
logger.info(f"[OpenAI Error] {repr(e)}")
except Exception as e:
logger.info(f"[Unhandled Error] {repr(e)}]")
import traceback
# write to error file
with open(Path(args.result_dir) / "error.txt", "a") as f:
f.write(f"[Config file]: {config_file}\n")
f.write(f"[Unhandled Error] {repr(e)}\n")
f.write(traceback.format_exc()) # write stack trace to file
render_helper.close()
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
def prepare(args: argparse.Namespace) -> None:
# convert prompt python files to json
from agent.prompts import to_json
to_json.run()
# prepare result dir
result_dir = args.result_dir
if not result_dir:
result_dir = (
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
)
if not Path(result_dir).exists():
Path(result_dir).mkdir(parents=True, exist_ok=True)
args.result_dir = result_dir
logger.info(f"Create result dir: {result_dir}")
if not (Path(result_dir) / "traces").exists():
(Path(result_dir) / "traces").mkdir(parents=True)
# log the log file
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
f.write(f"{LOG_FILE_NAME}\n")
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
result_files = glob.glob(f"{result_dir}/*.html")
task_ids = [
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
]
unfinished_configs = []
for config_file in config_files:
task_id = os.path.basename(config_file).split(".")[0]
if task_id not in task_ids:
unfinished_configs.append(config_file)
return unfinished_configs
@beartype
def dump_config(args: argparse.Namespace) -> None:
config_file = Path(args.result_dir) / "config.json"
if not config_file.exists():
with open(config_file, "w") as f:
json.dump(vars(args), f, indent=4)
logger.info(f"Dump config to {config_file}")
if __name__ == '__main__':
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
args.sleep_after_execution = 5
prepare(args)
test_config_base_dir = args.test_config_base_dir
test_file_list = []
st_idx = args.test_start_idx
ed_idx = args.test_end_idx
for i in range(st_idx, ed_idx):
test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json"))
test_file_list = get_unfinished(test_file_list, args.result_dir)
print(f"Total {len(test_file_list)} tasks left")
args.render = False
args.render_screenshot = True
args.save_trace_enabled = True
args.current_viewport_only = True
dump_config(args)
test(args, test_file_list)
# todo: add recorder of the progress of the examples
# todo: remove the useless example files
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
for domain in test_all_meta:
for example_id in test_all_meta[domain]:
main(domain, example_id, args.model)

View File

@@ -1,361 +0,0 @@
import datetime
import json
import logging
import os
import sys
import func_timeout
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.gpt_4v_agent import GPT4v_Agent
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu2\Ubuntu2.vmx"
# PATH_TO_VM = "../../../../大文件/镜像/Ubuntu-1218/Ubuntu/Ubuntu.vmx"
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
env = DesktopEnv(
path_to_vm=PATH_TO_VM,
action_space=agent.action_space,
task_config=example
)
# reset the environment to certain snapshot
observation = env.reset()
done = False
step_num = 0
if recording:
# send a request to the server to start recording
env.controller.start_recording()
while not done and step_num < max_steps:
actions = agent.predict(observation)
step_num += 1
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_num, action)
observation, reward, done, info = env.step(action)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"step_num": step_num,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
def stop_recording():
try:
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
except Exception as e:
print(f"An error occurred while stopping the recording: {e}")
try:
func_timeout.func_timeout(30, stop_recording)
except func_timeout.exceptions.FunctionTimedOut:
logger.info("Recording timed out.")
result = env.evaluate()
logger.info("Result: %.2f", result)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"result": result
}))
f.write("\n")
# env.close()
logger.info("Environment closed.")
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
action_space = "pyautogui"
# example_class = "libreoffice_calc"
# example_id = "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
# example_id = "01b269ae-2111-4a07-81fd-3fcd711993b0"
gemini_model = "gemini-pro-vision"
logger.info("Running example %s/%s", example_class, example_id)
logger.info("Using model %s", gpt4_model)
# logger.info("Using model %s", gemini_model)
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
example = json.load(f)
example["snapshot"] = "exp_v5"
# example["snapshot"] = "exp_setup4"
# example["snapshot"] = "Snapshot 30"
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
action_space=action_space, exp="both")
# api_key = os.environ.get("GENAI_API_KEY")
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="both")
root_trajectory_dir = "exp_trajectory"
example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gpt4_model, example_id)
# example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gemini_model, example_id)
os.makedirs(example_trajectory_dir, exist_ok=True)
run_one_example(example, agent, 15, example_trajectory_dir)
if __name__ == '__main__':
os_list = [
"94d95f96-9699-4208-98ba-3c3119edf9c2",
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
"e0df059f-28a6-4169-924f-b9623e7184cc",
"ddc75b62-7311-4af8-bfb3-859558542b36",
"b6781586-6346-41cd-935a-a6b1487918fc",
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
"13584542-872b-42d8-b299-866967b5c3ef",
"23393935-50c7-4a86-aeea-2b78fd089c5c"
]
# for example_id in os_list:
# try:
# main("os", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
calc_list = [
"a9f325aa-8c05-4e4f-8341-9e4358565f4f",
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
"7efeb4b1-3d19-4762-b163-63328d66303b",
"4e6fcf72-daf3-439f-a232-c434ce416af6",
"6054afcb-5bab-4702-90a0-b259b5d3217c",
"abed40dc-063f-4598-8ba5-9fe749c0615d",
"01b269ae-2111-4a07-81fd-3fcd711993b0",
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
"af2b02f7-acee-4be4-8b66-499fab394915",
"da1d63b8-fa12-417b-ba18-f748e5f770f3",
"636380ea-d5f6-4474-b6ca-b2ed578a20f1",
"5ba77536-05c5-4aae-a9ff-6e298d094c3e",
"4bc4eaf4-ca5e-4db2-8138-8d4e65af7c0b",
"672a1b02-c62f-4ae2-acf0-37f5fb3052b0",
"648fe544-16ba-44af-a587-12ccbe280ea6",
"8985d1e4-5b99-4711-add4-88949ebb2308",
"9e606842-2e27-43bf-b1d1-b43289c9589b",
"fcb6e45b-25c4-4087-9483-03d714f473a9",
"68c0c5b7-96f3-4e87-92a7-6c1b967fd2d2",
"fff629ea-046e-4793-8eec-1a5a15c3eb35",
"5c9a206c-bb00-4fb6-bb46-ee675c187df5",
"e975ae74-79bd-4672-8d1c-dc841a85781d",
"34a6938a-58da-4897-8639-9b90d6db5391",
"b5a22759-b4eb-4bf2-aeed-ad14e8615f19",
"2f9913a1-51ed-4db6-bfe0-7e1c95b3139e",
"2558031e-401d-4579-8e00-3ecf540fb492",
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
"51b11269-2ca8-4b2a-9163-f21758420e78",
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
"3aaa4e37-dc91-482e-99af-132a612d40f3",
"37608790-6147-45d0-9f20-1137bb35703d",
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
"d681960f-7bc3-4286-9913-a8812ba3261a",
"21df9241-f8d7-4509-b7f1-37e501a823f7",
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
"a01fbce3-2793-461f-ab86-43680ccbae25",
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c"
]
# for example_id in calc_list:
# try:
# main("libreoffice_calc", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
impress_list = [
"5d901039-a89c-4bfb-967b-bf66f4df075e",
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
"c59742c0-4323-4b9d-8a02-723c251deaa0",
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
"0f84bef9-9790-432e-92b7-eece357603fb",
"ce88f674-ab7a-43da-9201-468d38539e4a",
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
"a097acff-6266-4291-9fbd-137af7ecd439",
"bf4e9888-f10f-47af-8dba-76413038b73c",
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
]
# for example_id in impress_list:
# try:
# main("libreoffice_impress", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
vs_code_list = [
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
"eabc805a-bfcf-4460-b250-ac92135819f6",
"982d12a5-beab-424f-8d38-d2a48429e511",
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
"9439a27b-18ae-42d8-9778-5f68f891805e",
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
"930fdb3b-11a8-46fe-9bac-577332e2640e",
"276cc624-87ea-4f08-ab93-f770e3790175",
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
]
# for example_id in vs_code_list:
# try:
# main("vs_code", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
multiple_list = [
"f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
"b52b40a5-ad70-4c53-b5b0-5650a8387052",
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
"51f5801c-18b3-4f25-b0c3-02f85507a078",
"2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
"510f64c8-9bcc-4be1-8d30-638705850618",
"937087b6-f668-4ba6-9110-60682ee33441",
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
"3680a5ee-6870-426a-a997-eba929a0d25c",
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
"58565672-7bfe-48ab-b828-db349231de6b",
"2fe4b718-3bd7-46ec-bdce-b184f5653624"
]
# for example_id in multiple_list:
# try:
# main("multi_apps", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
chrome_list = [
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"06fe7178-4491-4589-810f-2e2bc9502122",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"af630914-714e-4a24-a7bb-f9af687d3b91"
]
# for example_id in chrome_list:
# try:
# main("chrome", example_id)
# except Exception as e:
# logger.error("An error occurred while running the example: %s", e)
# continue
writer_list = [
"6ada715d-3aae-4a32-a6a7-429b2e43fb93",
"ecc2413d-8a48-416e-a3a2-d30106ca36cb",
"0e47de2a-32e0-456c-a366-8c607ef7a9d2",
"4bcb1253-a636-4df4-8cb0-a35c04dfef31",
"0810415c-bde4-4443-9047-d5f70165a697",
"e528b65e-1107-4b8c-8988-490e4fece599",
"66399b0d-8fda-4618-95c4-bfc6191617e9",
"936321ce-5236-426a-9a20-e0e3c5dc536f",
"3ef2b351-8a84-4ff2-8724-d86eae9b842e",
"0b17a146-2934-46c7-8727-73ff6b6483e8",
"0e763496-b6bb-4508-a427-fad0b6c3e195",
"f178a4a9-d090-4b56-bc4c-4b72a61a035d",
"adf5e2c3-64c7-4644-b7b6-d2f0167927e7",
"0a0faba3-5580-44df-965d-f562a99b291c",
"e246f6d8-78d7-44ac-b668-fcf47946cb50",
"8472fece-c7dd-4241-8d65-9b3cd1a0b568",
"88fe4b2d-3040-4c70-9a70-546a47764b48",
"d53ff5ee-3b1a-431e-b2be-30ed2673079b",
"72b810ef-4156-4d09-8f08-a0cf57e7cefe",
"6f81754e-285d-4ce0-b59e-af7edb02d108",
"b21acd93-60fd-4127-8a43-2f5178f4a830"
]
for example_id in writer_list:
try:
main("libreoffice_writer", example_id)
except Exception as e:
logger.error("An error occurred while running the example: %s", e)
continue

View File

@@ -1,155 +0,0 @@
import ctypes
import datetime
import json
import logging
import os
import sys
import func_timeout
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.gpt_4v_agent import GPT4v_Agent
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
env = DesktopEnv(
path_to_vm=PATH_TO_VM,
action_space=agent.action_space,
task_config=example
)
# reset the environment to certain snapshot
observation = env.reset()
done = False
step_num = 0
if recording:
# send a request to the server to start recording
env.controller.start_recording()
while not done and step_num < max_steps:
actions = agent.predict(observation)
step_num += 1
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_num, action)
observation, reward, done, info = env.step(action)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"step_num": step_num,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
def stop_recording():
try:
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
except Exception as e:
print(f"An error occurred while stopping the recording: {e}")
try:
func_timeout.func_timeout(30, stop_recording)
except func_timeout.exceptions.FunctionTimedOut:
logger.info("Recording timed out.")
result = env.evaluate()
logger.info("Result: %.2f", result)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"result": result
}))
f.write("\n")
# env.close()
logger.info("Environment closed.")
def main(example_class, example_id):
action_space = "pyautogui"
gpt4_model = "gpt-4-vision-preview"
gemini_model = "gemini-pro-vision"
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
example = json.load(f)
example["snapshot"] = "exp_v5"
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
action_space=action_space, exp="seeact")
# api_key = os.environ.get("GENAI_API_KEY")
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
root_trajectory_dir = "exp_trajectory"
example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gpt4_model, example_id)
# example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gemini_model, example_id)
os.makedirs(example_trajectory_dir, exist_ok=True)
run_one_example(example, agent, 15, example_trajectory_dir)
if __name__ == '__main__':
xx_list = [
]
for example_id in xx_list:
main("xx", example_id)

View File

@@ -1,261 +0,0 @@
#import ctypes
import datetime
import json
import logging
import os
import sys
import func_timeout
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.gpt_4v_agent import GPT4v_Agent
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
env = DesktopEnv(
path_to_vm=PATH_TO_VM,
action_space=agent.action_space,
task_config=example
)
# reset the environment to certain snapshot
observation = env.reset()
done = False
step_num = 0
if recording:
# send a request to the server to start recording
env.controller.start_recording()
while not done and step_num < max_steps:
actions = agent.predict(observation)
step_num += 1
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_num, action)
observation, reward, done, info = env.step(action)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"step_num": step_num,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
def stop_recording():
try:
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
except Exception as e:
print(f"An error occurred while stopping the recording: {e}")
try:
func_timeout.func_timeout(30, stop_recording)
except func_timeout.exceptions.FunctionTimedOut:
logger.info("Recording timed out.")
result = env.evaluate()
logger.info("Result: %.2f", result)
with open(trajectory_recording_path, "a") as f:
f.write(json.dumps({
"result": result
}))
f.write("\n")
# env.close()
logger.info("Environment closed.")
def main(example_class, example_id):
action_space = "pyautogui"
gpt4_model = "gpt-4-vision-preview"
gemini_model = "gemini-pro-vision"
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
example = json.load(f)
example["snapshot"] = "exp_v5"
logger.info("TASK: %s/%s", example_class, example_id)
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, max_tokens=1000, instruction=example['instruction'],
action_space=action_space, exp="som")
# api_key = os.environ.get("GENAI_API_KEY")
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
root_trajectory_dir = "exp_trajectory"
example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gpt4_model, example_id)
# example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gemini_model, example_id)
os.makedirs(example_trajectory_dir, exist_ok=True)
run_one_example(example, agent, 15, example_trajectory_dir)
if __name__ == '__main__':
from tqdm import tqdm
# impress_list = [
# # "5d901039-a89c-4bfb-967b-bf66f4df075e",
# "550ce7e7-747b-495f-b122-acdc4d0b8e54",
# "455d3c66-7dc6-4537-a39a-36d3e9119df7",
# "af23762e-2bfd-4a1d-aada-20fa8de9ce07",
# "c59742c0-4323-4b9d-8a02-723c251deaa0",
# "ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
# "9ec204e4-f0a3-42f8-8458-b772a6797cab",
# "0f84bef9-9790-432e-92b7-eece357603fb",
# "ce88f674-ab7a-43da-9201-468d38539e4a",
# "3b27600c-3668-4abd-8f84-7bcdebbccbdb",
# "a097acff-6266-4291-9fbd-137af7ecd439",
# "bf4e9888-f10f-47af-8dba-76413038b73c",
# "21760ecb-8f62-40d2-8d85-0cee5725cb72"
# ]
# for example_id in impress_list:
# main("libreoffice_impress", example_id)
vlc_list = [
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
"8f080098-ddb1-424c-b438-4e96e5e4786e",
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
"fba2c100-79e8-42df-ae74-b592418d54f4",
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
]
# for example_id in tqdm(vlc_list):
# try:
# main("vlc", example_id)
# except Exception as e:
# print(f"An error occurred while running the example: {e}")
# continue
chrome_list = [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"06fe7178-4491-4589-810f-2e2bc9502122",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"af630914-714e-4a24-a7bb-f9af687d3b91"
]
for example_id in tqdm(chrome_list):
try:
main("chrome", example_id)
except Exception as e:
print(f"An error occurred while running the example: {e}")
continue
vs_code_list = [
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
"eabc805a-bfcf-4460-b250-ac92135819f6",
"982d12a5-beab-424f-8d38-d2a48429e511",
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
"9439a27b-18ae-42d8-9778-5f68f891805e",
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
"930fdb3b-11a8-46fe-9bac-577332e2640e",
"276cc624-87ea-4f08-ab93-f770e3790175",
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
]
for example_id in tqdm(vs_code_list):
try:
main("vs_code", example_id)
except Exception as e:
print(f"An error occurred while running the example: {e}")
continue
thunderbird_list = [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"12086550-11c0-466b-b367-1d9e75b3910e",
"06fe7178-4491-4589-810f-2e2bc9502122",
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
"e1e75309-3ddb-4d09-92ec-de869c928143",
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
"030eeff7-b492-4218-b312-701ec99ee0cc",
"94760984-3ff5-41ee-8347-cf1af709fea0",
"99146c54-4f37-4ab8-9327-5f3291665e1e",
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
]
for example_id in tqdm(thunderbird_list):
try:
main("thunderbird", example_id)
except Exception as e:
print(f"An error occurred while running the example: {e}")
continue

View File

@@ -5,21 +5,20 @@ import os
import re
import time
import uuid
import xml.etree.ElementTree as ET
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
import xml.etree.ElementTree as ET
import backoff
import dashscope
import google.generativeai as genai
import openai
import requests
from PIL import Image
from openai import (
APIConnectionError,
APIError,
RateLimitError
from vertexai.preview.generative_models import (
HarmBlockThreshold,
HarmCategory,
Image,
)
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
@@ -29,7 +28,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
import logging
# todo: cross-check with visualwebarena
logger = logging.getLogger("desktopenv.agent")
@@ -42,7 +40,7 @@ def encode_image(image_path):
def linearize_accessibility_tree(accessibility_tree):
#leaf_nodes = find_leaf_nodes(accessibility_tree)
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
@@ -172,60 +170,56 @@ def parse_code_from_som_string(input_string, masks):
class PromptAgent:
def __init__(
self,
api_key,
instruction,
model="gpt-4-vision-preview",
max_tokens=500,
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="computer_13",
exp="screenshot_a11y_tree"
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
max_trajectory_length=3
):
self.instruction = instruction
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.exp = exp
self.max_trajectory_length = 3
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.thoughts = []
self.actions = []
self.observations = []
if exp == "screenshot":
if observation_type == "screenshot":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "a11y_tree":
elif observation_type == "a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "both":
elif observation_type == "both":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "som":
elif observation_type == "som":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "seeact":
elif observation_type == "seeact":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
@@ -233,15 +227,14 @@ class PromptAgent:
else:
raise ValueError("Invalid action space: " + action_space)
else:
raise ValueError("Invalid experiment type: " + exp)
raise ValueError("Invalid experiment type: " + observation_type)
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
self.instruction)
def predict(self, obs: Dict) -> List:
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
instruction)
# Prepare the payload for the API call
messages = []
@@ -273,7 +266,7 @@ class PromptAgent:
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1
if self.exp == "both":
if self.observation_type == "both":
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
@@ -295,7 +288,7 @@ class PromptAgent:
}
]
})
elif self.exp in ["som", "seeact"]:
elif self.observation_type in ["som", "seeact"]:
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
@@ -317,7 +310,7 @@ class PromptAgent:
}
]
})
elif self.exp == "screenshot":
elif self.observation_type == "screenshot":
_screenshot = previous_obs["screenshot"]
messages.append({
@@ -336,7 +329,7 @@ class PromptAgent:
}
]
})
elif self.exp == "a11y_tree":
elif self.observation_type == "a11y_tree":
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
messages.append({
@@ -350,7 +343,7 @@ class PromptAgent:
]
})
else:
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
messages.append({
"role": "assistant",
@@ -363,11 +356,11 @@ class PromptAgent:
})
# {{{1
if self.exp in ["screenshot", "both"]:
if self.observation_type in ["screenshot", "both"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
if self.exp == "both":
if self.observation_type == "both":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
@@ -384,7 +377,7 @@ class PromptAgent:
{
"type": "text",
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
if self.exp == "screenshot"
if self.observation_type == "screenshot"
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
@@ -397,7 +390,7 @@ class PromptAgent:
}
]
})
elif self.exp == "a11y_tree":
elif self.observation_type == "a11y_tree":
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
@@ -415,7 +408,7 @@ class PromptAgent:
}
]
})
elif self.exp == "som":
elif self.observation_type == "som":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
@@ -443,7 +436,7 @@ class PromptAgent:
}
]
})
elif self.exp == "seeact":
elif self.observation_type == "seeact":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
@@ -471,21 +464,21 @@ class PromptAgent:
]
})
else:
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
with open("messages.json", "w") as f:
f.write(json.dumps(messages, indent=4))
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
# with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4))
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
})
logger.debug("RESPONSE: %s", response)
logger.info("RESPONSE: %s", response)
if self.exp == "seeact":
if self.observation_type == "seeact":
messages.append({
"role": "assistant",
"content": [
@@ -507,12 +500,15 @@ class PromptAgent:
]
})
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
print(response)
logger.info("RESPONSE: %s", response)
try:
actions = self.parse_actions(response, masks)
@@ -527,85 +523,90 @@ class PromptAgent:
@backoff.on_exception(
backoff.expo,
(Exception),
max_tries=10
max_tries=5
)
def call_llm(self, payload):
if self.model.startswith("gpt"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
}
logger.info("Generating content with GPT model: %s", self.model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
headers=headers,
json=payload
)
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
print("Context length exceeded. Retrying with a smaller context.")
logger.error("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
headers=headers,
json=payload
)
if retry_response.status_code != 200:
print("Failed to call LLM: " + retry_response.text)
logger.error("Failed to call LLM: " + retry_response.text)
return ""
print("Failed to call LLM: " + response.text)
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['choices'][0]['message']['content']
elif self.model.startswith("mistral"):
print("call mistral")
messages = payload["messages"]
max_tokens = payload["max_tokens"]
misrtal_messages = []
for i, message in enumerate(messages):
mistral_message = {
"role": message["role"],
"content": []
}
for part in message["content"]:
mistral_message['content'] = part['text'] if part['type'] == "text" else None
misrtal_messages.append(mistral_message)
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
if misrtal_messages[0]['role'] == "system":
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
misrtal_messages.pop(0)
# openai.api_base = "http://localhost:8000/v1"
# openai.api_key = "test"
# response = openai.ChatCompletion.create(
# messages=misrtal_messages,
# model="Mixtral-8x7B-Instruct-v0.1"
# )
from openai import OpenAI
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
client = OpenAI(api_key=TOGETHER_API_KEY,
base_url='https://api.together.xyz',
)
logger.info("Generating content with Mistral model: %s", self.model)
response = client.chat.completions.create(
messages=misrtal_messages,
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
max_tokens=1024
)
try:
# return response['choices'][0]['message']['content']
return response.choices[0].message.content
except Exception as e:
print("Failed to call LLM: " + str(e))
return ""
# elif self.model.startswith("mistral"):
# print("Call mistral")
# messages = payload["messages"]
# max_tokens = payload["max_tokens"]
#
# misrtal_messages = []
#
# for i, message in enumerate(messages):
# mistral_message = {
# "role": message["role"],
# "content": []
# }
#
# for part in message["content"]:
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
#
# misrtal_messages.append(mistral_message)
#
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
# if misrtal_messages[0]['role'] == "system":
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
# misrtal_messages.pop(0)
#
# # openai.api_base = "http://localhost:8000/v1"
# # openai.api_key = "test"
# # response = openai.ChatCompletion.create(
# # messages=misrtal_messages,
# # model="Mixtral-8x7B-Instruct-v0.1"
# # )
#
# from openai import OpenAI
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
#
# client = OpenAI(api_key=TOGETHER_API_KEY,
# base_url='https://api.together.xyz',
# )
# logger.info("Generating content with Mistral model: %s", self.model)
# response = client.chat.completions.create(
# messages=misrtal_messages,
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
# max_tokens=1024
# )
#
# try:
# # return response['choices'][0]['message']['content']
# return response.choices[0].message.content
# except Exception as e:
# print("Failed to call LLM: " + str(e))
# return ""
elif self.model.startswith("gemini"):
def encoded_img_to_pil_img(data_str):
@@ -617,6 +618,8 @@ class PromptAgent:
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
gemini_messages = []
for i, message in enumerate(messages):
@@ -662,7 +665,17 @@ class PromptAgent:
response = genai.GenerativeModel(self.model).generate_content(
gemini_messages,
generation_config={
"max_output_tokens": max_tokens
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
)
@@ -673,6 +686,8 @@ class PromptAgent:
elif self.model.startswith("qwen"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
qwen_messages = []
@@ -683,13 +698,16 @@ class PromptAgent:
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]:
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
'type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message)
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
messages=messages)
response = dashscope.MultiModalConversation.call(
model='qwen-vl-plus',
messages=messages, # todo: add the hyperparameters
)
# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code
# and message from code and message.
@@ -708,7 +726,7 @@ class PromptAgent:
def parse_actions(self, response: str, masks=None):
if self.exp in ["screenshot", "a11y_tree", "both"]:
if self.observation_type in ["screenshot", "a11y_tree", "both"]:
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
@@ -720,7 +738,7 @@ class PromptAgent:
self.actions.append(actions)
return actions
elif self.exp in ["som", "seeact"]:
elif self.observation_type in ["som", "seeact"]:
# parse from the response
if self.action_space == "computer_13":
raise ValueError("Invalid action space: " + self.action_space)
@@ -732,3 +750,8 @@ class PromptAgent:
self.actions.append(actions)
return actions
def reset(self):
self.thoughts = []
self.actions = []
self.observations = []

218
run.py Normal file
View File

@@ -0,0 +1,218 @@
"""Script to run end-to-end evaluation on the benchmark.
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import argparse
import datetime
import json
import logging
import os
import sys
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str,
default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx")
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
parser.add_argument(
"--observation_type",
choices=[
"screenshot",
"a11y_tree",
"screenshot_a11y_tree",
"som"
],
default="a11y_tree",
help="Observation type",
)
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
# lm config
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--stop_token", type=str, default=None)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
args = parser.parse_args()
return args
def test(
args: argparse.Namespace,
test_all_meta: dict
) -> None:
scores = []
max_steps = args.max_steps
# log args
logger.info("Args: %s", args)
agent = PromptAgent(
model=args.model,
max_tokens=args.max_tokens,
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
)
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height),
headless=args.headless,
)
for domain in test_all_meta:
for example_id in test_all_meta[domain]:
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Domain]: {domain}")
logger.info(f"[Example ID]: {example_id}")
instruction = example["instruction"]
logger.info(f"[Instruction]: {instruction}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id
)
os.makedirs(example_result_dir, exist_ok=True)
agent.reset()
obs = env.reset(task_config=example)
done = False
step_idx = 0
env.controller.start_recording()
while not done and step_idx < max_steps:
actions = agent.predict(
instruction,
obs
)
for action in actions:
step_idx += 1
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
observation, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(os.path.join(example_result_dir, "traj.json"), "a") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
def get_unfinished(test_file_list, result_dir):
finished = []
for domain in os.listdir(result_dir):
for example_id in os.listdir(os.path.join(result_dir, domain)):
finished.append(f"{domain}/{example_id}")
return [x for x in test_file_list if x not in finished]
if __name__ == '__main__':
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
# test_file_list = get_unfinished(args.test, args.result_dir)
# logger.info(f"Total {len(test_file_list)} tasks left")
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
test(args, test_all_meta)