Refactor experiments and agent implementation
This commit is contained in:
@@ -53,8 +53,8 @@ class DesktopEnv(gym.Env):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
path_to_vm: str,
|
path_to_vm: str,
|
||||||
|
snapshot_name: str ="init_state",
|
||||||
action_space: str = "computer_13",
|
action_space: str = "computer_13",
|
||||||
task_config: Dict[str, Any] = None,
|
|
||||||
tmp_dir: str = "tmp",
|
tmp_dir: str = "tmp",
|
||||||
cache_dir: str = "cache",
|
cache_dir: str = "cache",
|
||||||
screen_size: Tuple[int] = (1920, 1080),
|
screen_size: Tuple[int] = (1920, 1080),
|
||||||
@@ -64,15 +64,6 @@ class DesktopEnv(gym.Env):
|
|||||||
Args:
|
Args:
|
||||||
path_to_vm (str): path to .vmx file
|
path_to_vm (str): path to .vmx file
|
||||||
action_space (str): "computer_13" | "pyautogui"
|
action_space (str): "computer_13" | "pyautogui"
|
||||||
|
|
||||||
task_config (Dict[str, Any]): manages task configs integratedly,
|
|
||||||
including
|
|
||||||
* base snapshot
|
|
||||||
* task id (uuid)
|
|
||||||
* instruction
|
|
||||||
* setup config
|
|
||||||
* evaluator config
|
|
||||||
|
|
||||||
tmp_dir (str): temporary directory to store trajectory stuffs like
|
tmp_dir (str): temporary directory to store trajectory stuffs like
|
||||||
the extracted screenshots
|
the extracted screenshots
|
||||||
cache_dir (str): cache directory to cache task-related stuffs like
|
cache_dir (str): cache directory to cache task-related stuffs like
|
||||||
@@ -81,6 +72,7 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
# Initialize environment variables
|
# Initialize environment variables
|
||||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
||||||
|
self.snapshot_name = snapshot_name
|
||||||
self.tmp_dir_base: str = tmp_dir
|
self.tmp_dir_base: str = tmp_dir
|
||||||
self.cache_dir_base: str = cache_dir
|
self.cache_dir_base: str = cache_dir
|
||||||
self.vm_screen_size = screen_size
|
self.vm_screen_size = screen_size
|
||||||
@@ -88,16 +80,12 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
||||||
|
|
||||||
# task-aware stuffs
|
|
||||||
# todo: handling the logic of snapshot directory
|
|
||||||
self._set_task_info(task_config)
|
|
||||||
|
|
||||||
# Initialize emulator and controller
|
# Initialize emulator and controller
|
||||||
logger.info("Initializing...")
|
logger.info("Initializing...")
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
self.vm_ip = self._get_vm_ip()
|
self.vm_ip = self._get_vm_ip()
|
||||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
|
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||||
|
|
||||||
# Meta info of the VM, move to the reset() function
|
# Meta info of the VM, move to the reset() function
|
||||||
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
||||||
@@ -147,7 +135,7 @@ class DesktopEnv(gym.Env):
|
|||||||
raise Exception("Failed to get VM IP address!")
|
raise Exception("Failed to get VM IP address!")
|
||||||
|
|
||||||
def _save_state(self):
|
def _save_state(self):
|
||||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
||||||
|
|
||||||
def _get_screenshot(self):
|
def _get_screenshot(self):
|
||||||
# random_uuid = str(uuid.uuid4())
|
# random_uuid = str(uuid.uuid4())
|
||||||
@@ -167,7 +155,6 @@ class DesktopEnv(gym.Env):
|
|||||||
return screenshot_image_path
|
return screenshot_image_path
|
||||||
|
|
||||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||||
self.snapshot_path = task_config["snapshot"] # todo: save the snapshot when first start the environment, and then revert to it when reset
|
|
||||||
self.task_id: str = task_config["id"]
|
self.task_id: str = task_config["id"]
|
||||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||||
os.makedirs(self.cache_dir, exist_ok=True)
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
@@ -239,8 +226,8 @@ class DesktopEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
||||||
|
|
||||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
|
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
print(self.vm_screen_size)
|
print(self.vm_screen_size)
|
||||||
|
|||||||
@@ -1,432 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(30, stop_recording)
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id, gpt4_model="gpt-4-0125-preview"):
|
|
||||||
action_space = "pyautogui"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
logger.info("Running example %s/%s", example_class, example_id)
|
|
||||||
logger.info("Using model %s", gpt4_model)
|
|
||||||
# logger.info("Using model %s", gemini_model)
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'], max_tokens=1000,
|
|
||||||
action_space=action_space, exp="a11y_tree")
|
|
||||||
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="a11y_tree")
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
os_list = [
|
|
||||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
|
||||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
|
||||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
|
||||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
|
||||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
|
||||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
|
||||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
|
||||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
|
||||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
|
||||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
|
||||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
|
||||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
|
||||||
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
|
|
||||||
"13584542-872b-42d8-b299-866967b5c3ef",
|
|
||||||
"23393935-50c7-4a86-aeea-2b78fd089c5c"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in os_list:
|
|
||||||
# try:
|
|
||||||
# main("os", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
vlc_list = [
|
|
||||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
|
||||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
|
||||||
"8f080098-ddb1-424c-b438-4e96e5e4786e",
|
|
||||||
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
|
|
||||||
"fba2c100-79e8-42df-ae74-b592418d54f4",
|
|
||||||
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
|
|
||||||
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
|
|
||||||
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
|
|
||||||
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
|
|
||||||
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
|
|
||||||
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
|
|
||||||
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
|
|
||||||
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
|
|
||||||
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
|
|
||||||
]
|
|
||||||
|
|
||||||
chrome_list = [
|
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
|
||||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
|
||||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
|
||||||
]
|
|
||||||
|
|
||||||
calc_list = [
|
|
||||||
"eb03d19a-b88d-4de4-8a64-ca0ac66f426b",
|
|
||||||
"0bf05a7d-b28b-44d2-955a-50b41e24012a",
|
|
||||||
"7a4e4bc8-922c-4c84-865c-25ba34136be1",
|
|
||||||
"2bd59342-0664-4ccb-ba87-79379096cc08",
|
|
||||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
|
||||||
"7efeb4b1-3d19-4762-b163-63328d66303b",
|
|
||||||
"4e6fcf72-daf3-439f-a232-c434ce416af6",
|
|
||||||
"6054afcb-5bab-4702-90a0-b259b5d3217c",
|
|
||||||
"abed40dc-063f-4598-8ba5-9fe749c0615d",
|
|
||||||
"01b269ae-2111-4a07-81fd-3fcd711993b0",
|
|
||||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
|
||||||
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
|
|
||||||
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
|
|
||||||
"51b11269-2ca8-4b2a-9163-f21758420e78",
|
|
||||||
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
|
|
||||||
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
|
||||||
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
|
|
||||||
"3aaa4e37-dc91-482e-99af-132a612d40f3",
|
|
||||||
"37608790-6147-45d0-9f20-1137bb35703d",
|
|
||||||
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
|
||||||
"d681960f-7bc3-4286-9913-a8812ba3261a",
|
|
||||||
"21df9241-f8d7-4509-b7f1-37e501a823f7",
|
|
||||||
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
|
|
||||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
|
||||||
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
|
|
||||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
|
||||||
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c",
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in calc_list:
|
|
||||||
# main("libreoffice_calc", example_id)
|
|
||||||
|
|
||||||
impress_list = [
|
|
||||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
|
||||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
|
||||||
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
|
||||||
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
|
||||||
"c59742c0-4323-4b9d-8a02-723c251deaa0",
|
|
||||||
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
|
||||||
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
|
||||||
"0f84bef9-9790-432e-92b7-eece357603fb",
|
|
||||||
"ce88f674-ab7a-43da-9201-468d38539e4a",
|
|
||||||
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
|
||||||
"a097acff-6266-4291-9fbd-137af7ecd439",
|
|
||||||
"bf4e9888-f10f-47af-8dba-76413038b73c",
|
|
||||||
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
|
||||||
]
|
|
||||||
# for example_id in impress_list:
|
|
||||||
# main("libreoffice_impress", example_id)
|
|
||||||
|
|
||||||
thunderbird_list = [
|
|
||||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
|
||||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
|
||||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
|
||||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
|
||||||
]
|
|
||||||
# for example_id in thunderbird_list:
|
|
||||||
# main("thunderbird", example_id)
|
|
||||||
|
|
||||||
gimp_list = [
|
|
||||||
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
|
||||||
"554785e9-4523-4e7a-b8e1-8016f565f56a",
|
|
||||||
"77b8ab4d-994f-43ac-8930-8ca087d7c4b4",
|
|
||||||
"f4aec372-4fb0-4df5-a52b-79e0e2a5d6ce",
|
|
||||||
"d52d6308-ec58-42b7-a2c9-de80e4837b2b",
|
|
||||||
"2a729ded-3296-423d-aec4-7dd55ed5fbb3",
|
|
||||||
"b148e375-fe0b-4bec-90e7-38632b0d73c2",
|
|
||||||
"a746add2-cab0-4740-ac36-c3769d9bfb46",
|
|
||||||
"7b7617bd-57cc-468e-9c91-40c4ec2bcb3d",
|
|
||||||
"d16c99dc-2a1e-46f2-b350-d97c86c85c15",
|
|
||||||
"06ca5602-62ca-47f6-ad4f-da151cde54cc",
|
|
||||||
"e2dd0213-26db-4349-abe5-d5667bfd725c",
|
|
||||||
"f723c744-e62c-4ae6-98d1-750d3cd7d79d",
|
|
||||||
"72f83cdc-bf76-4531-9a1b-eb893a13f8aa",
|
|
||||||
"7767eef2-56a3-4cea-8c9f-48c070c7d65b",
|
|
||||||
"734d6579-c07d-47a8-9ae2-13339795476b"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in gimp_list:
|
|
||||||
# try:
|
|
||||||
# main("gimp", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
vs_code_list = [
|
|
||||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
|
||||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
|
||||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
|
||||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
|
||||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
|
||||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
|
||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
|
||||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in vs_code_list:
|
|
||||||
# try:
|
|
||||||
# main("vs_code", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# for example_id in tqdm(vlc_list):
|
|
||||||
# try:
|
|
||||||
# main("vlc", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An error occurred while running the example: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
chrome_list = [
|
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
|
||||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
|
||||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
|
||||||
]
|
|
||||||
# for example_id in tqdm(chrome_list):
|
|
||||||
# try:
|
|
||||||
# main("chrome", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An error occurred while running the example: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
vs_code_list = [
|
|
||||||
# "0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
|
||||||
# "53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
|
||||||
# "eabc805a-bfcf-4460-b250-ac92135819f6",
|
|
||||||
# "982d12a5-beab-424f-8d38-d2a48429e511",
|
|
||||||
# "4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
|
||||||
# "e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
|
||||||
# "9439a27b-18ae-42d8-9778-5f68f891805e",
|
|
||||||
# "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
|
||||||
# "930fdb3b-11a8-46fe-9bac-577332e2640e",
|
|
||||||
# "276cc624-87ea-4f08-ab93-f770e3790175",
|
|
||||||
# "9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in tqdm(vs_code_list):
|
|
||||||
try:
|
|
||||||
main("vs_code", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while running the example: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
thunderbird_list = [
|
|
||||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
|
||||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
|
||||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
|
||||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in tqdm(thunderbird_list):
|
|
||||||
# try:
|
|
||||||
# main("thunderbird", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An error occurred while running the example: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
multiple_list = [
|
|
||||||
# "f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
|
|
||||||
# "897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
|
||||||
"2fe4b718-3bd7-46ec-bdce-b184f5653624",
|
|
||||||
"3680a5ee-6870-426a-a997-eba929a0d25c",
|
|
||||||
# "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
|
||||||
# "b52b40a5-ad70-4c53-b5b0-5650a8387052",
|
|
||||||
# "46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
|
||||||
# "2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
|
||||||
# "51f5801c-18b3-4f25-b0c3-02f85507a078",
|
|
||||||
"58565672-7bfe-48ab-b828-db349231de6b",
|
|
||||||
# "2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
|
|
||||||
# "510f64c8-9bcc-4be1-8d30-638705850618",
|
|
||||||
# "937087b6-f668-4ba6-9110-60682ee33441",
|
|
||||||
# "ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
|
||||||
# "3680a5ee-6870-426a-a997-eba929a0d25c",
|
|
||||||
# "e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
|
||||||
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
|
||||||
# "58565672-7bfe-48ab-b828-db349231de6b",
|
|
||||||
# "2fe4b718-3bd7-46ec-bdce-b184f5653624"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in multiple_list:
|
|
||||||
try:
|
|
||||||
main("multi_apps", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("An error occurred while running the example: %s", e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
@@ -1,604 +0,0 @@
|
|||||||
"""Script to run end-to-end evaluation on the benchmark.
|
|
||||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
|
||||||
"""
|
|
||||||
import datetime
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import func_timeout
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
import openai
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
from beartype import beartype
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.agent import PromptAgent # todo: change the name into PromptAgent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
|
|
||||||
max_time=600):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example,
|
|
||||||
headless=True
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(120, stop_recording)
|
|
||||||
# todo: make sure we got the video file, check the bug
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
# fixme: change to write the result into a separate file
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# todo: append the result to the wandb for visualization
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
|
||||||
# todo: merge the main function into the run_one_example function
|
|
||||||
# fixme: change all the settings like action_space, model, etc. to the argparser
|
|
||||||
action_space = "pyautogui"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
logger.info("Running example %s/%s", example_class, example_id)
|
|
||||||
logger.info("Using model %s", gpt4_model)
|
|
||||||
# logger.info("Using model %s", gemini_model)
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = PromptAgent(
|
|
||||||
api_key=api_key,
|
|
||||||
model=gpt4_model,
|
|
||||||
instruction=example['instruction'],
|
|
||||||
action_space=action_space,
|
|
||||||
exp="screenshot"
|
|
||||||
)
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if os.path.exists(os.path.join(example_trajectory_dir, "trajectory.json")):
|
|
||||||
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "r") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
# strip the last line if it is empty
|
|
||||||
lines = [line.strip() for line in lines if line.strip() != ""]
|
|
||||||
if len(lines) > 0:
|
|
||||||
last_line = json.loads(lines[-1])
|
|
||||||
if "result" in last_line:
|
|
||||||
logger.info(
|
|
||||||
f"evaluation_examples/examples/{example_class}/{example_id}.json" + "has been evaluated. Skip.")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(1200, run_one_example, args=(example, agent, 15, example_trajectory_dir))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred: {e}")
|
|
||||||
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"error": str(e)
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
|
||||||
def config() -> argparse.Namespace:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run end-to-end evaluation on the benchmark"
|
|
||||||
)
|
|
||||||
|
|
||||||
# environment config
|
|
||||||
parser.add_argument("--path_to_vm", type=str, default="Ubuntu\\Ubuntu.vmx")
|
|
||||||
parser.add_argument(
|
|
||||||
"--headless", action="store_true", help="Run in headless machine"
|
|
||||||
)
|
|
||||||
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
|
|
||||||
parser.add_argument(
|
|
||||||
"--observation_type",
|
|
||||||
choices=[
|
|
||||||
"screenshot",
|
|
||||||
"a11y_tree",
|
|
||||||
"screenshot_a11y_tree",
|
|
||||||
"som"
|
|
||||||
],
|
|
||||||
default="accessibility_tree",
|
|
||||||
help="Observation type",
|
|
||||||
)
|
|
||||||
# parser.add_argument(
|
|
||||||
# "--current_viewport_only",
|
|
||||||
# action="store_true",
|
|
||||||
# help="Only use the current viewport for the observation",
|
|
||||||
# )
|
|
||||||
parser.add_argument("--screen_width", type=int, default=1920)
|
|
||||||
parser.add_argument("--screen_height", type=int, default=1080)
|
|
||||||
parser.add_argument("--save_trace_enabled", action="store_true")
|
|
||||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
|
||||||
parser.add_argument("--max_steps", type=int, default=30)
|
|
||||||
|
|
||||||
# agent config
|
|
||||||
parser.add_argument("--agent_type", type=str, default="prompt")
|
|
||||||
parser.add_argument(
|
|
||||||
"--instruction_path",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--parsing_failure_th",
|
|
||||||
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repeating_action_failure_th",
|
|
||||||
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--test_config_base_dir", type=str)
|
|
||||||
|
|
||||||
# lm config
|
|
||||||
parser.add_argument("--provider", type=str, default="openai")
|
|
||||||
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
|
||||||
parser.add_argument("--mode", type=str, default="chat")
|
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
|
||||||
parser.add_argument("--context_length", type=int, default=0)
|
|
||||||
parser.add_argument("--max_tokens", type=int, default=384)
|
|
||||||
parser.add_argument("--stop_token", type=str, default=None)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_retry",
|
|
||||||
type=int,
|
|
||||||
help="max retry times to perform generations when parsing fails",
|
|
||||||
default=1,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_obs_length",
|
|
||||||
type=int,
|
|
||||||
help="when not zero, will truncate the observation to this length before feeding to the model",
|
|
||||||
default=3840,
|
|
||||||
)
|
|
||||||
|
|
||||||
# example config
|
|
||||||
parser.add_argument("--test_start_idx", type=int, default=0)
|
|
||||||
parser.add_argument("--test_end_idx", type=int, default=378)
|
|
||||||
|
|
||||||
# logging related
|
|
||||||
parser.add_argument("--result_dir", type=str, default="")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
@beartype
|
|
||||||
def early_stop(
|
|
||||||
trajectory, max_steps: int, thresholds: dict[str, int]
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""Check whether need to stop early"""
|
|
||||||
|
|
||||||
# reach the max step
|
|
||||||
num_steps = (len(trajectory) - 1) / 2
|
|
||||||
if num_steps >= max_steps:
|
|
||||||
return True, f"Reach max steps {max_steps}"
|
|
||||||
|
|
||||||
# Case: parsing failure for k times
|
|
||||||
k = thresholds["parsing_failure"]
|
|
||||||
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
|
||||||
if len(last_k_actions) >= k:
|
|
||||||
if all(
|
|
||||||
[
|
|
||||||
action["action_type"] == ""
|
|
||||||
for action in last_k_actions
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return True, f"Failed to parse actions for {k} times"
|
|
||||||
|
|
||||||
# Case: same action for k times
|
|
||||||
k = thresholds["repeating_action"]
|
|
||||||
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
|
||||||
action_seq = trajectory[1::2] # type: ignore[assignment]
|
|
||||||
|
|
||||||
if len(action_seq) == 0:
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
last_action = action_seq[-1]
|
|
||||||
|
|
||||||
if last_action["action_type"] != ActionTypes.TYPE:
|
|
||||||
if len(last_k_actions) >= k:
|
|
||||||
if all(
|
|
||||||
[
|
|
||||||
is_equivalent(action, last_action)
|
|
||||||
for action in last_k_actions
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return True, f"Same action for {k} times"
|
|
||||||
|
|
||||||
else:
|
|
||||||
# check the action sequence
|
|
||||||
if (
|
|
||||||
sum([is_equivalent(action, last_action) for action in action_seq])
|
|
||||||
>= k
|
|
||||||
):
|
|
||||||
return True, f"Same typing action for {k} times"
|
|
||||||
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
|
|
||||||
@beartype
|
|
||||||
def test(
|
|
||||||
args: argparse.Namespace,
|
|
||||||
config_file_list: list[str]
|
|
||||||
) -> None:
|
|
||||||
scores = []
|
|
||||||
max_steps = args.max_steps
|
|
||||||
|
|
||||||
early_stop_thresholds = {
|
|
||||||
"parsing_failure": args.parsing_failure_th,
|
|
||||||
"repeating_action": args.repeating_action_failure_th,
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.observation_type in [
|
|
||||||
"accessibility_tree_with_captioner",
|
|
||||||
"image_som",
|
|
||||||
]:
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
|
||||||
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
||||||
caption_image_fn = image_utils.get_captioning_fn(
|
|
||||||
device, dtype, args.captioning_model
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
caption_image_fn = None
|
|
||||||
|
|
||||||
# Load a (possibly different) captioning model for running VQA evals.
|
|
||||||
if (
|
|
||||||
caption_image_fn
|
|
||||||
and args.eval_captioning_model == args.captioning_model
|
|
||||||
):
|
|
||||||
eval_caption_image_fn = caption_image_fn
|
|
||||||
else:
|
|
||||||
eval_caption_image_fn = image_utils.get_captioning_fn(
|
|
||||||
args.eval_captioning_model_device,
|
|
||||||
torch.float16
|
|
||||||
if (
|
|
||||||
torch.cuda.is_available()
|
|
||||||
and args.eval_captioning_model_device == "cuda"
|
|
||||||
)
|
|
||||||
else torch.float32,
|
|
||||||
args.eval_captioning_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = construct_agent(
|
|
||||||
args,
|
|
||||||
captioning_fn=caption_image_fn
|
|
||||||
if args.observation_type == "accessibility_tree_with_captioner"
|
|
||||||
else None,
|
|
||||||
) # NOTE: captioning_fn here is used for captioning input images.
|
|
||||||
|
|
||||||
env = ScriptBrowserEnv(
|
|
||||||
headless=not args.render,
|
|
||||||
slow_mo=args.slow_mo,
|
|
||||||
observation_type=args.observation_type,
|
|
||||||
current_viewport_only=args.current_viewport_only,
|
|
||||||
viewport_size={
|
|
||||||
"width": args.viewport_width,
|
|
||||||
"height": args.viewport_height,
|
|
||||||
},
|
|
||||||
save_trace_enabled=args.save_trace_enabled,
|
|
||||||
sleep_after_execution=args.sleep_after_execution,
|
|
||||||
# NOTE: captioning_fn here is used for LLM + captioning baselines.
|
|
||||||
# This can be different from the captioning model used for evals.
|
|
||||||
captioning_fn=caption_image_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
for config_file in config_file_list:
|
|
||||||
try:
|
|
||||||
render_helper = RenderHelper(
|
|
||||||
config_file, args.result_dir, args.action_set_tag
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load task.
|
|
||||||
with open(config_file) as f:
|
|
||||||
_c = json.load(f)
|
|
||||||
intent = _c["intent"]
|
|
||||||
task_id = _c["task_id"]
|
|
||||||
image_paths = _c.get("image", None)
|
|
||||||
images = []
|
|
||||||
|
|
||||||
# Load input images for the task, if any.
|
|
||||||
if image_paths is not None:
|
|
||||||
if isinstance(image_paths, str):
|
|
||||||
image_paths = [image_paths]
|
|
||||||
for image_path in image_paths:
|
|
||||||
# Load image either from the web or from a local path.
|
|
||||||
if image_path.startswith("http"):
|
|
||||||
input_image = Image.open(requests.get(image_path, stream=True).raw)
|
|
||||||
else:
|
|
||||||
input_image = Image.open(image_path)
|
|
||||||
|
|
||||||
images.append(input_image)
|
|
||||||
|
|
||||||
logger.info(f"[Config file]: {config_file}")
|
|
||||||
logger.info(f"[Intent]: {intent}")
|
|
||||||
|
|
||||||
agent.reset(config_file)
|
|
||||||
trajectory: Trajectory = []
|
|
||||||
obs, info = env.reset(options={"config_file": config_file})
|
|
||||||
state_info: StateInfo = {"observation": obs, "info": info}
|
|
||||||
trajectory.append(state_info)
|
|
||||||
|
|
||||||
meta_data = {"action_history": ["None"]}
|
|
||||||
while True:
|
|
||||||
early_stop_flag, stop_info = early_stop(
|
|
||||||
trajectory, max_steps, early_stop_thresholds
|
|
||||||
)
|
|
||||||
|
|
||||||
if early_stop_flag:
|
|
||||||
action = create_stop_action(f"Early stop: {stop_info}")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
action = agent.next_action(
|
|
||||||
trajectory,
|
|
||||||
intent,
|
|
||||||
images=images,
|
|
||||||
meta_data=meta_data,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
# get the error message
|
|
||||||
action = create_stop_action(f"ERROR: {str(e)}")
|
|
||||||
|
|
||||||
trajectory.append(action)
|
|
||||||
|
|
||||||
action_str = get_action_description(
|
|
||||||
action,
|
|
||||||
state_info["info"]["observation_metadata"],
|
|
||||||
action_set_tag=args.action_set_tag,
|
|
||||||
prompt_constructor=agent.prompt_constructor
|
|
||||||
if isinstance(agent, PromptAgent)
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
render_helper.render(
|
|
||||||
action, state_info, meta_data, args.render_screenshot
|
|
||||||
)
|
|
||||||
meta_data["action_history"].append(action_str)
|
|
||||||
|
|
||||||
if action["action_type"] == ActionTypes.STOP:
|
|
||||||
break
|
|
||||||
|
|
||||||
obs, _, terminated, _, info = env.step(action)
|
|
||||||
state_info = {"observation": obs, "info": info}
|
|
||||||
trajectory.append(state_info)
|
|
||||||
|
|
||||||
if terminated:
|
|
||||||
# add a action place holder
|
|
||||||
trajectory.append(create_stop_action(""))
|
|
||||||
break
|
|
||||||
|
|
||||||
# NOTE: eval_caption_image_fn is used for running eval_vqa functions.
|
|
||||||
evaluator = evaluator_router(
|
|
||||||
config_file, captioning_fn=eval_caption_image_fn
|
|
||||||
)
|
|
||||||
score = evaluator(
|
|
||||||
trajectory=trajectory,
|
|
||||||
config_file=config_file,
|
|
||||||
page=env.page,
|
|
||||||
client=env.get_page_client(env.page),
|
|
||||||
)
|
|
||||||
|
|
||||||
scores.append(score)
|
|
||||||
|
|
||||||
if score == 1:
|
|
||||||
logger.info(f"[Result] (PASS) {config_file}")
|
|
||||||
else:
|
|
||||||
logger.info(f"[Result] (FAIL) {config_file}")
|
|
||||||
|
|
||||||
if args.save_trace_enabled:
|
|
||||||
env.save_trace(
|
|
||||||
Path(args.result_dir) / "traces" / f"{task_id}.zip"
|
|
||||||
)
|
|
||||||
except openai.OpenAIError as e:
|
|
||||||
logger.info(f"[OpenAI Error] {repr(e)}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"[Unhandled Error] {repr(e)}]")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
# write to error file
|
|
||||||
with open(Path(args.result_dir) / "error.txt", "a") as f:
|
|
||||||
f.write(f"[Config file]: {config_file}\n")
|
|
||||||
f.write(f"[Unhandled Error] {repr(e)}\n")
|
|
||||||
f.write(traceback.format_exc()) # write stack trace to file
|
|
||||||
|
|
||||||
render_helper.close()
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
|
||||||
|
|
||||||
|
|
||||||
def prepare(args: argparse.Namespace) -> None:
|
|
||||||
# convert prompt python files to json
|
|
||||||
from agent.prompts import to_json
|
|
||||||
|
|
||||||
to_json.run()
|
|
||||||
|
|
||||||
# prepare result dir
|
|
||||||
result_dir = args.result_dir
|
|
||||||
if not result_dir:
|
|
||||||
result_dir = (
|
|
||||||
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
|
|
||||||
)
|
|
||||||
if not Path(result_dir).exists():
|
|
||||||
Path(result_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
args.result_dir = result_dir
|
|
||||||
logger.info(f"Create result dir: {result_dir}")
|
|
||||||
|
|
||||||
if not (Path(result_dir) / "traces").exists():
|
|
||||||
(Path(result_dir) / "traces").mkdir(parents=True)
|
|
||||||
|
|
||||||
# log the log file
|
|
||||||
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
|
|
||||||
f.write(f"{LOG_FILE_NAME}\n")
|
|
||||||
|
|
||||||
|
|
||||||
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
|
|
||||||
result_files = glob.glob(f"{result_dir}/*.html")
|
|
||||||
task_ids = [
|
|
||||||
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
|
|
||||||
]
|
|
||||||
unfinished_configs = []
|
|
||||||
for config_file in config_files:
|
|
||||||
task_id = os.path.basename(config_file).split(".")[0]
|
|
||||||
if task_id not in task_ids:
|
|
||||||
unfinished_configs.append(config_file)
|
|
||||||
return unfinished_configs
|
|
||||||
|
|
||||||
|
|
||||||
@beartype
|
|
||||||
def dump_config(args: argparse.Namespace) -> None:
|
|
||||||
config_file = Path(args.result_dir) / "config.json"
|
|
||||||
if not config_file.exists():
|
|
||||||
with open(config_file, "w") as f:
|
|
||||||
json.dump(vars(args), f, indent=4)
|
|
||||||
logger.info(f"Dump config to {config_file}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
####### The complete version of the list of examples #######
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
args = config()
|
|
||||||
args.sleep_after_execution = 5
|
|
||||||
prepare(args)
|
|
||||||
|
|
||||||
test_config_base_dir = args.test_config_base_dir
|
|
||||||
|
|
||||||
test_file_list = []
|
|
||||||
st_idx = args.test_start_idx
|
|
||||||
ed_idx = args.test_end_idx
|
|
||||||
for i in range(st_idx, ed_idx):
|
|
||||||
test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json"))
|
|
||||||
test_file_list = get_unfinished(test_file_list, args.result_dir)
|
|
||||||
print(f"Total {len(test_file_list)} tasks left")
|
|
||||||
args.render = False
|
|
||||||
args.render_screenshot = True
|
|
||||||
args.save_trace_enabled = True
|
|
||||||
|
|
||||||
args.current_viewport_only = True
|
|
||||||
dump_config(args)
|
|
||||||
|
|
||||||
test(args, test_file_list)
|
|
||||||
|
|
||||||
# todo: add recorder of the progress of the examples
|
|
||||||
|
|
||||||
# todo: remove the useless example files
|
|
||||||
|
|
||||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
|
||||||
test_all_meta = json.load(f)
|
|
||||||
|
|
||||||
for domain in test_all_meta:
|
|
||||||
for example_id in test_all_meta[domain]:
|
|
||||||
main(domain, example_id, args.model)
|
|
||||||
@@ -1,361 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu2\Ubuntu2.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
# PATH_TO_VM = "../../../../大文件/镜像/Ubuntu-1218/Ubuntu/Ubuntu.vmx"
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(30, stop_recording)
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
|
||||||
action_space = "pyautogui"
|
|
||||||
# example_class = "libreoffice_calc"
|
|
||||||
# example_id = "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
|
||||||
# example_id = "01b269ae-2111-4a07-81fd-3fcd711993b0"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
logger.info("Running example %s/%s", example_class, example_id)
|
|
||||||
logger.info("Using model %s", gpt4_model)
|
|
||||||
# logger.info("Using model %s", gemini_model)
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
# example["snapshot"] = "exp_setup4"
|
|
||||||
# example["snapshot"] = "Snapshot 30"
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
|
|
||||||
action_space=action_space, exp="both")
|
|
||||||
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="both")
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
os_list = [
|
|
||||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
|
||||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
|
||||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
|
||||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
|
||||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
|
||||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
|
||||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
|
||||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
|
||||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
|
||||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
|
||||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
|
||||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
|
||||||
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
|
|
||||||
"13584542-872b-42d8-b299-866967b5c3ef",
|
|
||||||
"23393935-50c7-4a86-aeea-2b78fd089c5c"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in os_list:
|
|
||||||
# try:
|
|
||||||
# main("os", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
calc_list = [
|
|
||||||
"a9f325aa-8c05-4e4f-8341-9e4358565f4f",
|
|
||||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
|
||||||
"7efeb4b1-3d19-4762-b163-63328d66303b",
|
|
||||||
"4e6fcf72-daf3-439f-a232-c434ce416af6",
|
|
||||||
"6054afcb-5bab-4702-90a0-b259b5d3217c",
|
|
||||||
"abed40dc-063f-4598-8ba5-9fe749c0615d",
|
|
||||||
"01b269ae-2111-4a07-81fd-3fcd711993b0",
|
|
||||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
|
||||||
"af2b02f7-acee-4be4-8b66-499fab394915",
|
|
||||||
"da1d63b8-fa12-417b-ba18-f748e5f770f3",
|
|
||||||
"636380ea-d5f6-4474-b6ca-b2ed578a20f1",
|
|
||||||
"5ba77536-05c5-4aae-a9ff-6e298d094c3e",
|
|
||||||
"4bc4eaf4-ca5e-4db2-8138-8d4e65af7c0b",
|
|
||||||
"672a1b02-c62f-4ae2-acf0-37f5fb3052b0",
|
|
||||||
"648fe544-16ba-44af-a587-12ccbe280ea6",
|
|
||||||
"8985d1e4-5b99-4711-add4-88949ebb2308",
|
|
||||||
"9e606842-2e27-43bf-b1d1-b43289c9589b",
|
|
||||||
"fcb6e45b-25c4-4087-9483-03d714f473a9",
|
|
||||||
"68c0c5b7-96f3-4e87-92a7-6c1b967fd2d2",
|
|
||||||
"fff629ea-046e-4793-8eec-1a5a15c3eb35",
|
|
||||||
"5c9a206c-bb00-4fb6-bb46-ee675c187df5",
|
|
||||||
"e975ae74-79bd-4672-8d1c-dc841a85781d",
|
|
||||||
"34a6938a-58da-4897-8639-9b90d6db5391",
|
|
||||||
"b5a22759-b4eb-4bf2-aeed-ad14e8615f19",
|
|
||||||
"2f9913a1-51ed-4db6-bfe0-7e1c95b3139e",
|
|
||||||
"2558031e-401d-4579-8e00-3ecf540fb492",
|
|
||||||
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
|
|
||||||
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
|
|
||||||
"51b11269-2ca8-4b2a-9163-f21758420e78",
|
|
||||||
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
|
|
||||||
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
|
||||||
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
|
|
||||||
"3aaa4e37-dc91-482e-99af-132a612d40f3",
|
|
||||||
"37608790-6147-45d0-9f20-1137bb35703d",
|
|
||||||
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
|
||||||
"d681960f-7bc3-4286-9913-a8812ba3261a",
|
|
||||||
"21df9241-f8d7-4509-b7f1-37e501a823f7",
|
|
||||||
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
|
|
||||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
|
||||||
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
|
|
||||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
|
||||||
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in calc_list:
|
|
||||||
# try:
|
|
||||||
# main("libreoffice_calc", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
impress_list = [
|
|
||||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
|
||||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
|
||||||
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
|
||||||
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
|
||||||
"c59742c0-4323-4b9d-8a02-723c251deaa0",
|
|
||||||
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
|
||||||
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
|
||||||
"0f84bef9-9790-432e-92b7-eece357603fb",
|
|
||||||
"ce88f674-ab7a-43da-9201-468d38539e4a",
|
|
||||||
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
|
||||||
"a097acff-6266-4291-9fbd-137af7ecd439",
|
|
||||||
"bf4e9888-f10f-47af-8dba-76413038b73c",
|
|
||||||
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in impress_list:
|
|
||||||
# try:
|
|
||||||
# main("libreoffice_impress", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
vs_code_list = [
|
|
||||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
|
||||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
|
||||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
|
||||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
|
||||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
|
||||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
|
||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
|
||||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in vs_code_list:
|
|
||||||
# try:
|
|
||||||
# main("vs_code", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
multiple_list = [
|
|
||||||
"f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
|
|
||||||
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
|
||||||
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
|
||||||
"b52b40a5-ad70-4c53-b5b0-5650a8387052",
|
|
||||||
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
|
||||||
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
|
||||||
"51f5801c-18b3-4f25-b0c3-02f85507a078",
|
|
||||||
"2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
|
|
||||||
"510f64c8-9bcc-4be1-8d30-638705850618",
|
|
||||||
"937087b6-f668-4ba6-9110-60682ee33441",
|
|
||||||
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
|
||||||
"3680a5ee-6870-426a-a997-eba929a0d25c",
|
|
||||||
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
|
||||||
"58565672-7bfe-48ab-b828-db349231de6b",
|
|
||||||
"2fe4b718-3bd7-46ec-bdce-b184f5653624"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in multiple_list:
|
|
||||||
# try:
|
|
||||||
# main("multi_apps", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
chrome_list = [
|
|
||||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
|
||||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
|
||||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in chrome_list:
|
|
||||||
# try:
|
|
||||||
# main("chrome", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
|
|
||||||
writer_list = [
|
|
||||||
"6ada715d-3aae-4a32-a6a7-429b2e43fb93",
|
|
||||||
"ecc2413d-8a48-416e-a3a2-d30106ca36cb",
|
|
||||||
"0e47de2a-32e0-456c-a366-8c607ef7a9d2",
|
|
||||||
"4bcb1253-a636-4df4-8cb0-a35c04dfef31",
|
|
||||||
"0810415c-bde4-4443-9047-d5f70165a697",
|
|
||||||
"e528b65e-1107-4b8c-8988-490e4fece599",
|
|
||||||
"66399b0d-8fda-4618-95c4-bfc6191617e9",
|
|
||||||
"936321ce-5236-426a-9a20-e0e3c5dc536f",
|
|
||||||
"3ef2b351-8a84-4ff2-8724-d86eae9b842e",
|
|
||||||
"0b17a146-2934-46c7-8727-73ff6b6483e8",
|
|
||||||
"0e763496-b6bb-4508-a427-fad0b6c3e195",
|
|
||||||
"f178a4a9-d090-4b56-bc4c-4b72a61a035d",
|
|
||||||
"adf5e2c3-64c7-4644-b7b6-d2f0167927e7",
|
|
||||||
"0a0faba3-5580-44df-965d-f562a99b291c",
|
|
||||||
"e246f6d8-78d7-44ac-b668-fcf47946cb50",
|
|
||||||
"8472fece-c7dd-4241-8d65-9b3cd1a0b568",
|
|
||||||
"88fe4b2d-3040-4c70-9a70-546a47764b48",
|
|
||||||
"d53ff5ee-3b1a-431e-b2be-30ed2673079b",
|
|
||||||
"72b810ef-4156-4d09-8f08-a0cf57e7cefe",
|
|
||||||
"6f81754e-285d-4ce0-b59e-af7edb02d108",
|
|
||||||
"b21acd93-60fd-4127-8a43-2f5178f4a830"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in writer_list:
|
|
||||||
try:
|
|
||||||
main("libreoffice_writer", example_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("An error occurred while running the example: %s", e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
import ctypes
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(30, stop_recording)
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id):
|
|
||||||
action_space = "pyautogui"
|
|
||||||
gpt4_model = "gpt-4-vision-preview"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
|
|
||||||
action_space=action_space, exp="seeact")
|
|
||||||
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
xx_list = [
|
|
||||||
]
|
|
||||||
for example_id in xx_list:
|
|
||||||
main("xx", example_id)
|
|
||||||
@@ -1,261 +0,0 @@
|
|||||||
#import ctypes
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(30, stop_recording)
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id):
|
|
||||||
action_space = "pyautogui"
|
|
||||||
gpt4_model = "gpt-4-vision-preview"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
|
|
||||||
logger.info("TASK: %s/%s", example_class, example_id)
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, max_tokens=1000, instruction=example['instruction'],
|
|
||||||
action_space=action_space, exp="som")
|
|
||||||
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from tqdm import tqdm
|
|
||||||
# impress_list = [
|
|
||||||
# # "5d901039-a89c-4bfb-967b-bf66f4df075e",
|
|
||||||
# "550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
|
||||||
# "455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
|
||||||
# "af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
|
||||||
# "c59742c0-4323-4b9d-8a02-723c251deaa0",
|
|
||||||
# "ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
|
||||||
# "9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
|
||||||
# "0f84bef9-9790-432e-92b7-eece357603fb",
|
|
||||||
# "ce88f674-ab7a-43da-9201-468d38539e4a",
|
|
||||||
# "3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
|
||||||
# "a097acff-6266-4291-9fbd-137af7ecd439",
|
|
||||||
# "bf4e9888-f10f-47af-8dba-76413038b73c",
|
|
||||||
# "21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
|
||||||
# ]
|
|
||||||
# for example_id in impress_list:
|
|
||||||
# main("libreoffice_impress", example_id)
|
|
||||||
|
|
||||||
vlc_list = [
|
|
||||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
|
||||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
|
||||||
"8f080098-ddb1-424c-b438-4e96e5e4786e",
|
|
||||||
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
|
|
||||||
"fba2c100-79e8-42df-ae74-b592418d54f4",
|
|
||||||
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
|
|
||||||
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
|
|
||||||
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
|
|
||||||
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
|
|
||||||
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
|
|
||||||
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
|
|
||||||
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
|
|
||||||
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
|
|
||||||
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in tqdm(vlc_list):
|
|
||||||
# try:
|
|
||||||
# main("vlc", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An error occurred while running the example: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
chrome_list = [
|
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
|
||||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
|
||||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
|
||||||
]
|
|
||||||
for example_id in tqdm(chrome_list):
|
|
||||||
try:
|
|
||||||
main("chrome", example_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while running the example: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
vs_code_list = [
|
|
||||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
|
||||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
|
||||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
|
||||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
|
||||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
|
||||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
|
||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
|
||||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in tqdm(vs_code_list):
|
|
||||||
try:
|
|
||||||
main("vs_code", example_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while running the example: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
thunderbird_list = [
|
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
|
||||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
|
||||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
|
||||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in tqdm(thunderbird_list):
|
|
||||||
try:
|
|
||||||
main("thunderbird", example_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while running the example: {e}")
|
|
||||||
continue
|
|
||||||
@@ -5,21 +5,20 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import dashscope
|
import dashscope
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import openai
|
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from openai import (
|
from vertexai.preview.generative_models import (
|
||||||
APIConnectionError,
|
HarmBlockThreshold,
|
||||||
APIError,
|
HarmCategory,
|
||||||
RateLimitError
|
Image,
|
||||||
)
|
)
|
||||||
|
|
||||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||||
@@ -29,7 +28,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
|
|||||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
||||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
||||||
|
|
||||||
import logging
|
|
||||||
# todo: cross-check with visualwebarena
|
# todo: cross-check with visualwebarena
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.agent")
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
@@ -42,7 +40,7 @@ def encode_image(image_path):
|
|||||||
|
|
||||||
|
|
||||||
def linearize_accessibility_tree(accessibility_tree):
|
def linearize_accessibility_tree(accessibility_tree):
|
||||||
#leaf_nodes = find_leaf_nodes(accessibility_tree)
|
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||||
|
|
||||||
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
||||||
@@ -172,60 +170,56 @@ def parse_code_from_som_string(input_string, masks):
|
|||||||
class PromptAgent:
|
class PromptAgent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key,
|
|
||||||
instruction,
|
|
||||||
model="gpt-4-vision-preview",
|
model="gpt-4-vision-preview",
|
||||||
max_tokens=500,
|
max_tokens=1500,
|
||||||
|
top_p=0.9,
|
||||||
|
temperature=0.5,
|
||||||
action_space="computer_13",
|
action_space="computer_13",
|
||||||
exp="screenshot_a11y_tree"
|
observation_type="screenshot_a11y_tree",
|
||||||
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
||||||
|
max_trajectory_length=3
|
||||||
):
|
):
|
||||||
|
|
||||||
self.instruction = instruction
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
|
self.temperature = temperature
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
self.exp = exp
|
self.observation_type = observation_type
|
||||||
self.max_trajectory_length = 3
|
self.max_trajectory_length = max_trajectory_length
|
||||||
|
|
||||||
self.headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
self.thoughts = []
|
self.thoughts = []
|
||||||
self.actions = []
|
self.actions = []
|
||||||
self.observations = []
|
self.observations = []
|
||||||
|
|
||||||
if exp == "screenshot":
|
if observation_type == "screenshot":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif exp == "a11y_tree":
|
elif observation_type == "a11y_tree":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif exp == "both":
|
elif observation_type == "both":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif exp == "som":
|
elif observation_type == "som":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif exp == "seeact":
|
elif observation_type == "seeact":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
@@ -233,15 +227,14 @@ class PromptAgent:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid experiment type: " + exp)
|
raise ValueError("Invalid experiment type: " + observation_type)
|
||||||
|
|
||||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
def predict(self, instruction: str, obs: Dict) -> List:
|
||||||
self.instruction)
|
|
||||||
|
|
||||||
def predict(self, obs: Dict) -> List:
|
|
||||||
"""
|
"""
|
||||||
Predict the next action(s) based on the current observation.
|
Predict the next action(s) based on the current observation.
|
||||||
"""
|
"""
|
||||||
|
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
||||||
|
instruction)
|
||||||
|
|
||||||
# Prepare the payload for the API call
|
# Prepare the payload for the API call
|
||||||
messages = []
|
messages = []
|
||||||
@@ -273,7 +266,7 @@ class PromptAgent:
|
|||||||
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
||||||
|
|
||||||
# {{{1
|
# {{{1
|
||||||
if self.exp == "both":
|
if self.observation_type == "both":
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||||
@@ -295,7 +288,7 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp in ["som", "seeact"]:
|
elif self.observation_type in ["som", "seeact"]:
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||||
@@ -317,7 +310,7 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "screenshot":
|
elif self.observation_type == "screenshot":
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -336,7 +329,7 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "a11y_tree":
|
elif self.observation_type == "a11y_tree":
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -350,7 +343,7 @@ class PromptAgent:
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -363,11 +356,11 @@ class PromptAgent:
|
|||||||
})
|
})
|
||||||
|
|
||||||
# {{{1
|
# {{{1
|
||||||
if self.exp in ["screenshot", "both"]:
|
if self.observation_type in ["screenshot", "both"]:
|
||||||
base64_image = encode_image(obs["screenshot"])
|
base64_image = encode_image(obs["screenshot"])
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||||
|
|
||||||
if self.exp == "both":
|
if self.observation_type == "both":
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image,
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
"accessibility_tree": linearized_accessibility_tree
|
||||||
@@ -384,7 +377,7 @@ class PromptAgent:
|
|||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
||||||
if self.exp == "screenshot"
|
if self.observation_type == "screenshot"
|
||||||
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||||
linearized_accessibility_tree)
|
linearized_accessibility_tree)
|
||||||
},
|
},
|
||||||
@@ -397,7 +390,7 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "a11y_tree":
|
elif self.observation_type == "a11y_tree":
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
@@ -415,7 +408,7 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "som":
|
elif self.observation_type == "som":
|
||||||
# Add som to the screenshot
|
# Add som to the screenshot
|
||||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||||
base64_image = encode_image(tagged_screenshot)
|
base64_image = encode_image(tagged_screenshot)
|
||||||
@@ -443,7 +436,7 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "seeact":
|
elif self.observation_type == "seeact":
|
||||||
# Add som to the screenshot
|
# Add som to the screenshot
|
||||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||||
base64_image = encode_image(tagged_screenshot)
|
base64_image = encode_image(tagged_screenshot)
|
||||||
@@ -471,21 +464,21 @@ class PromptAgent:
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||||
|
|
||||||
with open("messages.json", "w") as f:
|
|
||||||
f.write(json.dumps(messages, indent=4))
|
|
||||||
|
|
||||||
|
# with open("messages.json", "w") as f:
|
||||||
|
# f.write(json.dumps(messages, indent=4))
|
||||||
|
|
||||||
|
logger.info("Generating content with GPT model: %s", self.model)
|
||||||
response = self.call_llm({
|
response = self.call_llm({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"max_tokens": self.max_tokens
|
"max_tokens": self.max_tokens
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.debug("RESPONSE: %s", response)
|
logger.info("RESPONSE: %s", response)
|
||||||
|
|
||||||
if self.exp == "seeact":
|
if self.observation_type == "seeact":
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
@@ -507,12 +500,15 @@ class PromptAgent:
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
logger.info("Generating content with GPT model: %s", self.model)
|
||||||
response = self.call_llm({
|
response = self.call_llm({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"max_tokens": self.max_tokens
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature
|
||||||
})
|
})
|
||||||
print(response)
|
logger.info("RESPONSE: %s", response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
actions = self.parse_actions(response, masks)
|
actions = self.parse_actions(response, masks)
|
||||||
@@ -527,85 +523,90 @@ class PromptAgent:
|
|||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
(Exception),
|
(Exception),
|
||||||
max_tries=10
|
max_tries=5
|
||||||
)
|
)
|
||||||
def call_llm(self, payload):
|
def call_llm(self, payload):
|
||||||
|
|
||||||
if self.model.startswith("gpt"):
|
if self.model.startswith("gpt"):
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||||
|
}
|
||||||
logger.info("Generating content with GPT model: %s", self.model)
|
logger.info("Generating content with GPT model: %s", self.model)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.openai.com/v1/chat/completions",
|
"https://api.openai.com/v1/chat/completions",
|
||||||
headers=self.headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
if response.json()['error']['code'] == "context_length_exceeded":
|
if response.json()['error']['code'] == "context_length_exceeded":
|
||||||
print("Context length exceeded. Retrying with a smaller context.")
|
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||||
payload["messages"] = payload["messages"][-1:]
|
payload["messages"] = payload["messages"][-1:]
|
||||||
retry_response = requests.post(
|
retry_response = requests.post(
|
||||||
"https://api.openai.com/v1/chat/completions",
|
"https://api.openai.com/v1/chat/completions",
|
||||||
headers=self.headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
if retry_response.status_code != 200:
|
if retry_response.status_code != 200:
|
||||||
print("Failed to call LLM: " + retry_response.text)
|
logger.error("Failed to call LLM: " + retry_response.text)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
print("Failed to call LLM: " + response.text)
|
logger.error("Failed to call LLM: " + response.text)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
return response.json()['choices'][0]['message']['content']
|
return response.json()['choices'][0]['message']['content']
|
||||||
|
|
||||||
elif self.model.startswith("mistral"):
|
# elif self.model.startswith("mistral"):
|
||||||
print("call mistral")
|
# print("Call mistral")
|
||||||
messages = payload["messages"]
|
# messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
# max_tokens = payload["max_tokens"]
|
||||||
|
#
|
||||||
misrtal_messages = []
|
# misrtal_messages = []
|
||||||
|
#
|
||||||
for i, message in enumerate(messages):
|
# for i, message in enumerate(messages):
|
||||||
mistral_message = {
|
# mistral_message = {
|
||||||
"role": message["role"],
|
# "role": message["role"],
|
||||||
"content": []
|
# "content": []
|
||||||
}
|
# }
|
||||||
|
#
|
||||||
for part in message["content"]:
|
# for part in message["content"]:
|
||||||
mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||||
|
#
|
||||||
misrtal_messages.append(mistral_message)
|
# misrtal_messages.append(mistral_message)
|
||||||
|
#
|
||||||
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||||
if misrtal_messages[0]['role'] == "system":
|
# if misrtal_messages[0]['role'] == "system":
|
||||||
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||||
misrtal_messages.pop(0)
|
# misrtal_messages.pop(0)
|
||||||
|
#
|
||||||
# openai.api_base = "http://localhost:8000/v1"
|
# # openai.api_base = "http://localhost:8000/v1"
|
||||||
# openai.api_key = "test"
|
# # openai.api_key = "test"
|
||||||
# response = openai.ChatCompletion.create(
|
# # response = openai.ChatCompletion.create(
|
||||||
# messages=misrtal_messages,
|
# # messages=misrtal_messages,
|
||||||
# model="Mixtral-8x7B-Instruct-v0.1"
|
# # model="Mixtral-8x7B-Instruct-v0.1"
|
||||||
# )
|
# # )
|
||||||
|
#
|
||||||
from openai import OpenAI
|
# from openai import OpenAI
|
||||||
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||||
|
#
|
||||||
client = OpenAI(api_key=TOGETHER_API_KEY,
|
# client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||||
base_url='https://api.together.xyz',
|
# base_url='https://api.together.xyz',
|
||||||
)
|
# )
|
||||||
logger.info("Generating content with Mistral model: %s", self.model)
|
# logger.info("Generating content with Mistral model: %s", self.model)
|
||||||
response = client.chat.completions.create(
|
# response = client.chat.completions.create(
|
||||||
messages=misrtal_messages,
|
# messages=misrtal_messages,
|
||||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
max_tokens=1024
|
# max_tokens=1024
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
try:
|
# try:
|
||||||
# return response['choices'][0]['message']['content']
|
# # return response['choices'][0]['message']['content']
|
||||||
return response.choices[0].message.content
|
# return response.choices[0].message.content
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print("Failed to call LLM: " + str(e))
|
# print("Failed to call LLM: " + str(e))
|
||||||
return ""
|
# return ""
|
||||||
|
|
||||||
elif self.model.startswith("gemini"):
|
elif self.model.startswith("gemini"):
|
||||||
def encoded_img_to_pil_img(data_str):
|
def encoded_img_to_pil_img(data_str):
|
||||||
@@ -617,6 +618,8 @@ class PromptAgent:
|
|||||||
|
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
gemini_messages = []
|
gemini_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
@@ -662,7 +665,17 @@ class PromptAgent:
|
|||||||
response = genai.GenerativeModel(self.model).generate_content(
|
response = genai.GenerativeModel(self.model).generate_content(
|
||||||
gemini_messages,
|
gemini_messages,
|
||||||
generation_config={
|
generation_config={
|
||||||
"max_output_tokens": max_tokens
|
"candidate_count": 1,
|
||||||
|
"max_output_tokens": max_tokens,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature
|
||||||
|
},
|
||||||
|
safety_settings={
|
||||||
|
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -673,6 +686,8 @@ class PromptAgent:
|
|||||||
elif self.model.startswith("qwen"):
|
elif self.model.startswith("qwen"):
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
qwen_messages = []
|
qwen_messages = []
|
||||||
|
|
||||||
@@ -683,13 +698,16 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||||
for part in message["content"]:
|
for part in message["content"]:
|
||||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
|
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
|
||||||
|
'type'] == "image_url" else None
|
||||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||||
|
|
||||||
qwen_messages.append(qwen_message)
|
qwen_messages.append(qwen_message)
|
||||||
|
|
||||||
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
|
response = dashscope.MultiModalConversation.call(
|
||||||
messages=messages)
|
model='qwen-vl-plus',
|
||||||
|
messages=messages, # todo: add the hyperparameters
|
||||||
|
)
|
||||||
# The response status_code is HTTPStatus.OK indicate success,
|
# The response status_code is HTTPStatus.OK indicate success,
|
||||||
# otherwise indicate request is failed, you can get error code
|
# otherwise indicate request is failed, you can get error code
|
||||||
# and message from code and message.
|
# and message from code and message.
|
||||||
@@ -708,7 +726,7 @@ class PromptAgent:
|
|||||||
|
|
||||||
def parse_actions(self, response: str, masks=None):
|
def parse_actions(self, response: str, masks=None):
|
||||||
|
|
||||||
if self.exp in ["screenshot", "a11y_tree", "both"]:
|
if self.observation_type in ["screenshot", "a11y_tree", "both"]:
|
||||||
# parse from the response
|
# parse from the response
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
actions = parse_actions_from_string(response)
|
actions = parse_actions_from_string(response)
|
||||||
@@ -720,7 +738,7 @@ class PromptAgent:
|
|||||||
self.actions.append(actions)
|
self.actions.append(actions)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
elif self.exp in ["som", "seeact"]:
|
elif self.observation_type in ["som", "seeact"]:
|
||||||
# parse from the response
|
# parse from the response
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + self.action_space)
|
raise ValueError("Invalid action space: " + self.action_space)
|
||||||
@@ -732,3 +750,8 @@ class PromptAgent:
|
|||||||
self.actions.append(actions)
|
self.actions.append(actions)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.thoughts = []
|
||||||
|
self.actions = []
|
||||||
|
self.observations = []
|
||||||
|
|||||||
218
run.py
Normal file
218
run.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""Script to run end-to-end evaluation on the benchmark.
|
||||||
|
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from desktop_env.envs.desktop_env import DesktopEnv
|
||||||
|
from mm_agents.agent import PromptAgent
|
||||||
|
|
||||||
|
# Logger Configs {{{ #
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||||
|
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||||
|
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
debug_handler.setLevel(logging.DEBUG)
|
||||||
|
stdout_handler.setLevel(logging.INFO)
|
||||||
|
sdebug_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
debug_handler.setFormatter(formatter)
|
||||||
|
stdout_handler.setFormatter(formatter)
|
||||||
|
sdebug_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(debug_handler)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
logger.addHandler(sdebug_handler)
|
||||||
|
# }}} Logger Configs #
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
|
def config() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run end-to-end evaluation on the benchmark"
|
||||||
|
)
|
||||||
|
|
||||||
|
# environment config
|
||||||
|
parser.add_argument("--path_to_vm", type=str,
|
||||||
|
default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx")
|
||||||
|
parser.add_argument(
|
||||||
|
"--headless", action="store_true", help="Run in headless machine"
|
||||||
|
)
|
||||||
|
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
|
||||||
|
parser.add_argument(
|
||||||
|
"--observation_type",
|
||||||
|
choices=[
|
||||||
|
"screenshot",
|
||||||
|
"a11y_tree",
|
||||||
|
"screenshot_a11y_tree",
|
||||||
|
"som"
|
||||||
|
],
|
||||||
|
default="a11y_tree",
|
||||||
|
help="Observation type",
|
||||||
|
)
|
||||||
|
parser.add_argument("--screen_width", type=int, default=1920)
|
||||||
|
parser.add_argument("--screen_height", type=int, default=1080)
|
||||||
|
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||||
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
|
||||||
|
# agent config
|
||||||
|
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||||
|
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
||||||
|
|
||||||
|
# lm config
|
||||||
|
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||||
|
parser.add_argument("--stop_token", type=str, default=None)
|
||||||
|
|
||||||
|
# logging related
|
||||||
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
test_all_meta: dict
|
||||||
|
) -> None:
|
||||||
|
scores = []
|
||||||
|
max_steps = args.max_steps
|
||||||
|
|
||||||
|
# log args
|
||||||
|
logger.info("Args: %s", args)
|
||||||
|
|
||||||
|
agent = PromptAgent(
|
||||||
|
model=args.model,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
action_space=args.action_space,
|
||||||
|
observation_type=args.observation_type,
|
||||||
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
env = DesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=agent.action_space,
|
||||||
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
|
headless=args.headless,
|
||||||
|
)
|
||||||
|
|
||||||
|
for domain in test_all_meta:
|
||||||
|
for example_id in test_all_meta[domain]:
|
||||||
|
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
example = json.load(f)
|
||||||
|
|
||||||
|
logger.info(f"[Domain]: {domain}")
|
||||||
|
logger.info(f"[Example ID]: {example_id}")
|
||||||
|
|
||||||
|
instruction = example["instruction"]
|
||||||
|
|
||||||
|
logger.info(f"[Instruction]: {instruction}")
|
||||||
|
|
||||||
|
example_result_dir = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
domain,
|
||||||
|
example_id
|
||||||
|
)
|
||||||
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
|
||||||
|
agent.reset()
|
||||||
|
obs = env.reset(task_config=example)
|
||||||
|
done = False
|
||||||
|
step_idx = 0
|
||||||
|
env.controller.start_recording()
|
||||||
|
|
||||||
|
while not done and step_idx < max_steps:
|
||||||
|
actions = agent.predict(
|
||||||
|
instruction,
|
||||||
|
obs
|
||||||
|
)
|
||||||
|
|
||||||
|
for action in actions:
|
||||||
|
step_idx += 1
|
||||||
|
# Capture the timestamp before executing the action
|
||||||
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
logger.info("Step %d: %s", step_idx + 1, action)
|
||||||
|
|
||||||
|
observation, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||||
|
|
||||||
|
logger.info("Reward: %.2f", reward)
|
||||||
|
logger.info("Done: %s", done)
|
||||||
|
logger.info("Info: %s", info)
|
||||||
|
|
||||||
|
# Save screenshot and trajectory information
|
||||||
|
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||||
|
"wb") as _f:
|
||||||
|
with open(observation['screenshot'], "rb") as __f:
|
||||||
|
screenshot = __f.read()
|
||||||
|
_f.write(screenshot)
|
||||||
|
|
||||||
|
with open(os.path.join(example_result_dir, "traj.json"), "a") as f:
|
||||||
|
f.write(json.dumps({
|
||||||
|
"step_num": step_idx + 1,
|
||||||
|
"action_timestamp": action_timestamp,
|
||||||
|
"action": action,
|
||||||
|
"reward": reward,
|
||||||
|
"done": done,
|
||||||
|
"info": info,
|
||||||
|
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||||
|
}))
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
if done:
|
||||||
|
logger.info("The episode is done.")
|
||||||
|
break
|
||||||
|
|
||||||
|
result = env.evaluate()
|
||||||
|
logger.info("Result: %.2f", result)
|
||||||
|
scores.append(result)
|
||||||
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unfinished(test_file_list, result_dir):
|
||||||
|
finished = []
|
||||||
|
for domain in os.listdir(result_dir):
|
||||||
|
for example_id in os.listdir(os.path.join(result_dir, domain)):
|
||||||
|
finished.append(f"{domain}/{example_id}")
|
||||||
|
return [x for x in test_file_list if x not in finished]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
####### The complete version of the list of examples #######
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
args = config()
|
||||||
|
|
||||||
|
# test_file_list = get_unfinished(args.test, args.result_dir)
|
||||||
|
# logger.info(f"Total {len(test_file_list)} tasks left")
|
||||||
|
|
||||||
|
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
||||||
|
test_all_meta = json.load(f)
|
||||||
|
|
||||||
|
test(args, test_all_meta)
|
||||||
Reference in New Issue
Block a user