Refactor experiments and agent implementation
This commit is contained in:
@@ -53,8 +53,8 @@ class DesktopEnv(gym.Env):
|
||||
def __init__(
|
||||
self,
|
||||
path_to_vm: str,
|
||||
snapshot_name: str ="init_state",
|
||||
action_space: str = "computer_13",
|
||||
task_config: Dict[str, Any] = None,
|
||||
tmp_dir: str = "tmp",
|
||||
cache_dir: str = "cache",
|
||||
screen_size: Tuple[int] = (1920, 1080),
|
||||
@@ -64,15 +64,6 @@ class DesktopEnv(gym.Env):
|
||||
Args:
|
||||
path_to_vm (str): path to .vmx file
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
|
||||
task_config (Dict[str, Any]): manages task configs integratedly,
|
||||
including
|
||||
* base snapshot
|
||||
* task id (uuid)
|
||||
* instruction
|
||||
* setup config
|
||||
* evaluator config
|
||||
|
||||
tmp_dir (str): temporary directory to store trajectory stuffs like
|
||||
the extracted screenshots
|
||||
cache_dir (str): cache directory to cache task-related stuffs like
|
||||
@@ -81,6 +72,7 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
||||
self.snapshot_name = snapshot_name
|
||||
self.tmp_dir_base: str = tmp_dir
|
||||
self.cache_dir_base: str = cache_dir
|
||||
self.vm_screen_size = screen_size
|
||||
@@ -88,16 +80,12 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
||||
|
||||
# task-aware stuffs
|
||||
# todo: handling the logic of snapshot directory
|
||||
self._set_task_info(task_config)
|
||||
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
self.vm_ip = self._get_vm_ip()
|
||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||
|
||||
# Meta info of the VM, move to the reset() function
|
||||
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
||||
@@ -147,7 +135,7 @@ class DesktopEnv(gym.Env):
|
||||
raise Exception("Failed to get VM IP address!")
|
||||
|
||||
def _save_state(self):
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
||||
|
||||
def _get_screenshot(self):
|
||||
# random_uuid = str(uuid.uuid4())
|
||||
@@ -167,7 +155,6 @@ class DesktopEnv(gym.Env):
|
||||
return screenshot_image_path
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
self.snapshot_path = task_config["snapshot"] # todo: save the snapshot when first start the environment, and then revert to it when reset
|
||||
self.task_id: str = task_config["id"]
|
||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
@@ -239,8 +226,8 @@ class DesktopEnv(gym.Env):
|
||||
)
|
||||
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
||||
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
|
||||
time.sleep(5)
|
||||
|
||||
print(self.vm_screen_size)
|
||||
|
||||
@@ -1,432 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(30, stop_recording)
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id, gpt4_model="gpt-4-0125-preview"):
|
||||
action_space = "pyautogui"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
logger.info("Running example %s/%s", example_class, example_id)
|
||||
logger.info("Using model %s", gpt4_model)
|
||||
# logger.info("Using model %s", gemini_model)
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'], max_tokens=1000,
|
||||
action_space=action_space, exp="a11y_tree")
|
||||
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="a11y_tree")
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os_list = [
|
||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
||||
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
|
||||
"13584542-872b-42d8-b299-866967b5c3ef",
|
||||
"23393935-50c7-4a86-aeea-2b78fd089c5c"
|
||||
]
|
||||
|
||||
# for example_id in os_list:
|
||||
# try:
|
||||
# main("os", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
vlc_list = [
|
||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
||||
"8f080098-ddb1-424c-b438-4e96e5e4786e",
|
||||
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
|
||||
"fba2c100-79e8-42df-ae74-b592418d54f4",
|
||||
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
|
||||
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
|
||||
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
|
||||
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
|
||||
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
|
||||
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
|
||||
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
|
||||
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
|
||||
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
|
||||
]
|
||||
|
||||
chrome_list = [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
||||
]
|
||||
|
||||
calc_list = [
|
||||
"eb03d19a-b88d-4de4-8a64-ca0ac66f426b",
|
||||
"0bf05a7d-b28b-44d2-955a-50b41e24012a",
|
||||
"7a4e4bc8-922c-4c84-865c-25ba34136be1",
|
||||
"2bd59342-0664-4ccb-ba87-79379096cc08",
|
||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||
"7efeb4b1-3d19-4762-b163-63328d66303b",
|
||||
"4e6fcf72-daf3-439f-a232-c434ce416af6",
|
||||
"6054afcb-5bab-4702-90a0-b259b5d3217c",
|
||||
"abed40dc-063f-4598-8ba5-9fe749c0615d",
|
||||
"01b269ae-2111-4a07-81fd-3fcd711993b0",
|
||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
|
||||
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
|
||||
"51b11269-2ca8-4b2a-9163-f21758420e78",
|
||||
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
|
||||
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
|
||||
"3aaa4e37-dc91-482e-99af-132a612d40f3",
|
||||
"37608790-6147-45d0-9f20-1137bb35703d",
|
||||
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
||||
"d681960f-7bc3-4286-9913-a8812ba3261a",
|
||||
"21df9241-f8d7-4509-b7f1-37e501a823f7",
|
||||
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
|
||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
|
||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
||||
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c",
|
||||
]
|
||||
|
||||
# for example_id in calc_list:
|
||||
# main("libreoffice_calc", example_id)
|
||||
|
||||
impress_list = [
|
||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
||||
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
||||
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
||||
"c59742c0-4323-4b9d-8a02-723c251deaa0",
|
||||
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
||||
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
||||
"0f84bef9-9790-432e-92b7-eece357603fb",
|
||||
"ce88f674-ab7a-43da-9201-468d38539e4a",
|
||||
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
||||
"a097acff-6266-4291-9fbd-137af7ecd439",
|
||||
"bf4e9888-f10f-47af-8dba-76413038b73c",
|
||||
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
||||
]
|
||||
# for example_id in impress_list:
|
||||
# main("libreoffice_impress", example_id)
|
||||
|
||||
thunderbird_list = [
|
||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
||||
]
|
||||
# for example_id in thunderbird_list:
|
||||
# main("thunderbird", example_id)
|
||||
|
||||
gimp_list = [
|
||||
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
||||
"554785e9-4523-4e7a-b8e1-8016f565f56a",
|
||||
"77b8ab4d-994f-43ac-8930-8ca087d7c4b4",
|
||||
"f4aec372-4fb0-4df5-a52b-79e0e2a5d6ce",
|
||||
"d52d6308-ec58-42b7-a2c9-de80e4837b2b",
|
||||
"2a729ded-3296-423d-aec4-7dd55ed5fbb3",
|
||||
"b148e375-fe0b-4bec-90e7-38632b0d73c2",
|
||||
"a746add2-cab0-4740-ac36-c3769d9bfb46",
|
||||
"7b7617bd-57cc-468e-9c91-40c4ec2bcb3d",
|
||||
"d16c99dc-2a1e-46f2-b350-d97c86c85c15",
|
||||
"06ca5602-62ca-47f6-ad4f-da151cde54cc",
|
||||
"e2dd0213-26db-4349-abe5-d5667bfd725c",
|
||||
"f723c744-e62c-4ae6-98d1-750d3cd7d79d",
|
||||
"72f83cdc-bf76-4531-9a1b-eb893a13f8aa",
|
||||
"7767eef2-56a3-4cea-8c9f-48c070c7d65b",
|
||||
"734d6579-c07d-47a8-9ae2-13339795476b"
|
||||
]
|
||||
|
||||
# for example_id in gimp_list:
|
||||
# try:
|
||||
# main("gimp", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
vs_code_list = [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
||||
]
|
||||
|
||||
# for example_id in vs_code_list:
|
||||
# try:
|
||||
# main("vs_code", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# for example_id in tqdm(vlc_list):
|
||||
# try:
|
||||
# main("vlc", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
# except Exception as e:
|
||||
# print(f"An error occurred while running the example: {e}")
|
||||
# continue
|
||||
|
||||
chrome_list = [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
||||
]
|
||||
# for example_id in tqdm(chrome_list):
|
||||
# try:
|
||||
# main("chrome", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
# except Exception as e:
|
||||
# print(f"An error occurred while running the example: {e}")
|
||||
# continue
|
||||
|
||||
vs_code_list = [
|
||||
# "0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
# "53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
# "eabc805a-bfcf-4460-b250-ac92135819f6",
|
||||
# "982d12a5-beab-424f-8d38-d2a48429e511",
|
||||
# "4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
# "e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
# "9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
# "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
# "930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
# "276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
# "9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
||||
]
|
||||
|
||||
for example_id in tqdm(vs_code_list):
|
||||
try:
|
||||
main("vs_code", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
except Exception as e:
|
||||
print(f"An error occurred while running the example: {e}")
|
||||
continue
|
||||
|
||||
thunderbird_list = [
|
||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
||||
]
|
||||
|
||||
# for example_id in tqdm(thunderbird_list):
|
||||
# try:
|
||||
# main("thunderbird", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
# except Exception as e:
|
||||
# print(f"An error occurred while running the example: {e}")
|
||||
# continue
|
||||
|
||||
multiple_list = [
|
||||
# "f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
|
||||
# "897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||
"2fe4b718-3bd7-46ec-bdce-b184f5653624",
|
||||
"3680a5ee-6870-426a-a997-eba929a0d25c",
|
||||
# "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||
# "b52b40a5-ad70-4c53-b5b0-5650a8387052",
|
||||
# "46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||
# "2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||
# "51f5801c-18b3-4f25-b0c3-02f85507a078",
|
||||
"58565672-7bfe-48ab-b828-db349231de6b",
|
||||
# "2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
|
||||
# "510f64c8-9bcc-4be1-8d30-638705850618",
|
||||
# "937087b6-f668-4ba6-9110-60682ee33441",
|
||||
# "ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
||||
# "3680a5ee-6870-426a-a997-eba929a0d25c",
|
||||
# "e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
||||
# "58565672-7bfe-48ab-b828-db349231de6b",
|
||||
# "2fe4b718-3bd7-46ec-bdce-b184f5653624"
|
||||
]
|
||||
|
||||
for example_id in multiple_list:
|
||||
try:
|
||||
main("multi_apps", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while running the example: %s", e)
|
||||
continue
|
||||
|
||||
@@ -1,604 +0,0 @@
|
||||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
import datetime
|
||||
import sys
|
||||
|
||||
import func_timeout
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
import openai
|
||||
import requests
|
||||
import torch
|
||||
from beartype import beartype
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.agent import PromptAgent # todo: change the name into PromptAgent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
|
||||
max_time=600):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example,
|
||||
headless=True
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(120, stop_recording)
|
||||
# todo: make sure we got the video file, check the bug
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
# fixme: change to write the result into a separate file
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# todo: append the result to the wandb for visualization
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
||||
# todo: merge the main function into the run_one_example function
|
||||
# fixme: change all the settings like action_space, model, etc. to the argparser
|
||||
action_space = "pyautogui"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
logger.info("Running example %s/%s", example_class, example_id)
|
||||
logger.info("Using model %s", gpt4_model)
|
||||
# logger.info("Using model %s", gemini_model)
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = PromptAgent(
|
||||
api_key=api_key,
|
||||
model=gpt4_model,
|
||||
instruction=example['instruction'],
|
||||
action_space=action_space,
|
||||
exp="screenshot"
|
||||
)
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(os.path.join(example_trajectory_dir, "trajectory.json")):
|
||||
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "r") as f:
|
||||
lines = f.readlines()
|
||||
# strip the last line if it is empty
|
||||
lines = [line.strip() for line in lines if line.strip() != ""]
|
||||
if len(lines) > 0:
|
||||
last_line = json.loads(lines[-1])
|
||||
if "result" in last_line:
|
||||
logger.info(
|
||||
f"evaluation_examples/examples/{example_class}/{example_id}.json" + "has been evaluated. Skip.")
|
||||
return
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(1200, run_one_example, args=(example, agent, 15, example_trajectory_dir))
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"error": str(e)
|
||||
}))
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default="Ubuntu\\Ubuntu.vmx")
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=[
|
||||
"screenshot",
|
||||
"a11y_tree",
|
||||
"screenshot_a11y_tree",
|
||||
"som"
|
||||
],
|
||||
default="accessibility_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--current_viewport_only",
|
||||
# action="store_true",
|
||||
# help="Only use the current viewport for the observation",
|
||||
# )
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--save_trace_enabled", action="store_true")
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=30)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--agent_type", type=str, default="prompt")
|
||||
parser.add_argument(
|
||||
"--instruction_path",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parsing_failure_th",
|
||||
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
|
||||
type=int,
|
||||
default=3,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeating_action_failure_th",
|
||||
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
|
||||
parser.add_argument("--test_config_base_dir", type=str)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--provider", type=str, default="openai")
|
||||
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
||||
parser.add_argument("--mode", type=str, default="chat")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--context_length", type=int, default=0)
|
||||
parser.add_argument("--max_tokens", type=int, default=384)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--max_retry",
|
||||
type=int,
|
||||
help="max retry times to perform generations when parsing fails",
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_obs_length",
|
||||
type=int,
|
||||
help="when not zero, will truncate the observation to this length before feeding to the model",
|
||||
default=3840,
|
||||
)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--test_start_idx", type=int, default=0)
|
||||
parser.add_argument("--test_end_idx", type=int, default=378)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@beartype
|
||||
def early_stop(
|
||||
trajectory, max_steps: int, thresholds: dict[str, int]
|
||||
) -> tuple[bool, str]:
|
||||
"""Check whether need to stop early"""
|
||||
|
||||
# reach the max step
|
||||
num_steps = (len(trajectory) - 1) / 2
|
||||
if num_steps >= max_steps:
|
||||
return True, f"Reach max steps {max_steps}"
|
||||
|
||||
# Case: parsing failure for k times
|
||||
k = thresholds["parsing_failure"]
|
||||
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
||||
if len(last_k_actions) >= k:
|
||||
if all(
|
||||
[
|
||||
action["action_type"] == ""
|
||||
for action in last_k_actions
|
||||
]
|
||||
):
|
||||
return True, f"Failed to parse actions for {k} times"
|
||||
|
||||
# Case: same action for k times
|
||||
k = thresholds["repeating_action"]
|
||||
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
||||
action_seq = trajectory[1::2] # type: ignore[assignment]
|
||||
|
||||
if len(action_seq) == 0:
|
||||
return False, ""
|
||||
|
||||
last_action = action_seq[-1]
|
||||
|
||||
if last_action["action_type"] != ActionTypes.TYPE:
|
||||
if len(last_k_actions) >= k:
|
||||
if all(
|
||||
[
|
||||
is_equivalent(action, last_action)
|
||||
for action in last_k_actions
|
||||
]
|
||||
):
|
||||
return True, f"Same action for {k} times"
|
||||
|
||||
else:
|
||||
# check the action sequence
|
||||
if (
|
||||
sum([is_equivalent(action, last_action) for action in action_seq])
|
||||
>= k
|
||||
):
|
||||
return True, f"Same typing action for {k} times"
|
||||
|
||||
return False, ""
|
||||
|
||||
|
||||
@beartype
|
||||
def test(
|
||||
args: argparse.Namespace,
|
||||
config_file_list: list[str]
|
||||
) -> None:
|
||||
scores = []
|
||||
max_steps = args.max_steps
|
||||
|
||||
early_stop_thresholds = {
|
||||
"parsing_failure": args.parsing_failure_th,
|
||||
"repeating_action": args.repeating_action_failure_th,
|
||||
}
|
||||
|
||||
if args.observation_type in [
|
||||
"accessibility_tree_with_captioner",
|
||||
"image_som",
|
||||
]:
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
caption_image_fn = image_utils.get_captioning_fn(
|
||||
device, dtype, args.captioning_model
|
||||
)
|
||||
else:
|
||||
caption_image_fn = None
|
||||
|
||||
# Load a (possibly different) captioning model for running VQA evals.
|
||||
if (
|
||||
caption_image_fn
|
||||
and args.eval_captioning_model == args.captioning_model
|
||||
):
|
||||
eval_caption_image_fn = caption_image_fn
|
||||
else:
|
||||
eval_caption_image_fn = image_utils.get_captioning_fn(
|
||||
args.eval_captioning_model_device,
|
||||
torch.float16
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and args.eval_captioning_model_device == "cuda"
|
||||
)
|
||||
else torch.float32,
|
||||
args.eval_captioning_model,
|
||||
)
|
||||
|
||||
agent = construct_agent(
|
||||
args,
|
||||
captioning_fn=caption_image_fn
|
||||
if args.observation_type == "accessibility_tree_with_captioner"
|
||||
else None,
|
||||
) # NOTE: captioning_fn here is used for captioning input images.
|
||||
|
||||
env = ScriptBrowserEnv(
|
||||
headless=not args.render,
|
||||
slow_mo=args.slow_mo,
|
||||
observation_type=args.observation_type,
|
||||
current_viewport_only=args.current_viewport_only,
|
||||
viewport_size={
|
||||
"width": args.viewport_width,
|
||||
"height": args.viewport_height,
|
||||
},
|
||||
save_trace_enabled=args.save_trace_enabled,
|
||||
sleep_after_execution=args.sleep_after_execution,
|
||||
# NOTE: captioning_fn here is used for LLM + captioning baselines.
|
||||
# This can be different from the captioning model used for evals.
|
||||
captioning_fn=caption_image_fn,
|
||||
)
|
||||
|
||||
for config_file in config_file_list:
|
||||
try:
|
||||
render_helper = RenderHelper(
|
||||
config_file, args.result_dir, args.action_set_tag
|
||||
)
|
||||
|
||||
# Load task.
|
||||
with open(config_file) as f:
|
||||
_c = json.load(f)
|
||||
intent = _c["intent"]
|
||||
task_id = _c["task_id"]
|
||||
image_paths = _c.get("image", None)
|
||||
images = []
|
||||
|
||||
# Load input images for the task, if any.
|
||||
if image_paths is not None:
|
||||
if isinstance(image_paths, str):
|
||||
image_paths = [image_paths]
|
||||
for image_path in image_paths:
|
||||
# Load image either from the web or from a local path.
|
||||
if image_path.startswith("http"):
|
||||
input_image = Image.open(requests.get(image_path, stream=True).raw)
|
||||
else:
|
||||
input_image = Image.open(image_path)
|
||||
|
||||
images.append(input_image)
|
||||
|
||||
logger.info(f"[Config file]: {config_file}")
|
||||
logger.info(f"[Intent]: {intent}")
|
||||
|
||||
agent.reset(config_file)
|
||||
trajectory: Trajectory = []
|
||||
obs, info = env.reset(options={"config_file": config_file})
|
||||
state_info: StateInfo = {"observation": obs, "info": info}
|
||||
trajectory.append(state_info)
|
||||
|
||||
meta_data = {"action_history": ["None"]}
|
||||
while True:
|
||||
early_stop_flag, stop_info = early_stop(
|
||||
trajectory, max_steps, early_stop_thresholds
|
||||
)
|
||||
|
||||
if early_stop_flag:
|
||||
action = create_stop_action(f"Early stop: {stop_info}")
|
||||
else:
|
||||
try:
|
||||
action = agent.next_action(
|
||||
trajectory,
|
||||
intent,
|
||||
images=images,
|
||||
meta_data=meta_data,
|
||||
)
|
||||
except ValueError as e:
|
||||
# get the error message
|
||||
action = create_stop_action(f"ERROR: {str(e)}")
|
||||
|
||||
trajectory.append(action)
|
||||
|
||||
action_str = get_action_description(
|
||||
action,
|
||||
state_info["info"]["observation_metadata"],
|
||||
action_set_tag=args.action_set_tag,
|
||||
prompt_constructor=agent.prompt_constructor
|
||||
if isinstance(agent, PromptAgent)
|
||||
else None,
|
||||
)
|
||||
render_helper.render(
|
||||
action, state_info, meta_data, args.render_screenshot
|
||||
)
|
||||
meta_data["action_history"].append(action_str)
|
||||
|
||||
if action["action_type"] == ActionTypes.STOP:
|
||||
break
|
||||
|
||||
obs, _, terminated, _, info = env.step(action)
|
||||
state_info = {"observation": obs, "info": info}
|
||||
trajectory.append(state_info)
|
||||
|
||||
if terminated:
|
||||
# add a action place holder
|
||||
trajectory.append(create_stop_action(""))
|
||||
break
|
||||
|
||||
# NOTE: eval_caption_image_fn is used for running eval_vqa functions.
|
||||
evaluator = evaluator_router(
|
||||
config_file, captioning_fn=eval_caption_image_fn
|
||||
)
|
||||
score = evaluator(
|
||||
trajectory=trajectory,
|
||||
config_file=config_file,
|
||||
page=env.page,
|
||||
client=env.get_page_client(env.page),
|
||||
)
|
||||
|
||||
scores.append(score)
|
||||
|
||||
if score == 1:
|
||||
logger.info(f"[Result] (PASS) {config_file}")
|
||||
else:
|
||||
logger.info(f"[Result] (FAIL) {config_file}")
|
||||
|
||||
if args.save_trace_enabled:
|
||||
env.save_trace(
|
||||
Path(args.result_dir) / "traces" / f"{task_id}.zip"
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
logger.info(f"[OpenAI Error] {repr(e)}")
|
||||
except Exception as e:
|
||||
logger.info(f"[Unhandled Error] {repr(e)}]")
|
||||
import traceback
|
||||
|
||||
# write to error file
|
||||
with open(Path(args.result_dir) / "error.txt", "a") as f:
|
||||
f.write(f"[Config file]: {config_file}\n")
|
||||
f.write(f"[Unhandled Error] {repr(e)}\n")
|
||||
f.write(traceback.format_exc()) # write stack trace to file
|
||||
|
||||
render_helper.close()
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def prepare(args: argparse.Namespace) -> None:
|
||||
# convert prompt python files to json
|
||||
from agent.prompts import to_json
|
||||
|
||||
to_json.run()
|
||||
|
||||
# prepare result dir
|
||||
result_dir = args.result_dir
|
||||
if not result_dir:
|
||||
result_dir = (
|
||||
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
|
||||
)
|
||||
if not Path(result_dir).exists():
|
||||
Path(result_dir).mkdir(parents=True, exist_ok=True)
|
||||
args.result_dir = result_dir
|
||||
logger.info(f"Create result dir: {result_dir}")
|
||||
|
||||
if not (Path(result_dir) / "traces").exists():
|
||||
(Path(result_dir) / "traces").mkdir(parents=True)
|
||||
|
||||
# log the log file
|
||||
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
|
||||
f.write(f"{LOG_FILE_NAME}\n")
|
||||
|
||||
|
||||
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
|
||||
result_files = glob.glob(f"{result_dir}/*.html")
|
||||
task_ids = [
|
||||
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
|
||||
]
|
||||
unfinished_configs = []
|
||||
for config_file in config_files:
|
||||
task_id = os.path.basename(config_file).split(".")[0]
|
||||
if task_id not in task_ids:
|
||||
unfinished_configs.append(config_file)
|
||||
return unfinished_configs
|
||||
|
||||
|
||||
@beartype
|
||||
def dump_config(args: argparse.Namespace) -> None:
|
||||
config_file = Path(args.result_dir) / "config.json"
|
||||
if not config_file.exists():
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
logger.info(f"Dump config to {config_file}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
args.sleep_after_execution = 5
|
||||
prepare(args)
|
||||
|
||||
test_config_base_dir = args.test_config_base_dir
|
||||
|
||||
test_file_list = []
|
||||
st_idx = args.test_start_idx
|
||||
ed_idx = args.test_end_idx
|
||||
for i in range(st_idx, ed_idx):
|
||||
test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json"))
|
||||
test_file_list = get_unfinished(test_file_list, args.result_dir)
|
||||
print(f"Total {len(test_file_list)} tasks left")
|
||||
args.render = False
|
||||
args.render_screenshot = True
|
||||
args.save_trace_enabled = True
|
||||
|
||||
args.current_viewport_only = True
|
||||
dump_config(args)
|
||||
|
||||
test(args, test_file_list)
|
||||
|
||||
# todo: add recorder of the progress of the examples
|
||||
|
||||
# todo: remove the useless example files
|
||||
|
||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
for domain in test_all_meta:
|
||||
for example_id in test_all_meta[domain]:
|
||||
main(domain, example_id, args.model)
|
||||
@@ -1,361 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu2\Ubuntu2.vmx"
|
||||
|
||||
|
||||
# PATH_TO_VM = "../../../../大文件/镜像/Ubuntu-1218/Ubuntu/Ubuntu.vmx"
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(30, stop_recording)
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
||||
action_space = "pyautogui"
|
||||
# example_class = "libreoffice_calc"
|
||||
# example_id = "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||
# example_id = "01b269ae-2111-4a07-81fd-3fcd711993b0"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
logger.info("Running example %s/%s", example_class, example_id)
|
||||
logger.info("Using model %s", gpt4_model)
|
||||
# logger.info("Using model %s", gemini_model)
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
# example["snapshot"] = "exp_setup4"
|
||||
# example["snapshot"] = "Snapshot 30"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
|
||||
action_space=action_space, exp="both")
|
||||
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="both")
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os_list = [
|
||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
||||
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
|
||||
"13584542-872b-42d8-b299-866967b5c3ef",
|
||||
"23393935-50c7-4a86-aeea-2b78fd089c5c"
|
||||
]
|
||||
|
||||
# for example_id in os_list:
|
||||
# try:
|
||||
# main("os", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
calc_list = [
|
||||
"a9f325aa-8c05-4e4f-8341-9e4358565f4f",
|
||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||
"7efeb4b1-3d19-4762-b163-63328d66303b",
|
||||
"4e6fcf72-daf3-439f-a232-c434ce416af6",
|
||||
"6054afcb-5bab-4702-90a0-b259b5d3217c",
|
||||
"abed40dc-063f-4598-8ba5-9fe749c0615d",
|
||||
"01b269ae-2111-4a07-81fd-3fcd711993b0",
|
||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||
"af2b02f7-acee-4be4-8b66-499fab394915",
|
||||
"da1d63b8-fa12-417b-ba18-f748e5f770f3",
|
||||
"636380ea-d5f6-4474-b6ca-b2ed578a20f1",
|
||||
"5ba77536-05c5-4aae-a9ff-6e298d094c3e",
|
||||
"4bc4eaf4-ca5e-4db2-8138-8d4e65af7c0b",
|
||||
"672a1b02-c62f-4ae2-acf0-37f5fb3052b0",
|
||||
"648fe544-16ba-44af-a587-12ccbe280ea6",
|
||||
"8985d1e4-5b99-4711-add4-88949ebb2308",
|
||||
"9e606842-2e27-43bf-b1d1-b43289c9589b",
|
||||
"fcb6e45b-25c4-4087-9483-03d714f473a9",
|
||||
"68c0c5b7-96f3-4e87-92a7-6c1b967fd2d2",
|
||||
"fff629ea-046e-4793-8eec-1a5a15c3eb35",
|
||||
"5c9a206c-bb00-4fb6-bb46-ee675c187df5",
|
||||
"e975ae74-79bd-4672-8d1c-dc841a85781d",
|
||||
"34a6938a-58da-4897-8639-9b90d6db5391",
|
||||
"b5a22759-b4eb-4bf2-aeed-ad14e8615f19",
|
||||
"2f9913a1-51ed-4db6-bfe0-7e1c95b3139e",
|
||||
"2558031e-401d-4579-8e00-3ecf540fb492",
|
||||
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
|
||||
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
|
||||
"51b11269-2ca8-4b2a-9163-f21758420e78",
|
||||
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
|
||||
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
|
||||
"3aaa4e37-dc91-482e-99af-132a612d40f3",
|
||||
"37608790-6147-45d0-9f20-1137bb35703d",
|
||||
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
||||
"d681960f-7bc3-4286-9913-a8812ba3261a",
|
||||
"21df9241-f8d7-4509-b7f1-37e501a823f7",
|
||||
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
|
||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
|
||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
||||
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c"
|
||||
]
|
||||
|
||||
# for example_id in calc_list:
|
||||
# try:
|
||||
# main("libreoffice_calc", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
impress_list = [
|
||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
||||
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
||||
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
||||
"c59742c0-4323-4b9d-8a02-723c251deaa0",
|
||||
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
||||
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
||||
"0f84bef9-9790-432e-92b7-eece357603fb",
|
||||
"ce88f674-ab7a-43da-9201-468d38539e4a",
|
||||
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
||||
"a097acff-6266-4291-9fbd-137af7ecd439",
|
||||
"bf4e9888-f10f-47af-8dba-76413038b73c",
|
||||
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
||||
]
|
||||
|
||||
# for example_id in impress_list:
|
||||
# try:
|
||||
# main("libreoffice_impress", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
vs_code_list = [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
||||
]
|
||||
|
||||
# for example_id in vs_code_list:
|
||||
# try:
|
||||
# main("vs_code", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
multiple_list = [
|
||||
"f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
|
||||
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||
"b52b40a5-ad70-4c53-b5b0-5650a8387052",
|
||||
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||
"51f5801c-18b3-4f25-b0c3-02f85507a078",
|
||||
"2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
|
||||
"510f64c8-9bcc-4be1-8d30-638705850618",
|
||||
"937087b6-f668-4ba6-9110-60682ee33441",
|
||||
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
||||
"3680a5ee-6870-426a-a997-eba929a0d25c",
|
||||
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||
"58565672-7bfe-48ab-b828-db349231de6b",
|
||||
"2fe4b718-3bd7-46ec-bdce-b184f5653624"
|
||||
]
|
||||
|
||||
# for example_id in multiple_list:
|
||||
# try:
|
||||
# main("multi_apps", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
chrome_list = [
|
||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
||||
]
|
||||
|
||||
# for example_id in chrome_list:
|
||||
# try:
|
||||
# main("chrome", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
|
||||
writer_list = [
|
||||
"6ada715d-3aae-4a32-a6a7-429b2e43fb93",
|
||||
"ecc2413d-8a48-416e-a3a2-d30106ca36cb",
|
||||
"0e47de2a-32e0-456c-a366-8c607ef7a9d2",
|
||||
"4bcb1253-a636-4df4-8cb0-a35c04dfef31",
|
||||
"0810415c-bde4-4443-9047-d5f70165a697",
|
||||
"e528b65e-1107-4b8c-8988-490e4fece599",
|
||||
"66399b0d-8fda-4618-95c4-bfc6191617e9",
|
||||
"936321ce-5236-426a-9a20-e0e3c5dc536f",
|
||||
"3ef2b351-8a84-4ff2-8724-d86eae9b842e",
|
||||
"0b17a146-2934-46c7-8727-73ff6b6483e8",
|
||||
"0e763496-b6bb-4508-a427-fad0b6c3e195",
|
||||
"f178a4a9-d090-4b56-bc4c-4b72a61a035d",
|
||||
"adf5e2c3-64c7-4644-b7b6-d2f0167927e7",
|
||||
"0a0faba3-5580-44df-965d-f562a99b291c",
|
||||
"e246f6d8-78d7-44ac-b668-fcf47946cb50",
|
||||
"8472fece-c7dd-4241-8d65-9b3cd1a0b568",
|
||||
"88fe4b2d-3040-4c70-9a70-546a47764b48",
|
||||
"d53ff5ee-3b1a-431e-b2be-30ed2673079b",
|
||||
"72b810ef-4156-4d09-8f08-a0cf57e7cefe",
|
||||
"6f81754e-285d-4ce0-b59e-af7edb02d108",
|
||||
"b21acd93-60fd-4127-8a43-2f5178f4a830"
|
||||
]
|
||||
|
||||
for example_id in writer_list:
|
||||
try:
|
||||
main("libreoffice_writer", example_id)
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while running the example: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import ctypes
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(30, stop_recording)
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id):
|
||||
action_space = "pyautogui"
|
||||
gpt4_model = "gpt-4-vision-preview"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
|
||||
action_space=action_space, exp="seeact")
|
||||
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
xx_list = [
|
||||
]
|
||||
for example_id in xx_list:
|
||||
main("xx", example_id)
|
||||
@@ -1,261 +0,0 @@
|
||||
#import ctypes
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(30, stop_recording)
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id):
|
||||
action_space = "pyautogui"
|
||||
gpt4_model = "gpt-4-vision-preview"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
logger.info("TASK: %s/%s", example_class, example_id)
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, max_tokens=1000, instruction=example['instruction'],
|
||||
action_space=action_space, exp="som")
|
||||
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from tqdm import tqdm
|
||||
# impress_list = [
|
||||
# # "5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
# "550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
||||
# "455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
||||
# "af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
||||
# "c59742c0-4323-4b9d-8a02-723c251deaa0",
|
||||
# "ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
||||
# "9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
||||
# "0f84bef9-9790-432e-92b7-eece357603fb",
|
||||
# "ce88f674-ab7a-43da-9201-468d38539e4a",
|
||||
# "3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
||||
# "a097acff-6266-4291-9fbd-137af7ecd439",
|
||||
# "bf4e9888-f10f-47af-8dba-76413038b73c",
|
||||
# "21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
||||
# ]
|
||||
# for example_id in impress_list:
|
||||
# main("libreoffice_impress", example_id)
|
||||
|
||||
vlc_list = [
|
||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
||||
"8f080098-ddb1-424c-b438-4e96e5e4786e",
|
||||
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
|
||||
"fba2c100-79e8-42df-ae74-b592418d54f4",
|
||||
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
|
||||
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
|
||||
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
|
||||
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
|
||||
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
|
||||
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
|
||||
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
|
||||
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
|
||||
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
|
||||
]
|
||||
|
||||
# for example_id in tqdm(vlc_list):
|
||||
# try:
|
||||
# main("vlc", example_id)
|
||||
# except Exception as e:
|
||||
# print(f"An error occurred while running the example: {e}")
|
||||
# continue
|
||||
|
||||
chrome_list = [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
||||
]
|
||||
for example_id in tqdm(chrome_list):
|
||||
try:
|
||||
main("chrome", example_id)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while running the example: {e}")
|
||||
continue
|
||||
|
||||
vs_code_list = [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
||||
]
|
||||
|
||||
for example_id in tqdm(vs_code_list):
|
||||
try:
|
||||
main("vs_code", example_id)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while running the example: {e}")
|
||||
continue
|
||||
|
||||
thunderbird_list = [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
||||
]
|
||||
|
||||
for example_id in tqdm(thunderbird_list):
|
||||
try:
|
||||
main("thunderbird", example_id)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while running the example: {e}")
|
||||
continue
|
||||
@@ -5,21 +5,20 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
from PIL import Image
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
RateLimitError
|
||||
from vertexai.preview.generative_models import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
Image,
|
||||
)
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||
@@ -29,7 +28,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
|
||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
||||
|
||||
import logging
|
||||
# todo: cross-check with visualwebarena
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
@@ -42,7 +40,7 @@ def encode_image(image_path):
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree):
|
||||
#leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||
|
||||
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
||||
@@ -172,60 +170,56 @@ def parse_code_from_som_string(input_string, masks):
|
||||
class PromptAgent:
|
||||
def __init__(
|
||||
self,
|
||||
api_key,
|
||||
instruction,
|
||||
model="gpt-4-vision-preview",
|
||||
max_tokens=500,
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
temperature=0.5,
|
||||
action_space="computer_13",
|
||||
exp="screenshot_a11y_tree"
|
||||
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
||||
max_trajectory_length=3
|
||||
):
|
||||
|
||||
self.instruction = instruction
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.exp = exp
|
||||
self.max_trajectory_length = 3
|
||||
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
|
||||
if exp == "screenshot":
|
||||
if observation_type == "screenshot":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "a11y_tree":
|
||||
elif observation_type == "a11y_tree":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "both":
|
||||
elif observation_type == "both":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "som":
|
||||
elif observation_type == "som":
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "seeact":
|
||||
elif observation_type == "seeact":
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
@@ -233,15 +227,14 @@ class PromptAgent:
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + exp)
|
||||
raise ValueError("Invalid experiment type: " + observation_type)
|
||||
|
||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
||||
self.instruction)
|
||||
|
||||
def predict(self, obs: Dict) -> List:
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
||||
instruction)
|
||||
|
||||
# Prepare the payload for the API call
|
||||
messages = []
|
||||
@@ -273,7 +266,7 @@ class PromptAgent:
|
||||
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
||||
|
||||
# {{{1
|
||||
if self.exp == "both":
|
||||
if self.observation_type == "both":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||
@@ -295,7 +288,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp in ["som", "seeact"]:
|
||||
elif self.observation_type in ["som", "seeact"]:
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||
@@ -317,7 +310,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "screenshot":
|
||||
elif self.observation_type == "screenshot":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
|
||||
messages.append({
|
||||
@@ -336,7 +329,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "a11y_tree":
|
||||
elif self.observation_type == "a11y_tree":
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
|
||||
messages.append({
|
||||
@@ -350,7 +343,7 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
||||
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
@@ -363,11 +356,11 @@ class PromptAgent:
|
||||
})
|
||||
|
||||
# {{{1
|
||||
if self.exp in ["screenshot", "both"]:
|
||||
if self.observation_type in ["screenshot", "both"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
if self.exp == "both":
|
||||
if self.observation_type == "both":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -384,7 +377,7 @@ class PromptAgent:
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
||||
if self.exp == "screenshot"
|
||||
if self.observation_type == "screenshot"
|
||||
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
linearized_accessibility_tree)
|
||||
},
|
||||
@@ -397,7 +390,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "a11y_tree":
|
||||
elif self.observation_type == "a11y_tree":
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
@@ -415,7 +408,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "som":
|
||||
elif self.observation_type == "som":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
@@ -443,7 +436,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "seeact":
|
||||
elif self.observation_type == "seeact":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
@@ -471,21 +464,21 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
||||
|
||||
with open("messages.json", "w") as f:
|
||||
f.write(json.dumps(messages, indent=4))
|
||||
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||
|
||||
# with open("messages.json", "w") as f:
|
||||
# f.write(json.dumps(messages, indent=4))
|
||||
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
})
|
||||
|
||||
logger.debug("RESPONSE: %s", response)
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
if self.exp == "seeact":
|
||||
if self.observation_type == "seeact":
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
@@ -507,12 +500,15 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature
|
||||
})
|
||||
print(response)
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
try:
|
||||
actions = self.parse_actions(response, masks)
|
||||
@@ -527,85 +523,90 @@ class PromptAgent:
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
(Exception),
|
||||
max_tries=10
|
||||
max_tries=5
|
||||
)
|
||||
def call_llm(self, payload):
|
||||
|
||||
if self.model.startswith("gpt"):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||
}
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
if response.json()['error']['code'] == "context_length_exceeded":
|
||||
print("Context length exceeded. Retrying with a smaller context.")
|
||||
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||
payload["messages"] = payload["messages"][-1:]
|
||||
retry_response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
if retry_response.status_code != 200:
|
||||
print("Failed to call LLM: " + retry_response.text)
|
||||
logger.error("Failed to call LLM: " + retry_response.text)
|
||||
return ""
|
||||
|
||||
print("Failed to call LLM: " + response.text)
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
|
||||
elif self.model.startswith("mistral"):
|
||||
print("call mistral")
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
|
||||
misrtal_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
mistral_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
for part in message["content"]:
|
||||
mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||
|
||||
misrtal_messages.append(mistral_message)
|
||||
|
||||
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
if misrtal_messages[0]['role'] == "system":
|
||||
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||
misrtal_messages.pop(0)
|
||||
|
||||
# openai.api_base = "http://localhost:8000/v1"
|
||||
# openai.api_key = "test"
|
||||
# response = openai.ChatCompletion.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# )
|
||||
|
||||
from openai import OpenAI
|
||||
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||
|
||||
client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
logger.info("Generating content with Mistral model: %s", self.model)
|
||||
response = client.chat.completions.create(
|
||||
messages=misrtal_messages,
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
try:
|
||||
# return response['choices'][0]['message']['content']
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
return ""
|
||||
# elif self.model.startswith("mistral"):
|
||||
# print("Call mistral")
|
||||
# messages = payload["messages"]
|
||||
# max_tokens = payload["max_tokens"]
|
||||
#
|
||||
# misrtal_messages = []
|
||||
#
|
||||
# for i, message in enumerate(messages):
|
||||
# mistral_message = {
|
||||
# "role": message["role"],
|
||||
# "content": []
|
||||
# }
|
||||
#
|
||||
# for part in message["content"]:
|
||||
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||
#
|
||||
# misrtal_messages.append(mistral_message)
|
||||
#
|
||||
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
# if misrtal_messages[0]['role'] == "system":
|
||||
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||
# misrtal_messages.pop(0)
|
||||
#
|
||||
# # openai.api_base = "http://localhost:8000/v1"
|
||||
# # openai.api_key = "test"
|
||||
# # response = openai.ChatCompletion.create(
|
||||
# # messages=misrtal_messages,
|
||||
# # model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# # )
|
||||
#
|
||||
# from openai import OpenAI
|
||||
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||
#
|
||||
# client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||
# base_url='https://api.together.xyz',
|
||||
# )
|
||||
# logger.info("Generating content with Mistral model: %s", self.model)
|
||||
# response = client.chat.completions.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
# max_tokens=1024
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # return response['choices'][0]['message']['content']
|
||||
# return response.choices[0].message.content
|
||||
# except Exception as e:
|
||||
# print("Failed to call LLM: " + str(e))
|
||||
# return ""
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
@@ -617,6 +618,8 @@ class PromptAgent:
|
||||
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
gemini_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
@@ -662,7 +665,17 @@ class PromptAgent:
|
||||
response = genai.GenerativeModel(self.model).generate_content(
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"max_output_tokens": max_tokens
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
},
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -673,6 +686,8 @@ class PromptAgent:
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
qwen_messages = []
|
||||
|
||||
@@ -683,13 +698,16 @@ class PromptAgent:
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||
|
||||
qwen_messages.append(qwen_message)
|
||||
|
||||
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
|
||||
messages=messages)
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
model='qwen-vl-plus',
|
||||
messages=messages, # todo: add the hyperparameters
|
||||
)
|
||||
# The response status_code is HTTPStatus.OK indicate success,
|
||||
# otherwise indicate request is failed, you can get error code
|
||||
# and message from code and message.
|
||||
@@ -708,7 +726,7 @@ class PromptAgent:
|
||||
|
||||
def parse_actions(self, response: str, masks=None):
|
||||
|
||||
if self.exp in ["screenshot", "a11y_tree", "both"]:
|
||||
if self.observation_type in ["screenshot", "a11y_tree", "both"]:
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
actions = parse_actions_from_string(response)
|
||||
@@ -720,7 +738,7 @@ class PromptAgent:
|
||||
self.actions.append(actions)
|
||||
|
||||
return actions
|
||||
elif self.exp in ["som", "seeact"]:
|
||||
elif self.observation_type in ["som", "seeact"]:
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + self.action_space)
|
||||
@@ -732,3 +750,8 @@ class PromptAgent:
|
||||
self.actions.append(actions)
|
||||
|
||||
return actions
|
||||
|
||||
def reset(self):
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
|
||||
218
run.py
Normal file
218
run.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.agent import PromptAgent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str,
|
||||
default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx")
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=[
|
||||
"screenshot",
|
||||
"a11y_tree",
|
||||
"screenshot_a11y_tree",
|
||||
"som"
|
||||
],
|
||||
default="a11y_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def test(
|
||||
args: argparse.Namespace,
|
||||
test_all_meta: dict
|
||||
) -> None:
|
||||
scores = []
|
||||
max_steps = args.max_steps
|
||||
|
||||
# log args
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
agent = PromptAgent(
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
)
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=agent.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
)
|
||||
|
||||
for domain in test_all_meta:
|
||||
for example_id in test_all_meta[domain]:
|
||||
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Domain]: {domain}")
|
||||
logger.info(f"[Example ID]: {example_id}")
|
||||
|
||||
instruction = example["instruction"]
|
||||
|
||||
logger.info(f"[Instruction]: {instruction}")
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
agent.reset()
|
||||
obs = env.reset(task_config=example)
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
|
||||
for action in actions:
|
||||
step_idx += 1
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
|
||||
observation, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.json"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def get_unfinished(test_file_list, result_dir):
|
||||
finished = []
|
||||
for domain in os.listdir(result_dir):
|
||||
for example_id in os.listdir(os.path.join(result_dir, domain)):
|
||||
finished.append(f"{domain}/{example_id}")
|
||||
return [x for x in test_file_list if x not in finished]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
# test_file_list = get_unfinished(args.test, args.result_dir)
|
||||
# logger.info(f"Total {len(test_file_list)} tasks left")
|
||||
|
||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
test(args, test_all_meta)
|
||||
Reference in New Issue
Block a user