Merge branch 'main' of https://github.com/xlang-ai/DesktopEnv
This commit is contained in:
19
.vscode/launch.json
vendored
Normal file
19
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: Current File with Arguments",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx"
|
||||
// "--example_time_limit", "60"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -21,10 +21,12 @@
|
||||
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
|
||||
|
||||
2. Install the environment package, download the examples and the virtual machine image.
|
||||
For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands:
|
||||
```bash
|
||||
pip install desktop-env
|
||||
gdown xxxx
|
||||
gdown xxxx
|
||||
vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui
|
||||
vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -53,8 +53,8 @@ class DesktopEnv(gym.Env):
|
||||
def __init__(
|
||||
self,
|
||||
path_to_vm: str,
|
||||
snapshot_name: str = "init_state",
|
||||
action_space: str = "computer_13",
|
||||
task_config: Dict[str, Any] = None,
|
||||
tmp_dir: str = "tmp",
|
||||
cache_dir: str = "cache",
|
||||
screen_size: Tuple[int] = (1920, 1080),
|
||||
@@ -64,15 +64,6 @@ class DesktopEnv(gym.Env):
|
||||
Args:
|
||||
path_to_vm (str): path to .vmx file
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
|
||||
task_config (Dict[str, Any]): manages task configs integratedly,
|
||||
including
|
||||
* base snapshot
|
||||
* task id (uuid)
|
||||
* instruction
|
||||
* setup config
|
||||
* evaluator config
|
||||
|
||||
tmp_dir (str): temporary directory to store trajectory stuffs like
|
||||
the extracted screenshots
|
||||
cache_dir (str): cache directory to cache task-related stuffs like
|
||||
@@ -81,23 +72,20 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
||||
self.snapshot_name = snapshot_name
|
||||
self.tmp_dir_base: str = tmp_dir
|
||||
self.cache_dir_base: str = cache_dir
|
||||
self.vm_screen_size = screen_size
|
||||
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
|
||||
self.headless = headless
|
||||
|
||||
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
||||
|
||||
# task-aware stuffs
|
||||
# todo: handling the logic of snapshot directory
|
||||
self._set_task_info(task_config)
|
||||
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
self.vm_ip = self._get_vm_ip()
|
||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||
|
||||
# Meta info of the VM, move to the reset() function
|
||||
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
||||
@@ -147,7 +135,7 @@ class DesktopEnv(gym.Env):
|
||||
raise Exception("Failed to get VM IP address!")
|
||||
|
||||
def _save_state(self):
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
||||
|
||||
def _get_screenshot(self):
|
||||
# random_uuid = str(uuid.uuid4())
|
||||
@@ -167,7 +155,6 @@ class DesktopEnv(gym.Env):
|
||||
return screenshot_image_path
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
self.snapshot_path = task_config["snapshot"] # todo: save the snapshot when first start the environment, and then revert to it when reset
|
||||
self.task_id: str = task_config["id"]
|
||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
@@ -187,7 +174,7 @@ class DesktopEnv(gym.Env):
|
||||
if isinstance(self.evaluator["func"], list) \
|
||||
else getattr(metrics, self.evaluator["func"])
|
||||
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
||||
if "result" in self.evaluator:
|
||||
if "result" in self.evaluator and len(self.evaluator["result"])>0:
|
||||
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
||||
self.evaluator["result"]] \
|
||||
if isinstance(self.evaluator["result"], list) \
|
||||
@@ -197,7 +184,7 @@ class DesktopEnv(gym.Env):
|
||||
if isinstance(self.metric, list) \
|
||||
else None
|
||||
|
||||
if "expected" in self.evaluator:
|
||||
if "expected" in self.evaluator and len(self.evaluator["expected"])>0:
|
||||
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
||||
self.evaluator["expected"]] \
|
||||
if isinstance(self.evaluator["expected"], list) \
|
||||
@@ -239,8 +226,8 @@ class DesktopEnv(gym.Env):
|
||||
)
|
||||
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
||||
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
|
||||
time.sleep(5)
|
||||
|
||||
print(self.vm_screen_size)
|
||||
|
||||
@@ -284,6 +284,15 @@ def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = N
|
||||
text = text.replace("\ufffc", "").replace("\ufffd", "")
|
||||
# }}} Text #
|
||||
|
||||
# Image {{{ #
|
||||
try:
|
||||
node.queryImage()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
attribute_dict["image"] = "true"
|
||||
# }}} Image #
|
||||
|
||||
# Selection {{{ #
|
||||
try:
|
||||
node.querySelection()
|
||||
|
||||
16
desktop_env/server/osbench_server.service
Normal file
16
desktop_env/server/osbench_server.service
Normal file
@@ -0,0 +1,16 @@
|
||||
[Unit]
|
||||
Description=OSBench Server
|
||||
StartLimitIntervalSec=60
|
||||
StartLimitBurst=4
|
||||
After=network.target auditd.service
|
||||
|
||||
[Service]
|
||||
ExecStart=/usr/bin/python3 /home/user/main.py
|
||||
User=user
|
||||
WorkingDirectory=/home/user
|
||||
Restart=on-failure
|
||||
RestartSec=1
|
||||
Environment="DISPLAY=:1"
|
||||
|
||||
[Install]
|
||||
WantedBy=graphical.target
|
||||
16
desktop_env/server/osbench_server@.service
Normal file
16
desktop_env/server/osbench_server@.service
Normal file
@@ -0,0 +1,16 @@
|
||||
[Unit]
|
||||
Description=OSBench Server
|
||||
StartLimitIntervalSec=60
|
||||
StartLimitBurst=4
|
||||
After=network.target auditd.service
|
||||
|
||||
[Service]
|
||||
ExecStart=/usr/bin/python3 /home/user/main.py
|
||||
User=user
|
||||
WorkingDirectory=/home/user
|
||||
Restart=on-failure
|
||||
RestartSec=1
|
||||
Environment="DISPLAY=%i"
|
||||
|
||||
[Install]
|
||||
WantedBy=graphical.target
|
||||
@@ -10,10 +10,6 @@
|
||||
"libreoffice_calc"
|
||||
],
|
||||
"evaluator": {
|
||||
"func": "infeasible",
|
||||
"expected": {
|
||||
},
|
||||
"result": {
|
||||
}
|
||||
"func": "infeasible"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,10 +10,6 @@
|
||||
"libreoffice_calc"
|
||||
],
|
||||
"evaluator": {
|
||||
"func": "infeasible",
|
||||
"expected": {
|
||||
},
|
||||
"result": {
|
||||
}
|
||||
"func": "infeasible"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
19
evaluation_examples/examples/multi_apps/demo.py
Normal file
19
evaluation_examples/examples/multi_apps/demo.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import pandas as pd
|
||||
|
||||
file_path = "/Users/lxc/Downloads/Speedtest.csv"
|
||||
# 找到csv第二行的第二个数据格里的值
|
||||
# with open(file_path, "r") as f:
|
||||
# for i, line in enumerate(f):
|
||||
# if i == 1:
|
||||
# data = line.split(",")[1]
|
||||
# break
|
||||
# print(data)
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
reader = pd.read_csv(f, sep=',', header=None)
|
||||
# for column in reader.columns:
|
||||
# if column.startswith("TEST_DATE"):
|
||||
# data_col = column
|
||||
# break
|
||||
for data in reader['TEST_DATE']:
|
||||
print(data)
|
||||
@@ -103,7 +103,6 @@
|
||||
"1e8df695-bd1b-45b3-b557-e7d599cf7597",
|
||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||
"7b802dad-6e0f-4204-9815-d4e3f57627d8",
|
||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
||||
"0326d92d-d218-48a8-9ca1-981cd6d064c7",
|
||||
"0a2e43bf-b26c-4631-a966-af9dfa12c9e5",
|
||||
@@ -380,7 +379,6 @@
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ae506c68-352c-4094-9caa-ee9d42052317",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"c714dcee-cad3-4e12-8f3c-12bdcfcdb048",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140",
|
||||
|
||||
102
evaluation_examples/test_small.json
Normal file
102
evaluation_examples/test_small.json
Normal file
@@ -0,0 +1,102 @@
|
||||
{
|
||||
"chrome": [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||
],
|
||||
"gimp": [
|
||||
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
||||
"554785e9-4523-4e7a-b8e1-8016f565f56a"
|
||||
],
|
||||
"libreoffice_calc": [
|
||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"42e0a640-4f19-4b28-973d-729602b5a4a7"
|
||||
],
|
||||
"libreoffice_impress": [
|
||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54"
|
||||
],
|
||||
"libreoffice_writer": [
|
||||
"0810415c-bde4-4443-9047-d5f70165a697",
|
||||
"0a0faba3-5580-44df-965d-f562a99b291c"
|
||||
],
|
||||
"multi_apps": [
|
||||
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||
"510f64c8-9bcc-4be1-8d30-638705850618",
|
||||
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
|
||||
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||
"f7dfbef3-7697-431c-883a-db8583a4e4f9",
|
||||
"6d72aad6-187a-4392-a4c4-ed87269c51cf",
|
||||
"f918266a-b3e0-4914-865d-4faa564f1aef",
|
||||
"da52d699-e8d2-4dc5-9191-a2199e0b6a9b",
|
||||
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
|
||||
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
|
||||
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
|
||||
"eb303e01-261e-4972-8c07-c9b4e7a4922a",
|
||||
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
|
||||
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
|
||||
"8e116af7-7db7-4e35-a68b-b0939c066c78",
|
||||
"185f29bd-5da0-40a6-b69c-ba7f4e0324ef",
|
||||
"2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e",
|
||||
"3a93cae4-ad3e-403e-8c12-65303b271818",
|
||||
"1f18aa87-af6f-41ef-9853-cdb8f32ebdea",
|
||||
"26150609-0da3-4a7d-8868-0faf9c5f01bb",
|
||||
"7e287123-70ca-47b9-8521-47db09b69b14",
|
||||
"e2392362-125e-4f76-a2ee-524b183a3412",
|
||||
"26660ad1-6ebb-4f59-8cba-a8432dfe8d38",
|
||||
"a82b78bb-7fde-4cb3-94a4-035baf10bcf0",
|
||||
"36037439-2044-4b50-b9d1-875b5a332143",
|
||||
"716a6079-22da-47f1-ba73-c9d58f986a38",
|
||||
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
|
||||
"6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a",
|
||||
"da922383-bfa4-4cd3-bbad-6bebab3d7742",
|
||||
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d",
|
||||
"81c425f5-78f3-4771-afd6-3d2973825947",
|
||||
"227d2f97-562b-4ccb-ae47-a5ec9e142fbb",
|
||||
"20236825-b5df-46e7-89bf-62e1d640a897",
|
||||
"02ce9a50-7af2-47ed-8596-af0c230501f8",
|
||||
"4c26e3f3-3a14-4d86-b44a-d3cedebbb487",
|
||||
"09a37c51-e625-49f4-a514-20a773797a8a",
|
||||
"3e3fc409-bff3-4905-bf16-c968eee3f807",
|
||||
"415ef462-bed3-493a-ac36-ca8c6d23bf1b",
|
||||
"9f3bb592-209d-43bc-bb47-d77d9df56504",
|
||||
"dd60633f-2c72-42ba-8547-6f2c8cb0fdb0",
|
||||
"3f05f3b9-29ba-4b6b-95aa-2204697ffc06",
|
||||
"f8369178-fafe-40c2-adc4-b9b08a125456",
|
||||
"778efd0a-153f-4842-9214-f05fc176b877",
|
||||
"47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5",
|
||||
"c2751594-0cd5-4088-be1b-b5f2f9ec97c4",
|
||||
"48c46dc7-fe04-4505-ade7-723cba1aa6f6",
|
||||
"42d25c08-fb87-4927-8b65-93631280a26f",
|
||||
"bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108",
|
||||
"3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
||||
"d68204bf-11c1-4b13-b48b-d303c73d4bf6",
|
||||
"91190194-f406-4cd6-b3f9-c43fac942b22",
|
||||
"7f35355e-02a6-45b5-b140-f0be698bcf85",
|
||||
"98e8e339-5f91-4ed2-b2b2-12647cb134f4",
|
||||
"df67aebb-fb3a-44fd-b75b-51b6012df509",
|
||||
"5df7b33a-9f77-4101-823e-02f863e1c1ae",
|
||||
"22a4636f-8179-4357-8e87-d1743ece1f81",
|
||||
"236833a3-5704-47fc-888c-4f298f09f799"
|
||||
],
|
||||
"os": [
|
||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||
"5812b315-e7bd-4265-b51f-863c02174c28",
|
||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82"
|
||||
],
|
||||
"thunderbird": [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||
],
|
||||
"vlc": [
|
||||
"59f21cfb-0120-4326-b255-a5b827b38967",
|
||||
"8f080098-ddb1-424c-b438-4e96e5e4786e"
|
||||
],
|
||||
"vs_code": [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb"
|
||||
]
|
||||
}
|
||||
@@ -1,432 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(30, stop_recording)
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id, gpt4_model="gpt-4-0125-preview"):
|
||||
action_space = "pyautogui"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
logger.info("Running example %s/%s", example_class, example_id)
|
||||
logger.info("Using model %s", gpt4_model)
|
||||
# logger.info("Using model %s", gemini_model)
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'], max_tokens=1000,
|
||||
action_space=action_space, exp="a11y_tree")
|
||||
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="a11y_tree")
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os_list = [
|
||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
||||
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
|
||||
"13584542-872b-42d8-b299-866967b5c3ef",
|
||||
"23393935-50c7-4a86-aeea-2b78fd089c5c"
|
||||
]
|
||||
|
||||
# for example_id in os_list:
|
||||
# try:
|
||||
# main("os", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
vlc_list = [
|
||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
||||
"8f080098-ddb1-424c-b438-4e96e5e4786e",
|
||||
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
|
||||
"fba2c100-79e8-42df-ae74-b592418d54f4",
|
||||
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
|
||||
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
|
||||
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
|
||||
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
|
||||
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
|
||||
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
|
||||
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
|
||||
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
|
||||
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
|
||||
]
|
||||
|
||||
chrome_list = [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
||||
]
|
||||
|
||||
calc_list = [
|
||||
"eb03d19a-b88d-4de4-8a64-ca0ac66f426b",
|
||||
"0bf05a7d-b28b-44d2-955a-50b41e24012a",
|
||||
"7a4e4bc8-922c-4c84-865c-25ba34136be1",
|
||||
"2bd59342-0664-4ccb-ba87-79379096cc08",
|
||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||
"7efeb4b1-3d19-4762-b163-63328d66303b",
|
||||
"4e6fcf72-daf3-439f-a232-c434ce416af6",
|
||||
"6054afcb-5bab-4702-90a0-b259b5d3217c",
|
||||
"abed40dc-063f-4598-8ba5-9fe749c0615d",
|
||||
"01b269ae-2111-4a07-81fd-3fcd711993b0",
|
||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
|
||||
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
|
||||
"51b11269-2ca8-4b2a-9163-f21758420e78",
|
||||
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
|
||||
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
|
||||
"3aaa4e37-dc91-482e-99af-132a612d40f3",
|
||||
"37608790-6147-45d0-9f20-1137bb35703d",
|
||||
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
||||
"d681960f-7bc3-4286-9913-a8812ba3261a",
|
||||
"21df9241-f8d7-4509-b7f1-37e501a823f7",
|
||||
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
|
||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
|
||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
||||
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c",
|
||||
]
|
||||
|
||||
# for example_id in calc_list:
|
||||
# main("libreoffice_calc", example_id)
|
||||
|
||||
impress_list = [
|
||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
||||
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
||||
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
||||
"c59742c0-4323-4b9d-8a02-723c251deaa0",
|
||||
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
||||
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
||||
"0f84bef9-9790-432e-92b7-eece357603fb",
|
||||
"ce88f674-ab7a-43da-9201-468d38539e4a",
|
||||
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
||||
"a097acff-6266-4291-9fbd-137af7ecd439",
|
||||
"bf4e9888-f10f-47af-8dba-76413038b73c",
|
||||
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
||||
]
|
||||
# for example_id in impress_list:
|
||||
# main("libreoffice_impress", example_id)
|
||||
|
||||
thunderbird_list = [
|
||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
||||
]
|
||||
# for example_id in thunderbird_list:
|
||||
# main("thunderbird", example_id)
|
||||
|
||||
gimp_list = [
|
||||
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
||||
"554785e9-4523-4e7a-b8e1-8016f565f56a",
|
||||
"77b8ab4d-994f-43ac-8930-8ca087d7c4b4",
|
||||
"f4aec372-4fb0-4df5-a52b-79e0e2a5d6ce",
|
||||
"d52d6308-ec58-42b7-a2c9-de80e4837b2b",
|
||||
"2a729ded-3296-423d-aec4-7dd55ed5fbb3",
|
||||
"b148e375-fe0b-4bec-90e7-38632b0d73c2",
|
||||
"a746add2-cab0-4740-ac36-c3769d9bfb46",
|
||||
"7b7617bd-57cc-468e-9c91-40c4ec2bcb3d",
|
||||
"d16c99dc-2a1e-46f2-b350-d97c86c85c15",
|
||||
"06ca5602-62ca-47f6-ad4f-da151cde54cc",
|
||||
"e2dd0213-26db-4349-abe5-d5667bfd725c",
|
||||
"f723c744-e62c-4ae6-98d1-750d3cd7d79d",
|
||||
"72f83cdc-bf76-4531-9a1b-eb893a13f8aa",
|
||||
"7767eef2-56a3-4cea-8c9f-48c070c7d65b",
|
||||
"734d6579-c07d-47a8-9ae2-13339795476b"
|
||||
]
|
||||
|
||||
# for example_id in gimp_list:
|
||||
# try:
|
||||
# main("gimp", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
vs_code_list = [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
||||
]
|
||||
|
||||
# for example_id in vs_code_list:
|
||||
# try:
|
||||
# main("vs_code", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# for example_id in tqdm(vlc_list):
|
||||
# try:
|
||||
# main("vlc", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
# except Exception as e:
|
||||
# print(f"An error occurred while running the example: {e}")
|
||||
# continue
|
||||
|
||||
chrome_list = [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
||||
]
|
||||
# for example_id in tqdm(chrome_list):
|
||||
# try:
|
||||
# main("chrome", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
# except Exception as e:
|
||||
# print(f"An error occurred while running the example: {e}")
|
||||
# continue
|
||||
|
||||
vs_code_list = [
|
||||
# "0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
# "53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
# "eabc805a-bfcf-4460-b250-ac92135819f6",
|
||||
# "982d12a5-beab-424f-8d38-d2a48429e511",
|
||||
# "4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
# "e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
# "9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
# "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
# "930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
# "276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
# "9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
||||
]
|
||||
|
||||
for example_id in tqdm(vs_code_list):
|
||||
try:
|
||||
main("vs_code", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
except Exception as e:
|
||||
print(f"An error occurred while running the example: {e}")
|
||||
continue
|
||||
|
||||
thunderbird_list = [
|
||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
||||
]
|
||||
|
||||
# for example_id in tqdm(thunderbird_list):
|
||||
# try:
|
||||
# main("thunderbird", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
# except Exception as e:
|
||||
# print(f"An error occurred while running the example: {e}")
|
||||
# continue
|
||||
|
||||
multiple_list = [
|
||||
# "f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
|
||||
# "897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||
"2fe4b718-3bd7-46ec-bdce-b184f5653624",
|
||||
"3680a5ee-6870-426a-a997-eba929a0d25c",
|
||||
# "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||
# "b52b40a5-ad70-4c53-b5b0-5650a8387052",
|
||||
# "46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||
# "2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||
# "51f5801c-18b3-4f25-b0c3-02f85507a078",
|
||||
"58565672-7bfe-48ab-b828-db349231de6b",
|
||||
# "2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
|
||||
# "510f64c8-9bcc-4be1-8d30-638705850618",
|
||||
# "937087b6-f668-4ba6-9110-60682ee33441",
|
||||
# "ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
||||
# "3680a5ee-6870-426a-a997-eba929a0d25c",
|
||||
# "e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
||||
# "58565672-7bfe-48ab-b828-db349231de6b",
|
||||
# "2fe4b718-3bd7-46ec-bdce-b184f5653624"
|
||||
]
|
||||
|
||||
for example_id in multiple_list:
|
||||
try:
|
||||
main("multi_apps", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while running the example: %s", e)
|
||||
continue
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
# todo: unifiy all the experiments python file into one file
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent # todo: change the name into PromptAgent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
# todo: move the PATH_TO_VM to the argparser
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
|
||||
max_time=600):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example,
|
||||
headless=True
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(120, stop_recording)
|
||||
# todo: make sure we got the video file, check the bug
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
# fixme: change to write the result into a separate file
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# todo: append the result to the wandb for visualization
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
||||
# todo: merge the main function into the run_one_example function
|
||||
# fixme: change all the settings like action_space, model, etc. to the argparser
|
||||
action_space = "pyautogui"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
logger.info("Running example %s/%s", example_class, example_id)
|
||||
logger.info("Using model %s", gpt4_model)
|
||||
# logger.info("Using model %s", gemini_model)
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key,
|
||||
model=gpt4_model,
|
||||
instruction=example['instruction'],
|
||||
action_space=action_space,
|
||||
exp="screenshot")
|
||||
#
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space, exp="screenshot")
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(os.path.join(example_trajectory_dir, "trajectory.json")):
|
||||
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "r") as f:
|
||||
lines = f.readlines()
|
||||
# strip the last line if it is empty
|
||||
lines = [line.strip() for line in lines if line.strip() != ""]
|
||||
if len(lines) > 0:
|
||||
last_line = json.loads(lines[-1])
|
||||
if "result" in last_line:
|
||||
logger.info(
|
||||
f"evaluation_examples/examples/{example_class}/{example_id}.json" + "has been evaluated. Skip.")
|
||||
return
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(1200, run_one_example, args=(example, agent, 15, example_trajectory_dir))
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"error": str(e)
|
||||
}))
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--render", action="store_true", help="Render the browser"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--slow_mo",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Slow down the browser by the specified amount",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_set_tag", default="id_accessibility_tree", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=[
|
||||
"accessibility_tree",
|
||||
"accessibility_tree_with_captioner",
|
||||
"html",
|
||||
"image",
|
||||
"image_som",
|
||||
],
|
||||
default="accessibility_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--current_viewport_only",
|
||||
action="store_true",
|
||||
help="Only use the current viewport for the observation",
|
||||
)
|
||||
parser.add_argument("--viewport_width", type=int, default=1280)
|
||||
parser.add_argument("--viewport_height", type=int, default=2048)
|
||||
parser.add_argument("--save_trace_enabled", action="store_true")
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--max_steps", type=int, default=30)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--agent_type", type=str, default="prompt")
|
||||
parser.add_argument(
|
||||
"--instruction_path",
|
||||
type=str,
|
||||
default="agents/prompts/state_action_agent.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parsing_failure_th",
|
||||
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
|
||||
type=int,
|
||||
default=3,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeating_action_failure_th",
|
||||
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
|
||||
parser.add_argument("--test_config_base_dir", type=str)
|
||||
|
||||
parser.add_argument(
|
||||
"--eval_captioning_model_device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
help="Device to run eval captioning model on. By default, runs it on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_captioning_model",
|
||||
type=str,
|
||||
default="Salesforce/blip2-flan-t5-xl",
|
||||
choices=["Salesforce/blip2-flan-t5-xl"],
|
||||
help="Captioning backbone for VQA-type evals.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--captioning_model",
|
||||
type=str,
|
||||
default="Salesforce/blip2-flan-t5-xl",
|
||||
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
|
||||
help="Captioning backbone for accessibility tree alt text.",
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--provider", type=str, default="openai")
|
||||
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
||||
parser.add_argument("--mode", type=str, default="chat")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--context_length", type=int, default=0)
|
||||
parser.add_argument("--max_tokens", type=int, default=384)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--max_retry",
|
||||
type=int,
|
||||
help="max retry times to perform generations when parsing fails",
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_obs_length",
|
||||
type=int,
|
||||
help="when not zero, will truncate the observation to this length before feeding to the model",
|
||||
default=3840,
|
||||
)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--test_start_idx", type=int, default=0)
|
||||
parser.add_argument("--test_end_idx", type=int, default=910)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check the whether the action space is compatible with the observation space
|
||||
if (
|
||||
args.action_set_tag == "id_accessibility_tree"
|
||||
and args.observation_type
|
||||
not in [
|
||||
"accessibility_tree",
|
||||
"accessibility_tree_with_captioner",
|
||||
"image_som",
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
args.sleep_after_execution = 2.5
|
||||
prepare(args)
|
||||
|
||||
# todo: add recorder of the progress of the examples
|
||||
|
||||
# todo: remove the useless example files
|
||||
|
||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
for domain in test_all_meta:
|
||||
for example_id in test_all_meta[domain]:
|
||||
main(domain, example_id, args.model)
|
||||
@@ -1,361 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu2\Ubuntu2.vmx"
|
||||
|
||||
|
||||
# PATH_TO_VM = "../../../../大文件/镜像/Ubuntu-1218/Ubuntu/Ubuntu.vmx"
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(30, stop_recording)
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
||||
action_space = "pyautogui"
|
||||
# example_class = "libreoffice_calc"
|
||||
# example_id = "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||
# example_id = "01b269ae-2111-4a07-81fd-3fcd711993b0"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
logger.info("Running example %s/%s", example_class, example_id)
|
||||
logger.info("Using model %s", gpt4_model)
|
||||
# logger.info("Using model %s", gemini_model)
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
# example["snapshot"] = "exp_setup4"
|
||||
# example["snapshot"] = "Snapshot 30"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
|
||||
action_space=action_space, exp="both")
|
||||
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="both")
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os_list = [
|
||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
||||
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
|
||||
"13584542-872b-42d8-b299-866967b5c3ef",
|
||||
"23393935-50c7-4a86-aeea-2b78fd089c5c"
|
||||
]
|
||||
|
||||
# for example_id in os_list:
|
||||
# try:
|
||||
# main("os", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
calc_list = [
|
||||
"a9f325aa-8c05-4e4f-8341-9e4358565f4f",
|
||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||
"7efeb4b1-3d19-4762-b163-63328d66303b",
|
||||
"4e6fcf72-daf3-439f-a232-c434ce416af6",
|
||||
"6054afcb-5bab-4702-90a0-b259b5d3217c",
|
||||
"abed40dc-063f-4598-8ba5-9fe749c0615d",
|
||||
"01b269ae-2111-4a07-81fd-3fcd711993b0",
|
||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||
"af2b02f7-acee-4be4-8b66-499fab394915",
|
||||
"da1d63b8-fa12-417b-ba18-f748e5f770f3",
|
||||
"636380ea-d5f6-4474-b6ca-b2ed578a20f1",
|
||||
"5ba77536-05c5-4aae-a9ff-6e298d094c3e",
|
||||
"4bc4eaf4-ca5e-4db2-8138-8d4e65af7c0b",
|
||||
"672a1b02-c62f-4ae2-acf0-37f5fb3052b0",
|
||||
"648fe544-16ba-44af-a587-12ccbe280ea6",
|
||||
"8985d1e4-5b99-4711-add4-88949ebb2308",
|
||||
"9e606842-2e27-43bf-b1d1-b43289c9589b",
|
||||
"fcb6e45b-25c4-4087-9483-03d714f473a9",
|
||||
"68c0c5b7-96f3-4e87-92a7-6c1b967fd2d2",
|
||||
"fff629ea-046e-4793-8eec-1a5a15c3eb35",
|
||||
"5c9a206c-bb00-4fb6-bb46-ee675c187df5",
|
||||
"e975ae74-79bd-4672-8d1c-dc841a85781d",
|
||||
"34a6938a-58da-4897-8639-9b90d6db5391",
|
||||
"b5a22759-b4eb-4bf2-aeed-ad14e8615f19",
|
||||
"2f9913a1-51ed-4db6-bfe0-7e1c95b3139e",
|
||||
"2558031e-401d-4579-8e00-3ecf540fb492",
|
||||
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
|
||||
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
|
||||
"51b11269-2ca8-4b2a-9163-f21758420e78",
|
||||
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
|
||||
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
|
||||
"3aaa4e37-dc91-482e-99af-132a612d40f3",
|
||||
"37608790-6147-45d0-9f20-1137bb35703d",
|
||||
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
||||
"d681960f-7bc3-4286-9913-a8812ba3261a",
|
||||
"21df9241-f8d7-4509-b7f1-37e501a823f7",
|
||||
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
|
||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
|
||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
||||
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c"
|
||||
]
|
||||
|
||||
# for example_id in calc_list:
|
||||
# try:
|
||||
# main("libreoffice_calc", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
impress_list = [
|
||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
||||
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
||||
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
||||
"c59742c0-4323-4b9d-8a02-723c251deaa0",
|
||||
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
||||
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
||||
"0f84bef9-9790-432e-92b7-eece357603fb",
|
||||
"ce88f674-ab7a-43da-9201-468d38539e4a",
|
||||
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
||||
"a097acff-6266-4291-9fbd-137af7ecd439",
|
||||
"bf4e9888-f10f-47af-8dba-76413038b73c",
|
||||
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
||||
]
|
||||
|
||||
# for example_id in impress_list:
|
||||
# try:
|
||||
# main("libreoffice_impress", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
vs_code_list = [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
||||
]
|
||||
|
||||
# for example_id in vs_code_list:
|
||||
# try:
|
||||
# main("vs_code", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
multiple_list = [
|
||||
"f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
|
||||
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||
"b52b40a5-ad70-4c53-b5b0-5650a8387052",
|
||||
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||
"51f5801c-18b3-4f25-b0c3-02f85507a078",
|
||||
"2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
|
||||
"510f64c8-9bcc-4be1-8d30-638705850618",
|
||||
"937087b6-f668-4ba6-9110-60682ee33441",
|
||||
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
||||
"3680a5ee-6870-426a-a997-eba929a0d25c",
|
||||
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||
"58565672-7bfe-48ab-b828-db349231de6b",
|
||||
"2fe4b718-3bd7-46ec-bdce-b184f5653624"
|
||||
]
|
||||
|
||||
# for example_id in multiple_list:
|
||||
# try:
|
||||
# main("multi_apps", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
chrome_list = [
|
||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
||||
]
|
||||
|
||||
# for example_id in chrome_list:
|
||||
# try:
|
||||
# main("chrome", example_id)
|
||||
# except Exception as e:
|
||||
# logger.error("An error occurred while running the example: %s", e)
|
||||
# continue
|
||||
|
||||
|
||||
writer_list = [
|
||||
"6ada715d-3aae-4a32-a6a7-429b2e43fb93",
|
||||
"ecc2413d-8a48-416e-a3a2-d30106ca36cb",
|
||||
"0e47de2a-32e0-456c-a366-8c607ef7a9d2",
|
||||
"4bcb1253-a636-4df4-8cb0-a35c04dfef31",
|
||||
"0810415c-bde4-4443-9047-d5f70165a697",
|
||||
"e528b65e-1107-4b8c-8988-490e4fece599",
|
||||
"66399b0d-8fda-4618-95c4-bfc6191617e9",
|
||||
"936321ce-5236-426a-9a20-e0e3c5dc536f",
|
||||
"3ef2b351-8a84-4ff2-8724-d86eae9b842e",
|
||||
"0b17a146-2934-46c7-8727-73ff6b6483e8",
|
||||
"0e763496-b6bb-4508-a427-fad0b6c3e195",
|
||||
"f178a4a9-d090-4b56-bc4c-4b72a61a035d",
|
||||
"adf5e2c3-64c7-4644-b7b6-d2f0167927e7",
|
||||
"0a0faba3-5580-44df-965d-f562a99b291c",
|
||||
"e246f6d8-78d7-44ac-b668-fcf47946cb50",
|
||||
"8472fece-c7dd-4241-8d65-9b3cd1a0b568",
|
||||
"88fe4b2d-3040-4c70-9a70-546a47764b48",
|
||||
"d53ff5ee-3b1a-431e-b2be-30ed2673079b",
|
||||
"72b810ef-4156-4d09-8f08-a0cf57e7cefe",
|
||||
"6f81754e-285d-4ce0-b59e-af7edb02d108",
|
||||
"b21acd93-60fd-4127-8a43-2f5178f4a830"
|
||||
]
|
||||
|
||||
for example_id in writer_list:
|
||||
try:
|
||||
main("libreoffice_writer", example_id)
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while running the example: %s", e)
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import ctypes
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(30, stop_recording)
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id):
|
||||
action_space = "pyautogui"
|
||||
gpt4_model = "gpt-4-vision-preview"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
|
||||
action_space=action_space, exp="seeact")
|
||||
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
xx_list = [
|
||||
]
|
||||
for example_id in xx_list:
|
||||
main("xx", example_id)
|
||||
@@ -1,261 +0,0 @@
|
||||
#import ctypes
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import func_timeout
|
||||
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
||||
|
||||
|
||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
||||
env = DesktopEnv(
|
||||
path_to_vm=PATH_TO_VM,
|
||||
action_space=agent.action_space,
|
||||
task_config=example
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
done = False
|
||||
step_num = 0
|
||||
|
||||
if recording:
|
||||
# send a request to the server to start recording
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_num < max_steps:
|
||||
actions = agent.predict(observation)
|
||||
step_num += 1
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_num, action)
|
||||
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_num,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
def stop_recording():
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
||||
except Exception as e:
|
||||
print(f"An error occurred while stopping the recording: {e}")
|
||||
|
||||
try:
|
||||
func_timeout.func_timeout(30, stop_recording)
|
||||
except func_timeout.exceptions.FunctionTimedOut:
|
||||
logger.info("Recording timed out.")
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
with open(trajectory_recording_path, "a") as f:
|
||||
f.write(json.dumps({
|
||||
"result": result
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
# env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
def main(example_class, example_id):
|
||||
action_space = "pyautogui"
|
||||
gpt4_model = "gpt-4-vision-preview"
|
||||
gemini_model = "gemini-pro-vision"
|
||||
|
||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
|
||||
logger.info("TASK: %s/%s", example_class, example_id)
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, max_tokens=1000, instruction=example['instruction'],
|
||||
action_space=action_space, exp="som")
|
||||
|
||||
# api_key = os.environ.get("GENAI_API_KEY")
|
||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
|
||||
|
||||
root_trajectory_dir = "exp_trajectory"
|
||||
|
||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gpt4_model, example_id)
|
||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gemini_model, example_id)
|
||||
|
||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
||||
|
||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from tqdm import tqdm
|
||||
# impress_list = [
|
||||
# # "5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
# "550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
||||
# "455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
||||
# "af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
||||
# "c59742c0-4323-4b9d-8a02-723c251deaa0",
|
||||
# "ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
||||
# "9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
||||
# "0f84bef9-9790-432e-92b7-eece357603fb",
|
||||
# "ce88f674-ab7a-43da-9201-468d38539e4a",
|
||||
# "3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
||||
# "a097acff-6266-4291-9fbd-137af7ecd439",
|
||||
# "bf4e9888-f10f-47af-8dba-76413038b73c",
|
||||
# "21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
||||
# ]
|
||||
# for example_id in impress_list:
|
||||
# main("libreoffice_impress", example_id)
|
||||
|
||||
vlc_list = [
|
||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
||||
"8f080098-ddb1-424c-b438-4e96e5e4786e",
|
||||
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
|
||||
"fba2c100-79e8-42df-ae74-b592418d54f4",
|
||||
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
|
||||
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
|
||||
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
|
||||
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
|
||||
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
|
||||
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
|
||||
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
|
||||
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
|
||||
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
|
||||
]
|
||||
|
||||
# for example_id in tqdm(vlc_list):
|
||||
# try:
|
||||
# main("vlc", example_id)
|
||||
# except Exception as e:
|
||||
# print(f"An error occurred while running the example: {e}")
|
||||
# continue
|
||||
|
||||
chrome_list = [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
||||
]
|
||||
for example_id in tqdm(chrome_list):
|
||||
try:
|
||||
main("chrome", example_id)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while running the example: {e}")
|
||||
continue
|
||||
|
||||
vs_code_list = [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
||||
]
|
||||
|
||||
for example_id in tqdm(vs_code_list):
|
||||
try:
|
||||
main("vs_code", example_id)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while running the example: {e}")
|
||||
continue
|
||||
|
||||
thunderbird_list = [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
||||
]
|
||||
|
||||
for example_id in tqdm(thunderbird_list):
|
||||
try:
|
||||
main("thunderbird", example_id)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while running the example: {e}")
|
||||
continue
|
||||
69
lib_run_single.py
Normal file
69
lib_run_single.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from wrapt_timeout_decorator import *
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
# Open the JSON file
|
||||
with open("./settings.json", "r") as file:
|
||||
# Load the JSON data from the file
|
||||
data = json.load(file)
|
||||
time_limit = data["time_limit"]
|
||||
|
||||
|
||||
@timeout(time_limit, use_signals=False)
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
agent.reset()
|
||||
obs = env.reset(task_config=example)
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
with open(obs['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
36
main.py
36
main.py
@@ -47,38 +47,38 @@ def human_agent():
|
||||
Runs the Gym environment with human input.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--path', type=str, required=True, help="Path to the virtual machine .vmx file.")
|
||||
parser.add_argument('-s', '--snapshot', type=str, help="Name of the snapshot to restore.")
|
||||
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.")
|
||||
parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.")
|
||||
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
|
||||
example_path = args.example if args.example is not None and os.path.exists(args.example) else \
|
||||
'evaluation_examples/examples/libreoffice_writer/6a33f9b9-0a56-4844-9c3f-96ec3ffb3ba2.json'
|
||||
with open(example_path, "r") as f:
|
||||
'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json'
|
||||
with open(example_path, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
# change to your customized snapshot
|
||||
if args.snapshot is not None: example["snapshot"] = args.snapshot
|
||||
if args.snapshot is not None:
|
||||
example['snapshot'] = args.snapshot
|
||||
|
||||
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path,
|
||||
action_space="computer_13",
|
||||
task_config=example
|
||||
snapshot_name=args.snapshot,
|
||||
action_space="computer_13"
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
|
||||
observation = env.reset(task_config=example)
|
||||
done = False
|
||||
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
|
||||
|
||||
trajectory = [
|
||||
# {
|
||||
# "action_type": "MOVE_TO",
|
||||
# "parameters": {
|
||||
# "x": 754,
|
||||
# "y": 1057
|
||||
# }
|
||||
# },
|
||||
# {"action_type": "CLICK", "parameters": {"button": "right", "num_clicks": 1}}
|
||||
{
|
||||
"action_type": "MOVE_TO", #
|
||||
"parameters": {
|
||||
"x": 754,
|
||||
"y": 1057
|
||||
}
|
||||
},
|
||||
{"action_type": "CLICK", "parameters": {"button": "right", "num_clicks": 1}}
|
||||
]
|
||||
|
||||
for i in range(len(trajectory)):
|
||||
|
||||
@@ -26,7 +26,7 @@ def find_leaf_nodes(xlm_file_str):
|
||||
|
||||
state_ns = "uri:deskat:state.at-spi.gnome.org"
|
||||
component_ns = "uri:deskat:component.at-spi.gnome.org"
|
||||
def judge_node(node: ET, platform="ubuntu") -> bool:
|
||||
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
|
||||
keeps: bool = node.tag.startswith("document")\
|
||||
or node.tag.endswith("item")\
|
||||
or node.tag.endswith("button")\
|
||||
@@ -55,23 +55,25 @@ def judge_node(node: ET, platform="ubuntu") -> bool:
|
||||
or platform=="windows"\
|
||||
and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\
|
||||
)\
|
||||
and ( node.get("{{{:}}}enabled".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}editable".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
|
||||
)\
|
||||
and (node.get("name", "") != "" or node.text is not None and len(node.text)>0)
|
||||
and ( node.get("{{{:}}}enabled".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}editable".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
|
||||
)\
|
||||
and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\
|
||||
or check_image and node.get("image", "false")=="true"
|
||||
)
|
||||
|
||||
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
|
||||
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
|
||||
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
|
||||
return keeps
|
||||
|
||||
def filter_nodes(root: ET, platform="ubuntu"):
|
||||
def filter_nodes(root: ET, platform="ubuntu", check_image=False):
|
||||
filtered_nodes = []
|
||||
|
||||
for node in root.iter():
|
||||
if judge_node(node, platform):
|
||||
if judge_node(node, platform, check_image):
|
||||
filtered_nodes.append(node)
|
||||
#print(ET.tostring(node, encoding="unicode"))
|
||||
|
||||
@@ -155,12 +157,12 @@ def print_nodes_with_indent(nodes, indent=0):
|
||||
|
||||
if __name__ == '__main__':
|
||||
import json
|
||||
with open('4.json', 'r', encoding='utf-8') as f:
|
||||
xml_file_str = json.load(f)["AT"]
|
||||
with open('selection_sorted(imaged).xml', 'r', encoding='utf-8') as f:
|
||||
xml_file_str = f.read()
|
||||
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
|
||||
print(len(filtered_nodes))
|
||||
masks = draw_bounding_boxes( filtered_nodes, '4.png'
|
||||
, '4.a.png'
|
||||
masks = draw_bounding_boxes( filtered_nodes, 'selection_sorted(imaged).png'
|
||||
, 'selection_sorted(imaged).ai.png'
|
||||
)
|
||||
|
||||
# print(masks)
|
||||
|
||||
@@ -5,32 +5,23 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import openai
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
from PIL import Image
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
RateLimitError
|
||||
)
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
||||
|
||||
import logging
|
||||
# todo: cross-check with visualwebarena
|
||||
SYS_PROMPT_IN_SOM_OUT_TAG
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
@@ -42,10 +33,10 @@ def encode_image(image_path):
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree):
|
||||
#leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||
|
||||
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
||||
linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
|
||||
for node in filtered_nodes:
|
||||
@@ -73,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree):
|
||||
uuid_str = str(uuid.uuid4())
|
||||
os.makedirs("tmp/images", exist_ok=True)
|
||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
||||
# Make tag screenshot
|
||||
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||
|
||||
@@ -169,79 +161,66 @@ def parse_code_from_som_string(input_string, masks):
|
||||
return actions
|
||||
|
||||
|
||||
class GPT4v_Agent:
|
||||
class PromptAgent:
|
||||
def __init__(
|
||||
self,
|
||||
api_key,
|
||||
instruction,
|
||||
model="gpt-4-vision-preview",
|
||||
max_tokens=500,
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
temperature=0.5,
|
||||
action_space="computer_13",
|
||||
exp="screenshot_a11y_tree"
|
||||
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=3
|
||||
):
|
||||
|
||||
self.instruction = instruction
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.exp = exp
|
||||
self.max_trajectory_length = 3
|
||||
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
|
||||
if exp == "screenshot":
|
||||
if observation_type == "screenshot":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "a11y_tree":
|
||||
elif observation_type == "a11y_tree":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "both":
|
||||
elif observation_type == "screenshot_a11y_tree":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "som":
|
||||
elif observation_type == "som":
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "seeact":
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_SEEACT
|
||||
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + exp)
|
||||
raise ValueError("Invalid experiment type: " + observation_type)
|
||||
|
||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
||||
self.instruction)
|
||||
|
||||
def predict(self, obs: Dict) -> List:
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
|
||||
|
||||
# Prepare the payload for the API call
|
||||
messages = []
|
||||
@@ -252,7 +231,7 @@ class GPT4v_Agent:
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.system_message
|
||||
"text": system_message
|
||||
},
|
||||
]
|
||||
})
|
||||
@@ -273,7 +252,7 @@ class GPT4v_Agent:
|
||||
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
||||
|
||||
# {{{1
|
||||
if self.exp == "both":
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||
@@ -295,18 +274,15 @@ class GPT4v_Agent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp in ["som", "seeact"]:
|
||||
elif self.observation_type in ["som"]:
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
_linearized_accessibility_tree)
|
||||
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -317,7 +293,7 @@ class GPT4v_Agent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "screenshot":
|
||||
elif self.observation_type == "screenshot":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
|
||||
messages.append({
|
||||
@@ -336,7 +312,7 @@ class GPT4v_Agent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "a11y_tree":
|
||||
elif self.observation_type == "a11y_tree":
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
|
||||
messages.append({
|
||||
@@ -350,7 +326,7 @@ class GPT4v_Agent:
|
||||
]
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
||||
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
@@ -363,11 +339,11 @@ class GPT4v_Agent:
|
||||
})
|
||||
|
||||
# {{{1
|
||||
if self.exp in ["screenshot", "both"]:
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
if self.exp == "both":
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -384,7 +360,7 @@ class GPT4v_Agent:
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
||||
if self.exp == "screenshot"
|
||||
if self.observation_type == "screenshot"
|
||||
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
linearized_accessibility_tree)
|
||||
},
|
||||
@@ -397,7 +373,7 @@ class GPT4v_Agent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "a11y_tree":
|
||||
elif self.observation_type == "a11y_tree":
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
@@ -415,15 +391,13 @@ class GPT4v_Agent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "som":
|
||||
elif self.observation_type == "som":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
"screenshot": base64_image
|
||||
})
|
||||
|
||||
messages.append({
|
||||
@@ -431,35 +405,7 @@ class GPT4v_Agent:
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
linearized_accessibility_tree)
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "seeact":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
|
||||
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -471,53 +417,26 @@ class GPT4v_Agent:
|
||||
]
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
||||
|
||||
with open("messages.json", "w") as f:
|
||||
f.write(json.dumps(messages, indent=4))
|
||||
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||
|
||||
# with open("messages.json", "w") as f:
|
||||
# f.write(json.dumps(messages, indent=4))
|
||||
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature
|
||||
})
|
||||
|
||||
logger.debug("RESPONSE: %s", response)
|
||||
|
||||
if self.exp == "seeact":
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": response
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
|
||||
ACTION_GROUNDING_PROMPT_SEEACT)
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
})
|
||||
print(response)
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
try:
|
||||
actions = self.parse_actions(response, masks)
|
||||
self.thoughts.append(response)
|
||||
except Exception as e:
|
||||
except ValueError as e:
|
||||
print("Failed to parse action from response", e)
|
||||
actions = None
|
||||
self.thoughts.append("")
|
||||
@@ -526,86 +445,160 @@ class GPT4v_Agent:
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
(Exception),
|
||||
max_tries=10
|
||||
# here you should add more model exceptions as you want,
|
||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
||||
(openai.RateLimitError,
|
||||
openai.BadRequestError,
|
||||
openai.InternalServerError,
|
||||
InvalidArgument),
|
||||
max_tries=5
|
||||
)
|
||||
def call_llm(self, payload):
|
||||
|
||||
if self.model.startswith("gpt"):
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||
}
|
||||
# logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
if response.json()['error']['code'] == "context_length_exceeded":
|
||||
print("Context length exceeded. Retrying with a smaller context.")
|
||||
payload["messages"] = payload["messages"][-1:]
|
||||
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
|
||||
retry_response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
if retry_response.status_code != 200:
|
||||
print("Failed to call LLM: " + retry_response.text)
|
||||
logger.error(
|
||||
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
||||
return ""
|
||||
|
||||
print("Failed to call LLM: " + response.text)
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
|
||||
elif self.model.startswith("mistral"):
|
||||
print("call mistral")
|
||||
elif self.model.startswith("claude"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
misrtal_messages = []
|
||||
|
||||
claude_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
mistral_message = {
|
||||
claude_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||
|
||||
if part['type'] == "image_url":
|
||||
image_source = {}
|
||||
image_source["type"] = "base64"
|
||||
image_source["media_type"] = "image/png"
|
||||
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
|
||||
claude_message['content'].append({"type": "image", "source": image_source})
|
||||
|
||||
if part['type'] == "text":
|
||||
claude_message['content'].append({"type": "text", "text": part['text']})
|
||||
|
||||
claude_messages.append(claude_message)
|
||||
|
||||
misrtal_messages.append(mistral_message)
|
||||
# the claude not support system message in our endpoint, so we concatenate it at the first user message
|
||||
if claude_messages[0]['role'] == "system":
|
||||
claude_system_message_item = claude_messages[0]['content'][0]
|
||||
claude_messages[1]['content'].insert(0, claude_system_message_item)
|
||||
claude_messages.pop(0)
|
||||
|
||||
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
if misrtal_messages[0]['role'] == "system":
|
||||
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||
misrtal_messages.pop(0)
|
||||
|
||||
# openai.api_base = "http://localhost:8000/v1"
|
||||
# openai.api_key = "test"
|
||||
# response = openai.ChatCompletion.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# )
|
||||
headers = {
|
||||
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
from openai import OpenAI
|
||||
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": claude_messages
|
||||
}
|
||||
|
||||
client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
logger.info("Generating content with Mistral model: %s", self.model)
|
||||
response = client.chat.completions.create(
|
||||
messages=misrtal_messages,
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
max_tokens=1024
|
||||
response = requests.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
try:
|
||||
# return response['choices'][0]['message']['content']
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['content'][0]['text']
|
||||
|
||||
|
||||
# elif self.model.startswith("mistral"):
|
||||
# print("Call mistral")
|
||||
# messages = payload["messages"]
|
||||
# max_tokens = payload["max_tokens"]
|
||||
#
|
||||
# misrtal_messages = []
|
||||
#
|
||||
# for i, message in enumerate(messages):
|
||||
# mistral_message = {
|
||||
# "role": message["role"],
|
||||
# "content": []
|
||||
# }
|
||||
#
|
||||
# for part in message["content"]:
|
||||
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||
#
|
||||
# misrtal_messages.append(mistral_message)
|
||||
#
|
||||
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
# if misrtal_messages[0]['role'] == "system":
|
||||
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||
# misrtal_messages.pop(0)
|
||||
#
|
||||
# # openai.api_base = "http://localhost:8000/v1"
|
||||
# # openai.api_key = "test"
|
||||
# # response = openai.ChatCompletion.create(
|
||||
# # messages=misrtal_messages,
|
||||
# # model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# # )
|
||||
#
|
||||
# from openai import OpenAI
|
||||
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||
#
|
||||
# client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||
# base_url='https://api.together.xyz',
|
||||
# )
|
||||
# logger.info("Generating content with Mistral model: %s", self.model)
|
||||
# response = client.chat.completions.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
# max_tokens=1024
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # return response['choices'][0]['message']['content']
|
||||
# return response.choices[0].message.content
|
||||
# except Exception as e:
|
||||
# print("Failed to call LLM: " + str(e))
|
||||
# return ""
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
@@ -617,6 +610,8 @@ class GPT4v_Agent:
|
||||
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
gemini_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
@@ -653,8 +648,9 @@ class GPT4v_Agent:
|
||||
for message in gemini_messages:
|
||||
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
|
||||
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
|
||||
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
|
||||
|
||||
print(gemini_messages)
|
||||
# print(gemini_messages)
|
||||
api_key = os.environ.get("GENAI_API_KEY")
|
||||
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||
genai.configure(api_key=api_key)
|
||||
@@ -662,7 +658,16 @@ class GPT4v_Agent:
|
||||
response = genai.GenerativeModel(self.model).generate_content(
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"max_output_tokens": max_tokens
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
},
|
||||
safety_settings={
|
||||
"harassment": "block_none",
|
||||
"hate": "block_none",
|
||||
"sex": "block_none",
|
||||
"danger": "block_none"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -673,6 +678,8 @@ class GPT4v_Agent:
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
qwen_messages = []
|
||||
|
||||
@@ -683,13 +690,16 @@ class GPT4v_Agent:
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||
|
||||
qwen_messages.append(qwen_message)
|
||||
|
||||
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
|
||||
messages=messages)
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
model='qwen-vl-plus',
|
||||
messages=messages, # todo: add the hyperparameters
|
||||
)
|
||||
# The response status_code is HTTPStatus.OK indicate success,
|
||||
# otherwise indicate request is failed, you can get error code
|
||||
# and message from code and message.
|
||||
@@ -708,7 +718,7 @@ class GPT4v_Agent:
|
||||
|
||||
def parse_actions(self, response: str, masks=None):
|
||||
|
||||
if self.exp in ["screenshot", "a11y_tree", "both"]:
|
||||
if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
actions = parse_actions_from_string(response)
|
||||
@@ -720,7 +730,7 @@ class GPT4v_Agent:
|
||||
self.actions.append(actions)
|
||||
|
||||
return actions
|
||||
elif self.exp in ["som", "seeact"]:
|
||||
elif self.observation_type in ["som"]:
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + self.action_space)
|
||||
@@ -732,3 +742,8 @@ class GPT4v_Agent:
|
||||
self.actions.append(actions)
|
||||
|
||||
return actions
|
||||
|
||||
def reset(self):
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
@@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti
|
||||
You CAN predict multiple actions at one step, but you should only return one action for each step.
|
||||
""".strip()
|
||||
|
||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """
|
||||
SYS_PROMPT_IN_SOM_OUT_TAG = """
|
||||
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
||||
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
||||
For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library.
|
||||
For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image.
|
||||
|
||||
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
|
||||
You can replace x, y in the code with the tag of the element you want to operate with. such as:
|
||||
|
||||
245
run.py
Normal file
245
run.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import lib_run_single
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.agent import PromptAgent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str,
|
||||
default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx")
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=[
|
||||
"screenshot",
|
||||
"a11y_tree",
|
||||
"screenshot_a11y_tree",
|
||||
"som"
|
||||
],
|
||||
default="som",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def test(
|
||||
args: argparse.Namespace,
|
||||
test_all_meta: dict
|
||||
) -> None:
|
||||
scores = []
|
||||
max_steps = args.max_steps
|
||||
|
||||
# log args
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
agent = PromptAgent(
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
)
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=agent.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
)
|
||||
|
||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||
# example setting
|
||||
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Domain]: {domain}")
|
||||
logger.info(f"[Example ID]: {example_id}")
|
||||
|
||||
instruction = example["instruction"]
|
||||
|
||||
logger.info(f"[Instruction]: {instruction}")
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
# example start running
|
||||
try:
|
||||
lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
|
||||
scores)
|
||||
except Exception as e:
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
logger.error(f"Time limit exceeded in {domain}/{example_id}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"Error": f"Time limit exceeded in {domain}/{example_id}"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
all_result = []
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
|
||||
|
||||
print("Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta
|
||||
)
|
||||
|
||||
# test(args, test_all_meta)
|
||||
3
settings.json
Normal file
3
settings.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"time_limit": "1200"
|
||||
}
|
||||
Reference in New Issue
Block a user