Merge branch 'main' of https://github.com/xlang-ai/DesktopEnv
This commit is contained in:
19
.vscode/launch.json
vendored
Normal file
19
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: Current File with Arguments",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": [
|
||||||
|
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx"
|
||||||
|
// "--example_time_limit", "60"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -21,10 +21,12 @@
|
|||||||
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
|
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
|
||||||
|
|
||||||
2. Install the environment package, download the examples and the virtual machine image.
|
2. Install the environment package, download the examples and the virtual machine image.
|
||||||
|
For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands:
|
||||||
```bash
|
```bash
|
||||||
pip install desktop-env
|
pip install desktop-env
|
||||||
gdown xxxx
|
gdown xxxx
|
||||||
gdown xxxx
|
vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui
|
||||||
|
vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ class DesktopEnv(gym.Env):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
path_to_vm: str,
|
path_to_vm: str,
|
||||||
|
snapshot_name: str = "init_state",
|
||||||
action_space: str = "computer_13",
|
action_space: str = "computer_13",
|
||||||
task_config: Dict[str, Any] = None,
|
|
||||||
tmp_dir: str = "tmp",
|
tmp_dir: str = "tmp",
|
||||||
cache_dir: str = "cache",
|
cache_dir: str = "cache",
|
||||||
screen_size: Tuple[int] = (1920, 1080),
|
screen_size: Tuple[int] = (1920, 1080),
|
||||||
@@ -64,15 +64,6 @@ class DesktopEnv(gym.Env):
|
|||||||
Args:
|
Args:
|
||||||
path_to_vm (str): path to .vmx file
|
path_to_vm (str): path to .vmx file
|
||||||
action_space (str): "computer_13" | "pyautogui"
|
action_space (str): "computer_13" | "pyautogui"
|
||||||
|
|
||||||
task_config (Dict[str, Any]): manages task configs integratedly,
|
|
||||||
including
|
|
||||||
* base snapshot
|
|
||||||
* task id (uuid)
|
|
||||||
* instruction
|
|
||||||
* setup config
|
|
||||||
* evaluator config
|
|
||||||
|
|
||||||
tmp_dir (str): temporary directory to store trajectory stuffs like
|
tmp_dir (str): temporary directory to store trajectory stuffs like
|
||||||
the extracted screenshots
|
the extracted screenshots
|
||||||
cache_dir (str): cache directory to cache task-related stuffs like
|
cache_dir (str): cache directory to cache task-related stuffs like
|
||||||
@@ -81,23 +72,20 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
# Initialize environment variables
|
# Initialize environment variables
|
||||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
||||||
|
self.snapshot_name = snapshot_name
|
||||||
self.tmp_dir_base: str = tmp_dir
|
self.tmp_dir_base: str = tmp_dir
|
||||||
self.cache_dir_base: str = cache_dir
|
self.cache_dir_base: str = cache_dir
|
||||||
self.vm_screen_size = screen_size
|
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
|
||||||
self.headless = headless
|
self.headless = headless
|
||||||
|
|
||||||
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
||||||
|
|
||||||
# task-aware stuffs
|
|
||||||
# todo: handling the logic of snapshot directory
|
|
||||||
self._set_task_info(task_config)
|
|
||||||
|
|
||||||
# Initialize emulator and controller
|
# Initialize emulator and controller
|
||||||
logger.info("Initializing...")
|
logger.info("Initializing...")
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
self.vm_ip = self._get_vm_ip()
|
self.vm_ip = self._get_vm_ip()
|
||||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
|
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||||
|
|
||||||
# Meta info of the VM, move to the reset() function
|
# Meta info of the VM, move to the reset() function
|
||||||
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
||||||
@@ -147,7 +135,7 @@ class DesktopEnv(gym.Env):
|
|||||||
raise Exception("Failed to get VM IP address!")
|
raise Exception("Failed to get VM IP address!")
|
||||||
|
|
||||||
def _save_state(self):
|
def _save_state(self):
|
||||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
||||||
|
|
||||||
def _get_screenshot(self):
|
def _get_screenshot(self):
|
||||||
# random_uuid = str(uuid.uuid4())
|
# random_uuid = str(uuid.uuid4())
|
||||||
@@ -167,7 +155,6 @@ class DesktopEnv(gym.Env):
|
|||||||
return screenshot_image_path
|
return screenshot_image_path
|
||||||
|
|
||||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||||
self.snapshot_path = task_config["snapshot"] # todo: save the snapshot when first start the environment, and then revert to it when reset
|
|
||||||
self.task_id: str = task_config["id"]
|
self.task_id: str = task_config["id"]
|
||||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||||
os.makedirs(self.cache_dir, exist_ok=True)
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
@@ -187,7 +174,7 @@ class DesktopEnv(gym.Env):
|
|||||||
if isinstance(self.evaluator["func"], list) \
|
if isinstance(self.evaluator["func"], list) \
|
||||||
else getattr(metrics, self.evaluator["func"])
|
else getattr(metrics, self.evaluator["func"])
|
||||||
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
||||||
if "result" in self.evaluator:
|
if "result" in self.evaluator and len(self.evaluator["result"])>0:
|
||||||
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
||||||
self.evaluator["result"]] \
|
self.evaluator["result"]] \
|
||||||
if isinstance(self.evaluator["result"], list) \
|
if isinstance(self.evaluator["result"], list) \
|
||||||
@@ -197,7 +184,7 @@ class DesktopEnv(gym.Env):
|
|||||||
if isinstance(self.metric, list) \
|
if isinstance(self.metric, list) \
|
||||||
else None
|
else None
|
||||||
|
|
||||||
if "expected" in self.evaluator:
|
if "expected" in self.evaluator and len(self.evaluator["expected"])>0:
|
||||||
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
||||||
self.evaluator["expected"]] \
|
self.evaluator["expected"]] \
|
||||||
if isinstance(self.evaluator["expected"], list) \
|
if isinstance(self.evaluator["expected"], list) \
|
||||||
@@ -239,8 +226,8 @@ class DesktopEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
||||||
|
|
||||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
|
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
print(self.vm_screen_size)
|
print(self.vm_screen_size)
|
||||||
|
|||||||
@@ -284,6 +284,15 @@ def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = N
|
|||||||
text = text.replace("\ufffc", "").replace("\ufffd", "")
|
text = text.replace("\ufffc", "").replace("\ufffd", "")
|
||||||
# }}} Text #
|
# }}} Text #
|
||||||
|
|
||||||
|
# Image {{{ #
|
||||||
|
try:
|
||||||
|
node.queryImage()
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
attribute_dict["image"] = "true"
|
||||||
|
# }}} Image #
|
||||||
|
|
||||||
# Selection {{{ #
|
# Selection {{{ #
|
||||||
try:
|
try:
|
||||||
node.querySelection()
|
node.querySelection()
|
||||||
|
|||||||
16
desktop_env/server/osbench_server.service
Normal file
16
desktop_env/server/osbench_server.service
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=OSBench Server
|
||||||
|
StartLimitIntervalSec=60
|
||||||
|
StartLimitBurst=4
|
||||||
|
After=network.target auditd.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
ExecStart=/usr/bin/python3 /home/user/main.py
|
||||||
|
User=user
|
||||||
|
WorkingDirectory=/home/user
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=1
|
||||||
|
Environment="DISPLAY=:1"
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=graphical.target
|
||||||
16
desktop_env/server/osbench_server@.service
Normal file
16
desktop_env/server/osbench_server@.service
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=OSBench Server
|
||||||
|
StartLimitIntervalSec=60
|
||||||
|
StartLimitBurst=4
|
||||||
|
After=network.target auditd.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
ExecStart=/usr/bin/python3 /home/user/main.py
|
||||||
|
User=user
|
||||||
|
WorkingDirectory=/home/user
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=1
|
||||||
|
Environment="DISPLAY=%i"
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=graphical.target
|
||||||
@@ -10,10 +10,6 @@
|
|||||||
"libreoffice_calc"
|
"libreoffice_calc"
|
||||||
],
|
],
|
||||||
"evaluator": {
|
"evaluator": {
|
||||||
"func": "infeasible",
|
"func": "infeasible"
|
||||||
"expected": {
|
|
||||||
},
|
|
||||||
"result": {
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,10 +10,6 @@
|
|||||||
"libreoffice_calc"
|
"libreoffice_calc"
|
||||||
],
|
],
|
||||||
"evaluator": {
|
"evaluator": {
|
||||||
"func": "infeasible",
|
"func": "infeasible"
|
||||||
"expected": {
|
|
||||||
},
|
|
||||||
"result": {
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
19
evaluation_examples/examples/multi_apps/demo.py
Normal file
19
evaluation_examples/examples/multi_apps/demo.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
file_path = "/Users/lxc/Downloads/Speedtest.csv"
|
||||||
|
# 找到csv第二行的第二个数据格里的值
|
||||||
|
# with open(file_path, "r") as f:
|
||||||
|
# for i, line in enumerate(f):
|
||||||
|
# if i == 1:
|
||||||
|
# data = line.split(",")[1]
|
||||||
|
# break
|
||||||
|
# print(data)
|
||||||
|
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
reader = pd.read_csv(f, sep=',', header=None)
|
||||||
|
# for column in reader.columns:
|
||||||
|
# if column.startswith("TEST_DATE"):
|
||||||
|
# data_col = column
|
||||||
|
# break
|
||||||
|
for data in reader['TEST_DATE']:
|
||||||
|
print(data)
|
||||||
@@ -103,7 +103,6 @@
|
|||||||
"1e8df695-bd1b-45b3-b557-e7d599cf7597",
|
"1e8df695-bd1b-45b3-b557-e7d599cf7597",
|
||||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||||
"7b802dad-6e0f-4204-9815-d4e3f57627d8",
|
|
||||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
||||||
"0326d92d-d218-48a8-9ca1-981cd6d064c7",
|
"0326d92d-d218-48a8-9ca1-981cd6d064c7",
|
||||||
"0a2e43bf-b26c-4631-a966-af9dfa12c9e5",
|
"0a2e43bf-b26c-4631-a966-af9dfa12c9e5",
|
||||||
@@ -380,7 +379,6 @@
|
|||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||||
"ae506c68-352c-4094-9caa-ee9d42052317",
|
"ae506c68-352c-4094-9caa-ee9d42052317",
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||||
"c714dcee-cad3-4e12-8f3c-12bdcfcdb048",
|
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140",
|
"9d425400-e9b2-4424-9a4b-d4c7abac4140",
|
||||||
|
|||||||
102
evaluation_examples/test_small.json
Normal file
102
evaluation_examples/test_small.json
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
{
|
||||||
|
"chrome": [
|
||||||
|
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||||
|
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||||
|
],
|
||||||
|
"gimp": [
|
||||||
|
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
||||||
|
"554785e9-4523-4e7a-b8e1-8016f565f56a"
|
||||||
|
],
|
||||||
|
"libreoffice_calc": [
|
||||||
|
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||||
|
"42e0a640-4f19-4b28-973d-729602b5a4a7"
|
||||||
|
],
|
||||||
|
"libreoffice_impress": [
|
||||||
|
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||||
|
"550ce7e7-747b-495f-b122-acdc4d0b8e54"
|
||||||
|
],
|
||||||
|
"libreoffice_writer": [
|
||||||
|
"0810415c-bde4-4443-9047-d5f70165a697",
|
||||||
|
"0a0faba3-5580-44df-965d-f562a99b291c"
|
||||||
|
],
|
||||||
|
"multi_apps": [
|
||||||
|
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||||
|
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||||
|
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||||
|
"510f64c8-9bcc-4be1-8d30-638705850618",
|
||||||
|
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||||
|
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
|
||||||
|
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||||
|
"f7dfbef3-7697-431c-883a-db8583a4e4f9",
|
||||||
|
"6d72aad6-187a-4392-a4c4-ed87269c51cf",
|
||||||
|
"f918266a-b3e0-4914-865d-4faa564f1aef",
|
||||||
|
"da52d699-e8d2-4dc5-9191-a2199e0b6a9b",
|
||||||
|
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
|
||||||
|
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
|
||||||
|
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
|
||||||
|
"eb303e01-261e-4972-8c07-c9b4e7a4922a",
|
||||||
|
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
|
||||||
|
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
|
||||||
|
"8e116af7-7db7-4e35-a68b-b0939c066c78",
|
||||||
|
"185f29bd-5da0-40a6-b69c-ba7f4e0324ef",
|
||||||
|
"2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e",
|
||||||
|
"3a93cae4-ad3e-403e-8c12-65303b271818",
|
||||||
|
"1f18aa87-af6f-41ef-9853-cdb8f32ebdea",
|
||||||
|
"26150609-0da3-4a7d-8868-0faf9c5f01bb",
|
||||||
|
"7e287123-70ca-47b9-8521-47db09b69b14",
|
||||||
|
"e2392362-125e-4f76-a2ee-524b183a3412",
|
||||||
|
"26660ad1-6ebb-4f59-8cba-a8432dfe8d38",
|
||||||
|
"a82b78bb-7fde-4cb3-94a4-035baf10bcf0",
|
||||||
|
"36037439-2044-4b50-b9d1-875b5a332143",
|
||||||
|
"716a6079-22da-47f1-ba73-c9d58f986a38",
|
||||||
|
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
|
||||||
|
"6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a",
|
||||||
|
"da922383-bfa4-4cd3-bbad-6bebab3d7742",
|
||||||
|
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d",
|
||||||
|
"81c425f5-78f3-4771-afd6-3d2973825947",
|
||||||
|
"227d2f97-562b-4ccb-ae47-a5ec9e142fbb",
|
||||||
|
"20236825-b5df-46e7-89bf-62e1d640a897",
|
||||||
|
"02ce9a50-7af2-47ed-8596-af0c230501f8",
|
||||||
|
"4c26e3f3-3a14-4d86-b44a-d3cedebbb487",
|
||||||
|
"09a37c51-e625-49f4-a514-20a773797a8a",
|
||||||
|
"3e3fc409-bff3-4905-bf16-c968eee3f807",
|
||||||
|
"415ef462-bed3-493a-ac36-ca8c6d23bf1b",
|
||||||
|
"9f3bb592-209d-43bc-bb47-d77d9df56504",
|
||||||
|
"dd60633f-2c72-42ba-8547-6f2c8cb0fdb0",
|
||||||
|
"3f05f3b9-29ba-4b6b-95aa-2204697ffc06",
|
||||||
|
"f8369178-fafe-40c2-adc4-b9b08a125456",
|
||||||
|
"778efd0a-153f-4842-9214-f05fc176b877",
|
||||||
|
"47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5",
|
||||||
|
"c2751594-0cd5-4088-be1b-b5f2f9ec97c4",
|
||||||
|
"48c46dc7-fe04-4505-ade7-723cba1aa6f6",
|
||||||
|
"42d25c08-fb87-4927-8b65-93631280a26f",
|
||||||
|
"bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108",
|
||||||
|
"3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
||||||
|
"d68204bf-11c1-4b13-b48b-d303c73d4bf6",
|
||||||
|
"91190194-f406-4cd6-b3f9-c43fac942b22",
|
||||||
|
"7f35355e-02a6-45b5-b140-f0be698bcf85",
|
||||||
|
"98e8e339-5f91-4ed2-b2b2-12647cb134f4",
|
||||||
|
"df67aebb-fb3a-44fd-b75b-51b6012df509",
|
||||||
|
"5df7b33a-9f77-4101-823e-02f863e1c1ae",
|
||||||
|
"22a4636f-8179-4357-8e87-d1743ece1f81",
|
||||||
|
"236833a3-5704-47fc-888c-4f298f09f799"
|
||||||
|
],
|
||||||
|
"os": [
|
||||||
|
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||||
|
"5812b315-e7bd-4265-b51f-863c02174c28",
|
||||||
|
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||||
|
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82"
|
||||||
|
],
|
||||||
|
"thunderbird": [
|
||||||
|
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||||
|
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||||
|
],
|
||||||
|
"vlc": [
|
||||||
|
"59f21cfb-0120-4326-b255-a5b827b38967",
|
||||||
|
"8f080098-ddb1-424c-b438-4e96e5e4786e"
|
||||||
|
],
|
||||||
|
"vs_code": [
|
||||||
|
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||||
|
"53ad5833-3455-407b-bbc6-45b4c79ab8fb"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -1,432 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(30, stop_recording)
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id, gpt4_model="gpt-4-0125-preview"):
|
|
||||||
action_space = "pyautogui"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
logger.info("Running example %s/%s", example_class, example_id)
|
|
||||||
logger.info("Using model %s", gpt4_model)
|
|
||||||
# logger.info("Using model %s", gemini_model)
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'], max_tokens=1000,
|
|
||||||
action_space=action_space, exp="a11y_tree")
|
|
||||||
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="a11y_tree")
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "a11y_tree", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
os_list = [
|
|
||||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
|
||||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
|
||||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
|
||||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
|
||||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
|
||||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
|
||||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
|
||||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
|
||||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
|
||||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
|
||||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
|
||||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
|
||||||
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
|
|
||||||
"13584542-872b-42d8-b299-866967b5c3ef",
|
|
||||||
"23393935-50c7-4a86-aeea-2b78fd089c5c"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in os_list:
|
|
||||||
# try:
|
|
||||||
# main("os", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
vlc_list = [
|
|
||||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
|
||||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
|
||||||
"8f080098-ddb1-424c-b438-4e96e5e4786e",
|
|
||||||
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
|
|
||||||
"fba2c100-79e8-42df-ae74-b592418d54f4",
|
|
||||||
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
|
|
||||||
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
|
|
||||||
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
|
|
||||||
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
|
|
||||||
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
|
|
||||||
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
|
|
||||||
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
|
|
||||||
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
|
|
||||||
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
|
|
||||||
]
|
|
||||||
|
|
||||||
chrome_list = [
|
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
|
||||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
|
||||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
|
||||||
]
|
|
||||||
|
|
||||||
calc_list = [
|
|
||||||
"eb03d19a-b88d-4de4-8a64-ca0ac66f426b",
|
|
||||||
"0bf05a7d-b28b-44d2-955a-50b41e24012a",
|
|
||||||
"7a4e4bc8-922c-4c84-865c-25ba34136be1",
|
|
||||||
"2bd59342-0664-4ccb-ba87-79379096cc08",
|
|
||||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
|
||||||
"7efeb4b1-3d19-4762-b163-63328d66303b",
|
|
||||||
"4e6fcf72-daf3-439f-a232-c434ce416af6",
|
|
||||||
"6054afcb-5bab-4702-90a0-b259b5d3217c",
|
|
||||||
"abed40dc-063f-4598-8ba5-9fe749c0615d",
|
|
||||||
"01b269ae-2111-4a07-81fd-3fcd711993b0",
|
|
||||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
|
||||||
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
|
|
||||||
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
|
|
||||||
"51b11269-2ca8-4b2a-9163-f21758420e78",
|
|
||||||
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
|
|
||||||
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
|
||||||
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
|
|
||||||
"3aaa4e37-dc91-482e-99af-132a612d40f3",
|
|
||||||
"37608790-6147-45d0-9f20-1137bb35703d",
|
|
||||||
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
|
||||||
"d681960f-7bc3-4286-9913-a8812ba3261a",
|
|
||||||
"21df9241-f8d7-4509-b7f1-37e501a823f7",
|
|
||||||
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
|
|
||||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
|
||||||
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
|
|
||||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
|
||||||
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c",
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in calc_list:
|
|
||||||
# main("libreoffice_calc", example_id)
|
|
||||||
|
|
||||||
impress_list = [
|
|
||||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
|
||||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
|
||||||
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
|
||||||
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
|
||||||
"c59742c0-4323-4b9d-8a02-723c251deaa0",
|
|
||||||
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
|
||||||
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
|
||||||
"0f84bef9-9790-432e-92b7-eece357603fb",
|
|
||||||
"ce88f674-ab7a-43da-9201-468d38539e4a",
|
|
||||||
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
|
||||||
"a097acff-6266-4291-9fbd-137af7ecd439",
|
|
||||||
"bf4e9888-f10f-47af-8dba-76413038b73c",
|
|
||||||
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
|
||||||
]
|
|
||||||
# for example_id in impress_list:
|
|
||||||
# main("libreoffice_impress", example_id)
|
|
||||||
|
|
||||||
thunderbird_list = [
|
|
||||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
|
||||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
|
||||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
|
||||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
|
||||||
]
|
|
||||||
# for example_id in thunderbird_list:
|
|
||||||
# main("thunderbird", example_id)
|
|
||||||
|
|
||||||
gimp_list = [
|
|
||||||
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
|
||||||
"554785e9-4523-4e7a-b8e1-8016f565f56a",
|
|
||||||
"77b8ab4d-994f-43ac-8930-8ca087d7c4b4",
|
|
||||||
"f4aec372-4fb0-4df5-a52b-79e0e2a5d6ce",
|
|
||||||
"d52d6308-ec58-42b7-a2c9-de80e4837b2b",
|
|
||||||
"2a729ded-3296-423d-aec4-7dd55ed5fbb3",
|
|
||||||
"b148e375-fe0b-4bec-90e7-38632b0d73c2",
|
|
||||||
"a746add2-cab0-4740-ac36-c3769d9bfb46",
|
|
||||||
"7b7617bd-57cc-468e-9c91-40c4ec2bcb3d",
|
|
||||||
"d16c99dc-2a1e-46f2-b350-d97c86c85c15",
|
|
||||||
"06ca5602-62ca-47f6-ad4f-da151cde54cc",
|
|
||||||
"e2dd0213-26db-4349-abe5-d5667bfd725c",
|
|
||||||
"f723c744-e62c-4ae6-98d1-750d3cd7d79d",
|
|
||||||
"72f83cdc-bf76-4531-9a1b-eb893a13f8aa",
|
|
||||||
"7767eef2-56a3-4cea-8c9f-48c070c7d65b",
|
|
||||||
"734d6579-c07d-47a8-9ae2-13339795476b"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in gimp_list:
|
|
||||||
# try:
|
|
||||||
# main("gimp", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
vs_code_list = [
|
|
||||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
|
||||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
|
||||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
|
||||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
|
||||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
|
||||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
|
||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
|
||||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in vs_code_list:
|
|
||||||
# try:
|
|
||||||
# main("vs_code", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# for example_id in tqdm(vlc_list):
|
|
||||||
# try:
|
|
||||||
# main("vlc", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An error occurred while running the example: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
chrome_list = [
|
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
|
||||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
|
||||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
|
||||||
]
|
|
||||||
# for example_id in tqdm(chrome_list):
|
|
||||||
# try:
|
|
||||||
# main("chrome", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An error occurred while running the example: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
vs_code_list = [
|
|
||||||
# "0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
|
||||||
# "53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
|
||||||
# "eabc805a-bfcf-4460-b250-ac92135819f6",
|
|
||||||
# "982d12a5-beab-424f-8d38-d2a48429e511",
|
|
||||||
# "4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
|
||||||
# "e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
|
||||||
# "9439a27b-18ae-42d8-9778-5f68f891805e",
|
|
||||||
# "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
|
||||||
# "930fdb3b-11a8-46fe-9bac-577332e2640e",
|
|
||||||
# "276cc624-87ea-4f08-ab93-f770e3790175",
|
|
||||||
# "9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in tqdm(vs_code_list):
|
|
||||||
try:
|
|
||||||
main("vs_code", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while running the example: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
thunderbird_list = [
|
|
||||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
# "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
|
||||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
|
||||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
|
||||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in tqdm(thunderbird_list):
|
|
||||||
# try:
|
|
||||||
# main("thunderbird", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An error occurred while running the example: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
multiple_list = [
|
|
||||||
# "f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
|
|
||||||
# "897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
|
||||||
"2fe4b718-3bd7-46ec-bdce-b184f5653624",
|
|
||||||
"3680a5ee-6870-426a-a997-eba929a0d25c",
|
|
||||||
# "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
|
||||||
# "b52b40a5-ad70-4c53-b5b0-5650a8387052",
|
|
||||||
# "46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
|
||||||
# "2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
|
||||||
# "51f5801c-18b3-4f25-b0c3-02f85507a078",
|
|
||||||
"58565672-7bfe-48ab-b828-db349231de6b",
|
|
||||||
# "2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
|
|
||||||
# "510f64c8-9bcc-4be1-8d30-638705850618",
|
|
||||||
# "937087b6-f668-4ba6-9110-60682ee33441",
|
|
||||||
# "ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
|
||||||
# "3680a5ee-6870-426a-a997-eba929a0d25c",
|
|
||||||
# "e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
|
||||||
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
|
||||||
# "58565672-7bfe-48ab-b828-db349231de6b",
|
|
||||||
# "2fe4b718-3bd7-46ec-bdce-b184f5653624"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in multiple_list:
|
|
||||||
try:
|
|
||||||
main("multi_apps", example_id, gpt4_model="gpt-3.5-turbo-16k")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("An error occurred while running the example: %s", e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
@@ -1,335 +0,0 @@
|
|||||||
# todo: unifiy all the experiments python file into one file
|
|
||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent # todo: change the name into PromptAgent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
# todo: move the PATH_TO_VM to the argparser
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True,
|
|
||||||
max_time=600):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example,
|
|
||||||
headless=True
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(120, stop_recording)
|
|
||||||
# todo: make sure we got the video file, check the bug
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
# fixme: change to write the result into a separate file
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# todo: append the result to the wandb for visualization
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
|
||||||
# todo: merge the main function into the run_one_example function
|
|
||||||
# fixme: change all the settings like action_space, model, etc. to the argparser
|
|
||||||
action_space = "pyautogui"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
logger.info("Running example %s/%s", example_class, example_id)
|
|
||||||
logger.info("Using model %s", gpt4_model)
|
|
||||||
# logger.info("Using model %s", gemini_model)
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key,
|
|
||||||
model=gpt4_model,
|
|
||||||
instruction=example['instruction'],
|
|
||||||
action_space=action_space,
|
|
||||||
exp="screenshot")
|
|
||||||
#
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space, exp="screenshot")
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "screenshot", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if os.path.exists(os.path.join(example_trajectory_dir, "trajectory.json")):
|
|
||||||
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "r") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
# strip the last line if it is empty
|
|
||||||
lines = [line.strip() for line in lines if line.strip() != ""]
|
|
||||||
if len(lines) > 0:
|
|
||||||
last_line = json.loads(lines[-1])
|
|
||||||
if "result" in last_line:
|
|
||||||
logger.info(
|
|
||||||
f"evaluation_examples/examples/{example_class}/{example_id}.json" + "has been evaluated. Skip.")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(1200, run_one_example, args=(example, agent, 15, example_trajectory_dir))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred: {e}")
|
|
||||||
with open(os.path.join(example_trajectory_dir, "trajectory.json"), "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"error": str(e)
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
|
||||||
def config() -> argparse.Namespace:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run end-to-end evaluation on the benchmark"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--render", action="store_true", help="Render the browser"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--slow_mo",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Slow down the browser by the specified amount",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--action_set_tag", default="id_accessibility_tree", help="Action type"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--observation_type",
|
|
||||||
choices=[
|
|
||||||
"accessibility_tree",
|
|
||||||
"accessibility_tree_with_captioner",
|
|
||||||
"html",
|
|
||||||
"image",
|
|
||||||
"image_som",
|
|
||||||
],
|
|
||||||
default="accessibility_tree",
|
|
||||||
help="Observation type",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--current_viewport_only",
|
|
||||||
action="store_true",
|
|
||||||
help="Only use the current viewport for the observation",
|
|
||||||
)
|
|
||||||
parser.add_argument("--viewport_width", type=int, default=1280)
|
|
||||||
parser.add_argument("--viewport_height", type=int, default=2048)
|
|
||||||
parser.add_argument("--save_trace_enabled", action="store_true")
|
|
||||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
|
||||||
|
|
||||||
parser.add_argument("--max_steps", type=int, default=30)
|
|
||||||
|
|
||||||
# agent config
|
|
||||||
parser.add_argument("--agent_type", type=str, default="prompt")
|
|
||||||
parser.add_argument(
|
|
||||||
"--instruction_path",
|
|
||||||
type=str,
|
|
||||||
default="agents/prompts/state_action_agent.json",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--parsing_failure_th",
|
|
||||||
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repeating_action_failure_th",
|
|
||||||
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--test_config_base_dir", type=str)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_captioning_model_device",
|
|
||||||
type=str,
|
|
||||||
default="cpu",
|
|
||||||
choices=["cpu", "cuda"],
|
|
||||||
help="Device to run eval captioning model on. By default, runs it on CPU.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_captioning_model",
|
|
||||||
type=str,
|
|
||||||
default="Salesforce/blip2-flan-t5-xl",
|
|
||||||
choices=["Salesforce/blip2-flan-t5-xl"],
|
|
||||||
help="Captioning backbone for VQA-type evals.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--captioning_model",
|
|
||||||
type=str,
|
|
||||||
default="Salesforce/blip2-flan-t5-xl",
|
|
||||||
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
|
|
||||||
help="Captioning backbone for accessibility tree alt text.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# lm config
|
|
||||||
parser.add_argument("--provider", type=str, default="openai")
|
|
||||||
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
|
||||||
parser.add_argument("--mode", type=str, default="chat")
|
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
|
||||||
parser.add_argument("--context_length", type=int, default=0)
|
|
||||||
parser.add_argument("--max_tokens", type=int, default=384)
|
|
||||||
parser.add_argument("--stop_token", type=str, default=None)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_retry",
|
|
||||||
type=int,
|
|
||||||
help="max retry times to perform generations when parsing fails",
|
|
||||||
default=1,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_obs_length",
|
|
||||||
type=int,
|
|
||||||
help="when not zero, will truncate the observation to this length before feeding to the model",
|
|
||||||
default=3840,
|
|
||||||
)
|
|
||||||
|
|
||||||
# example config
|
|
||||||
parser.add_argument("--test_start_idx", type=int, default=0)
|
|
||||||
parser.add_argument("--test_end_idx", type=int, default=910)
|
|
||||||
|
|
||||||
# logging related
|
|
||||||
parser.add_argument("--result_dir", type=str, default="")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# check the whether the action space is compatible with the observation space
|
|
||||||
if (
|
|
||||||
args.action_set_tag == "id_accessibility_tree"
|
|
||||||
and args.observation_type
|
|
||||||
not in [
|
|
||||||
"accessibility_tree",
|
|
||||||
"accessibility_tree_with_captioner",
|
|
||||||
"image_som",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
####### The complete version of the list of examples #######
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
args = config()
|
|
||||||
args.sleep_after_execution = 2.5
|
|
||||||
prepare(args)
|
|
||||||
|
|
||||||
# todo: add recorder of the progress of the examples
|
|
||||||
|
|
||||||
# todo: remove the useless example files
|
|
||||||
|
|
||||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
|
||||||
test_all_meta = json.load(f)
|
|
||||||
|
|
||||||
for domain in test_all_meta:
|
|
||||||
for example_id in test_all_meta[domain]:
|
|
||||||
main(domain, example_id, args.model)
|
|
||||||
@@ -1,361 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu2\Ubuntu2.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
# PATH_TO_VM = "../../../../大文件/镜像/Ubuntu-1218/Ubuntu/Ubuntu.vmx"
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(30, stop_recording)
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id, gpt4_model="gpt-4-vision-preview"):
|
|
||||||
action_space = "pyautogui"
|
|
||||||
# example_class = "libreoffice_calc"
|
|
||||||
# example_id = "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
|
||||||
# example_id = "01b269ae-2111-4a07-81fd-3fcd711993b0"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
logger.info("Running example %s/%s", example_class, example_id)
|
|
||||||
logger.info("Using model %s", gpt4_model)
|
|
||||||
# logger.info("Using model %s", gemini_model)
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
# example["snapshot"] = "exp_setup4"
|
|
||||||
# example["snapshot"] = "Snapshot 30"
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
|
|
||||||
action_space=action_space, exp="both")
|
|
||||||
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space, exp="both")
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "both", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
os_list = [
|
|
||||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
|
||||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
|
||||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
|
||||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
|
||||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
|
||||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
|
||||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
|
||||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
|
||||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
|
||||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
|
||||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
|
||||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
|
||||||
"a4d98375-215b-4a4d-aee9-3d4370fccc41",
|
|
||||||
"13584542-872b-42d8-b299-866967b5c3ef",
|
|
||||||
"23393935-50c7-4a86-aeea-2b78fd089c5c"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in os_list:
|
|
||||||
# try:
|
|
||||||
# main("os", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
calc_list = [
|
|
||||||
"a9f325aa-8c05-4e4f-8341-9e4358565f4f",
|
|
||||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
|
||||||
"7efeb4b1-3d19-4762-b163-63328d66303b",
|
|
||||||
"4e6fcf72-daf3-439f-a232-c434ce416af6",
|
|
||||||
"6054afcb-5bab-4702-90a0-b259b5d3217c",
|
|
||||||
"abed40dc-063f-4598-8ba5-9fe749c0615d",
|
|
||||||
"01b269ae-2111-4a07-81fd-3fcd711993b0",
|
|
||||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
|
||||||
"af2b02f7-acee-4be4-8b66-499fab394915",
|
|
||||||
"da1d63b8-fa12-417b-ba18-f748e5f770f3",
|
|
||||||
"636380ea-d5f6-4474-b6ca-b2ed578a20f1",
|
|
||||||
"5ba77536-05c5-4aae-a9ff-6e298d094c3e",
|
|
||||||
"4bc4eaf4-ca5e-4db2-8138-8d4e65af7c0b",
|
|
||||||
"672a1b02-c62f-4ae2-acf0-37f5fb3052b0",
|
|
||||||
"648fe544-16ba-44af-a587-12ccbe280ea6",
|
|
||||||
"8985d1e4-5b99-4711-add4-88949ebb2308",
|
|
||||||
"9e606842-2e27-43bf-b1d1-b43289c9589b",
|
|
||||||
"fcb6e45b-25c4-4087-9483-03d714f473a9",
|
|
||||||
"68c0c5b7-96f3-4e87-92a7-6c1b967fd2d2",
|
|
||||||
"fff629ea-046e-4793-8eec-1a5a15c3eb35",
|
|
||||||
"5c9a206c-bb00-4fb6-bb46-ee675c187df5",
|
|
||||||
"e975ae74-79bd-4672-8d1c-dc841a85781d",
|
|
||||||
"34a6938a-58da-4897-8639-9b90d6db5391",
|
|
||||||
"b5a22759-b4eb-4bf2-aeed-ad14e8615f19",
|
|
||||||
"2f9913a1-51ed-4db6-bfe0-7e1c95b3139e",
|
|
||||||
"2558031e-401d-4579-8e00-3ecf540fb492",
|
|
||||||
"0cecd4f3-74de-457b-ba94-29ad6b5dafb6",
|
|
||||||
"4188d3a4-077d-46b7-9c86-23e1a036f6c1",
|
|
||||||
"51b11269-2ca8-4b2a-9163-f21758420e78",
|
|
||||||
"7e429b8d-a3f0-4ed0-9b58-08957d00b127",
|
|
||||||
"347ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
|
||||||
"6e99a1ad-07d2-4b66-a1ce-ece6d99c20a5",
|
|
||||||
"3aaa4e37-dc91-482e-99af-132a612d40f3",
|
|
||||||
"37608790-6147-45d0-9f20-1137bb35703d",
|
|
||||||
"f9584479-3d0d-4c79-affa-9ad7afdd8850",
|
|
||||||
"d681960f-7bc3-4286-9913-a8812ba3261a",
|
|
||||||
"21df9241-f8d7-4509-b7f1-37e501a823f7",
|
|
||||||
"1334ca3e-f9e3-4db8-9ca7-b4c653be7d17",
|
|
||||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
|
||||||
"aa3a8974-2e85-438b-b29e-a64df44deb4b",
|
|
||||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
|
||||||
"4f07fbe9-70de-4927-a4d5-bb28bc12c52c"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in calc_list:
|
|
||||||
# try:
|
|
||||||
# main("libreoffice_calc", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
impress_list = [
|
|
||||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
|
||||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
|
||||||
"455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
|
||||||
"af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
|
||||||
"c59742c0-4323-4b9d-8a02-723c251deaa0",
|
|
||||||
"ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
|
||||||
"9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
|
||||||
"0f84bef9-9790-432e-92b7-eece357603fb",
|
|
||||||
"ce88f674-ab7a-43da-9201-468d38539e4a",
|
|
||||||
"3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
|
||||||
"a097acff-6266-4291-9fbd-137af7ecd439",
|
|
||||||
"bf4e9888-f10f-47af-8dba-76413038b73c",
|
|
||||||
"21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in impress_list:
|
|
||||||
# try:
|
|
||||||
# main("libreoffice_impress", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
vs_code_list = [
|
|
||||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
|
||||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
|
||||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
|
||||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
|
||||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
|
||||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
|
||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
|
||||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in vs_code_list:
|
|
||||||
# try:
|
|
||||||
# main("vs_code", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
multiple_list = [
|
|
||||||
"f8cfa149-d1c1-4215-8dac-4a0932bad3c2",
|
|
||||||
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
|
||||||
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
|
||||||
"b52b40a5-ad70-4c53-b5b0-5650a8387052",
|
|
||||||
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
|
||||||
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
|
||||||
"51f5801c-18b3-4f25-b0c3-02f85507a078",
|
|
||||||
"2c9fc0de-3ee7-45e1-a5df-c86206ad78b5",
|
|
||||||
"510f64c8-9bcc-4be1-8d30-638705850618",
|
|
||||||
"937087b6-f668-4ba6-9110-60682ee33441",
|
|
||||||
"ee9a3c83-f437-4879-8918-be5efbb9fac7",
|
|
||||||
"3680a5ee-6870-426a-a997-eba929a0d25c",
|
|
||||||
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
|
||||||
"58565672-7bfe-48ab-b828-db349231de6b",
|
|
||||||
"2fe4b718-3bd7-46ec-bdce-b184f5653624"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in multiple_list:
|
|
||||||
# try:
|
|
||||||
# main("multi_apps", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
chrome_list = [
|
|
||||||
# "bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
|
||||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
|
||||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in chrome_list:
|
|
||||||
# try:
|
|
||||||
# main("chrome", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error("An error occurred while running the example: %s", e)
|
|
||||||
# continue
|
|
||||||
|
|
||||||
|
|
||||||
writer_list = [
|
|
||||||
"6ada715d-3aae-4a32-a6a7-429b2e43fb93",
|
|
||||||
"ecc2413d-8a48-416e-a3a2-d30106ca36cb",
|
|
||||||
"0e47de2a-32e0-456c-a366-8c607ef7a9d2",
|
|
||||||
"4bcb1253-a636-4df4-8cb0-a35c04dfef31",
|
|
||||||
"0810415c-bde4-4443-9047-d5f70165a697",
|
|
||||||
"e528b65e-1107-4b8c-8988-490e4fece599",
|
|
||||||
"66399b0d-8fda-4618-95c4-bfc6191617e9",
|
|
||||||
"936321ce-5236-426a-9a20-e0e3c5dc536f",
|
|
||||||
"3ef2b351-8a84-4ff2-8724-d86eae9b842e",
|
|
||||||
"0b17a146-2934-46c7-8727-73ff6b6483e8",
|
|
||||||
"0e763496-b6bb-4508-a427-fad0b6c3e195",
|
|
||||||
"f178a4a9-d090-4b56-bc4c-4b72a61a035d",
|
|
||||||
"adf5e2c3-64c7-4644-b7b6-d2f0167927e7",
|
|
||||||
"0a0faba3-5580-44df-965d-f562a99b291c",
|
|
||||||
"e246f6d8-78d7-44ac-b668-fcf47946cb50",
|
|
||||||
"8472fece-c7dd-4241-8d65-9b3cd1a0b568",
|
|
||||||
"88fe4b2d-3040-4c70-9a70-546a47764b48",
|
|
||||||
"d53ff5ee-3b1a-431e-b2be-30ed2673079b",
|
|
||||||
"72b810ef-4156-4d09-8f08-a0cf57e7cefe",
|
|
||||||
"6f81754e-285d-4ce0-b59e-af7edb02d108",
|
|
||||||
"b21acd93-60fd-4127-8a43-2f5178f4a830"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in writer_list:
|
|
||||||
try:
|
|
||||||
main("libreoffice_writer", example_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("An error occurred while running the example: %s", e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
import ctypes
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(30, stop_recording)
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id):
|
|
||||||
action_space = "pyautogui"
|
|
||||||
gpt4_model = "gpt-4-vision-preview"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, instruction=example['instruction'],
|
|
||||||
action_space=action_space, exp="seeact")
|
|
||||||
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "seeact", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
xx_list = [
|
|
||||||
]
|
|
||||||
for example_id in xx_list:
|
|
||||||
main("xx", example_id)
|
|
||||||
@@ -1,261 +0,0 @@
|
|||||||
#import ctypes
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import func_timeout
|
|
||||||
|
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
|
||||||
from mm_agents.gpt_4v_agent import GPT4v_Agent
|
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
PATH_TO_VM = r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
|
|
||||||
|
|
||||||
|
|
||||||
def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_trajectory", recording=True):
|
|
||||||
trajectory_recording_path = os.path.join(example_trajectory_dir, "trajectory.json")
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=PATH_TO_VM,
|
|
||||||
action_space=agent.action_space,
|
|
||||||
task_config=example
|
|
||||||
)
|
|
||||||
# reset the environment to certain snapshot
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
step_num = 0
|
|
||||||
|
|
||||||
if recording:
|
|
||||||
# send a request to the server to start recording
|
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_num < max_steps:
|
|
||||||
actions = agent.predict(observation)
|
|
||||||
step_num += 1
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_num, action)
|
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_trajectory_dir, f"step_{step_num}_{action_timestamp}.png"), "wb") as _f:
|
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_num,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_num}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
|
|
||||||
def stop_recording():
|
|
||||||
try:
|
|
||||||
env.controller.end_recording(os.path.join(example_trajectory_dir, "recording.mp4"))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while stopping the recording: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
func_timeout.func_timeout(30, stop_recording)
|
|
||||||
except func_timeout.exceptions.FunctionTimedOut:
|
|
||||||
logger.info("Recording timed out.")
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
|
|
||||||
with open(trajectory_recording_path, "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"result": result
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
# env.close()
|
|
||||||
logger.info("Environment closed.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(example_class, example_id):
|
|
||||||
action_space = "pyautogui"
|
|
||||||
gpt4_model = "gpt-4-vision-preview"
|
|
||||||
gemini_model = "gemini-pro-vision"
|
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r", encoding="utf-8") as f:
|
|
||||||
example = json.load(f)
|
|
||||||
example["snapshot"] = "exp_v5"
|
|
||||||
|
|
||||||
logger.info("TASK: %s/%s", example_class, example_id)
|
|
||||||
|
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, model=gpt4_model, max_tokens=1000, instruction=example['instruction'],
|
|
||||||
action_space=action_space, exp="som")
|
|
||||||
|
|
||||||
# api_key = os.environ.get("GENAI_API_KEY")
|
|
||||||
# agent = GeminiPro_Agent(api_key=api_key, model=gemini_model, instruction=example['instruction'], action_space=action_space)
|
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
|
||||||
|
|
||||||
example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gpt4_model, example_id)
|
|
||||||
# example_trajectory_dir = os.path.join(root_trajectory_dir, "som", example_class, gemini_model, example_id)
|
|
||||||
|
|
||||||
os.makedirs(example_trajectory_dir, exist_ok=True)
|
|
||||||
|
|
||||||
run_one_example(example, agent, 15, example_trajectory_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from tqdm import tqdm
|
|
||||||
# impress_list = [
|
|
||||||
# # "5d901039-a89c-4bfb-967b-bf66f4df075e",
|
|
||||||
# "550ce7e7-747b-495f-b122-acdc4d0b8e54",
|
|
||||||
# "455d3c66-7dc6-4537-a39a-36d3e9119df7",
|
|
||||||
# "af23762e-2bfd-4a1d-aada-20fa8de9ce07",
|
|
||||||
# "c59742c0-4323-4b9d-8a02-723c251deaa0",
|
|
||||||
# "ef9d12bd-bcee-4ba0-a40e-918400f43ddf",
|
|
||||||
# "9ec204e4-f0a3-42f8-8458-b772a6797cab",
|
|
||||||
# "0f84bef9-9790-432e-92b7-eece357603fb",
|
|
||||||
# "ce88f674-ab7a-43da-9201-468d38539e4a",
|
|
||||||
# "3b27600c-3668-4abd-8f84-7bcdebbccbdb",
|
|
||||||
# "a097acff-6266-4291-9fbd-137af7ecd439",
|
|
||||||
# "bf4e9888-f10f-47af-8dba-76413038b73c",
|
|
||||||
# "21760ecb-8f62-40d2-8d85-0cee5725cb72"
|
|
||||||
# ]
|
|
||||||
# for example_id in impress_list:
|
|
||||||
# main("libreoffice_impress", example_id)
|
|
||||||
|
|
||||||
vlc_list = [
|
|
||||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
|
||||||
"8ba5ae7a-5ae5-4eab-9fcc-5dd4fe3abf89",
|
|
||||||
"8f080098-ddb1-424c-b438-4e96e5e4786e",
|
|
||||||
"bba3381f-b5eb-4439-bd9e-80c22218d5a7",
|
|
||||||
"fba2c100-79e8-42df-ae74-b592418d54f4",
|
|
||||||
"efcf0d81-0835-4880-b2fd-d866e8bc2294",
|
|
||||||
"8d9fd4e2-6fdb-46b0-b9b9-02f06495c62f",
|
|
||||||
"aa4b5023-aef6-4ed9-bdc9-705f59ab9ad6",
|
|
||||||
"386dbd0e-0241-4a0a-b6a2-6704fba26b1c",
|
|
||||||
"9195653c-f4aa-453d-aa95-787f6ccfaae9",
|
|
||||||
"d06f0d4d-2cd5-4ede-8de9-598629438c6e",
|
|
||||||
"a5bbbcd5-b398-4c91-83d4-55e1e31bbb81",
|
|
||||||
"f3977615-2b45-4ac5-8bba-80c17dbe2a37",
|
|
||||||
"215dfd39-f493-4bc3-a027-8a97d72c61bf"
|
|
||||||
]
|
|
||||||
|
|
||||||
# for example_id in tqdm(vlc_list):
|
|
||||||
# try:
|
|
||||||
# main("vlc", example_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An error occurred while running the example: {e}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
chrome_list = [
|
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"7a5a7856-f1b6-42a4-ade9-1ca81ca0f263",
|
|
||||||
"44ee5668-ecd5-4366-a6ce-c1c9b8d4e938",
|
|
||||||
"2ae9ba84-3a0d-4d4c-8338-3a1478dc5fe3",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"af630914-714e-4a24-a7bb-f9af687d3b91"
|
|
||||||
]
|
|
||||||
for example_id in tqdm(chrome_list):
|
|
||||||
try:
|
|
||||||
main("chrome", example_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while running the example: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
vs_code_list = [
|
|
||||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
|
||||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
|
||||||
"eabc805a-bfcf-4460-b250-ac92135819f6",
|
|
||||||
"982d12a5-beab-424f-8d38-d2a48429e511",
|
|
||||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
|
||||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
|
||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
|
||||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in tqdm(vs_code_list):
|
|
||||||
try:
|
|
||||||
main("vs_code", example_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while running the example: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
thunderbird_list = [
|
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
|
||||||
"12086550-11c0-466b-b367-1d9e75b3910e",
|
|
||||||
"06fe7178-4491-4589-810f-2e2bc9502122",
|
|
||||||
"6766f2b8-8a72-417f-a9e5-56fcaa735837",
|
|
||||||
"e1e75309-3ddb-4d09-92ec-de869c928143",
|
|
||||||
"3d1682a7-0fb0-49ae-a4dc-a73afd2d06d5",
|
|
||||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
|
||||||
"d088f539-cab4-4f9a-ac92-9999fc3a656e",
|
|
||||||
"2ad9387a-65d8-4e33-ad5b-7580065a27ca",
|
|
||||||
"480bcfea-d68f-4aaa-a0a9-2589ef319381",
|
|
||||||
"030eeff7-b492-4218-b312-701ec99ee0cc",
|
|
||||||
"94760984-3ff5-41ee-8347-cf1af709fea0",
|
|
||||||
"99146c54-4f37-4ab8-9327-5f3291665e1e",
|
|
||||||
"c9e7eaf2-b1a1-4efc-a982-721972fa9f02"
|
|
||||||
]
|
|
||||||
|
|
||||||
for example_id in tqdm(thunderbird_list):
|
|
||||||
try:
|
|
||||||
main("thunderbird", example_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while running the example: {e}")
|
|
||||||
continue
|
|
||||||
69
lib_run_single.py
Normal file
69
lib_run_single.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from wrapt_timeout_decorator import *
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
# Open the JSON file
|
||||||
|
with open("./settings.json", "r") as file:
|
||||||
|
# Load the JSON data from the file
|
||||||
|
data = json.load(file)
|
||||||
|
time_limit = data["time_limit"]
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(time_limit, use_signals=False)
|
||||||
|
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||||
|
agent.reset()
|
||||||
|
obs = env.reset(task_config=example)
|
||||||
|
done = False
|
||||||
|
step_idx = 0
|
||||||
|
env.controller.start_recording()
|
||||||
|
|
||||||
|
while not done and step_idx < max_steps:
|
||||||
|
actions = agent.predict(
|
||||||
|
instruction,
|
||||||
|
obs
|
||||||
|
)
|
||||||
|
for action in actions:
|
||||||
|
# Capture the timestamp before executing the action
|
||||||
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
logger.info("Step %d: %s", step_idx + 1, action)
|
||||||
|
|
||||||
|
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||||
|
|
||||||
|
logger.info("Reward: %.2f", reward)
|
||||||
|
logger.info("Done: %s", done)
|
||||||
|
logger.info("Info: %s", info)
|
||||||
|
|
||||||
|
# Save screenshot and trajectory information
|
||||||
|
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||||
|
"wb") as _f:
|
||||||
|
with open(obs['screenshot'], "rb") as __f:
|
||||||
|
screenshot = __f.read()
|
||||||
|
_f.write(screenshot)
|
||||||
|
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(json.dumps({
|
||||||
|
"step_num": step_idx + 1,
|
||||||
|
"action_timestamp": action_timestamp,
|
||||||
|
"action": action,
|
||||||
|
"reward": reward,
|
||||||
|
"done": done,
|
||||||
|
"info": info,
|
||||||
|
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||||
|
}))
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
if done:
|
||||||
|
logger.info("The episode is done.")
|
||||||
|
break
|
||||||
|
step_idx += 1
|
||||||
|
result = env.evaluate()
|
||||||
|
logger.info("Result: %.2f", result)
|
||||||
|
scores.append(result)
|
||||||
|
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||||
|
f.write(f"{result}\n")
|
||||||
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
36
main.py
36
main.py
@@ -47,38 +47,38 @@ def human_agent():
|
|||||||
Runs the Gym environment with human input.
|
Runs the Gym environment with human input.
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-p', '--path', type=str, required=True, help="Path to the virtual machine .vmx file.")
|
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.")
|
||||||
parser.add_argument('-s', '--snapshot', type=str, help="Name of the snapshot to restore.")
|
parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.")
|
||||||
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
|
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
|
||||||
args = parser.parse_args(sys.argv[1:])
|
args = parser.parse_args(sys.argv[1:])
|
||||||
|
|
||||||
example_path = args.example if args.example is not None and os.path.exists(args.example) else \
|
example_path = args.example if args.example is not None and os.path.exists(args.example) else \
|
||||||
'evaluation_examples/examples/libreoffice_writer/6a33f9b9-0a56-4844-9c3f-96ec3ffb3ba2.json'
|
'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json'
|
||||||
with open(example_path, "r") as f:
|
with open(example_path, "r", encoding="utf-8") as f:
|
||||||
example = json.load(f)
|
example = json.load(f)
|
||||||
# change to your customized snapshot
|
if args.snapshot is not None:
|
||||||
if args.snapshot is not None: example["snapshot"] = args.snapshot
|
example['snapshot'] = args.snapshot
|
||||||
|
|
||||||
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
|
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
|
||||||
env = DesktopEnv(
|
env = DesktopEnv(
|
||||||
path_to_vm=args.path,
|
path_to_vm=args.path,
|
||||||
action_space="computer_13",
|
snapshot_name=args.snapshot,
|
||||||
task_config=example
|
action_space="computer_13"
|
||||||
)
|
)
|
||||||
# reset the environment to certain snapshot
|
# reset the environment to certain snapshot
|
||||||
observation = env.reset()
|
observation = env.reset(task_config=example)
|
||||||
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
|
|
||||||
done = False
|
done = False
|
||||||
|
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
|
||||||
|
|
||||||
trajectory = [
|
trajectory = [
|
||||||
# {
|
{
|
||||||
# "action_type": "MOVE_TO",
|
"action_type": "MOVE_TO", #
|
||||||
# "parameters": {
|
"parameters": {
|
||||||
# "x": 754,
|
"x": 754,
|
||||||
# "y": 1057
|
"y": 1057
|
||||||
# }
|
}
|
||||||
# },
|
},
|
||||||
# {"action_type": "CLICK", "parameters": {"button": "right", "num_clicks": 1}}
|
{"action_type": "CLICK", "parameters": {"button": "right", "num_clicks": 1}}
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(len(trajectory)):
|
for i in range(len(trajectory)):
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def find_leaf_nodes(xlm_file_str):
|
|||||||
|
|
||||||
state_ns = "uri:deskat:state.at-spi.gnome.org"
|
state_ns = "uri:deskat:state.at-spi.gnome.org"
|
||||||
component_ns = "uri:deskat:component.at-spi.gnome.org"
|
component_ns = "uri:deskat:component.at-spi.gnome.org"
|
||||||
def judge_node(node: ET, platform="ubuntu") -> bool:
|
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
|
||||||
keeps: bool = node.tag.startswith("document")\
|
keeps: bool = node.tag.startswith("document")\
|
||||||
or node.tag.endswith("item")\
|
or node.tag.endswith("item")\
|
||||||
or node.tag.endswith("button")\
|
or node.tag.endswith("button")\
|
||||||
@@ -55,23 +55,25 @@ def judge_node(node: ET, platform="ubuntu") -> bool:
|
|||||||
or platform=="windows"\
|
or platform=="windows"\
|
||||||
and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\
|
and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\
|
||||||
)\
|
)\
|
||||||
and ( node.get("{{{:}}}enabled".format(state_ns), "false")=="true"\
|
and ( node.get("{{{:}}}enabled".format(state_ns), "false")=="true"\
|
||||||
or node.get("{{{:}}}editable".format(state_ns), "false")=="true"\
|
or node.get("{{{:}}}editable".format(state_ns), "false")=="true"\
|
||||||
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
|
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
|
||||||
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
|
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
|
||||||
)\
|
)\
|
||||||
and (node.get("name", "") != "" or node.text is not None and len(node.text)>0)
|
and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\
|
||||||
|
or check_image and node.get("image", "false")=="true"
|
||||||
|
)
|
||||||
|
|
||||||
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
|
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
|
||||||
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
|
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
|
||||||
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
|
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
|
||||||
return keeps
|
return keeps
|
||||||
|
|
||||||
def filter_nodes(root: ET, platform="ubuntu"):
|
def filter_nodes(root: ET, platform="ubuntu", check_image=False):
|
||||||
filtered_nodes = []
|
filtered_nodes = []
|
||||||
|
|
||||||
for node in root.iter():
|
for node in root.iter():
|
||||||
if judge_node(node, platform):
|
if judge_node(node, platform, check_image):
|
||||||
filtered_nodes.append(node)
|
filtered_nodes.append(node)
|
||||||
#print(ET.tostring(node, encoding="unicode"))
|
#print(ET.tostring(node, encoding="unicode"))
|
||||||
|
|
||||||
@@ -155,12 +157,12 @@ def print_nodes_with_indent(nodes, indent=0):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import json
|
import json
|
||||||
with open('4.json', 'r', encoding='utf-8') as f:
|
with open('selection_sorted(imaged).xml', 'r', encoding='utf-8') as f:
|
||||||
xml_file_str = json.load(f)["AT"]
|
xml_file_str = f.read()
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
|
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
|
||||||
print(len(filtered_nodes))
|
print(len(filtered_nodes))
|
||||||
masks = draw_bounding_boxes( filtered_nodes, '4.png'
|
masks = draw_bounding_boxes( filtered_nodes, 'selection_sorted(imaged).png'
|
||||||
, '4.a.png'
|
, 'selection_sorted(imaged).ai.png'
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(masks)
|
# print(masks)
|
||||||
|
|||||||
@@ -5,32 +5,23 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import openai
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
import xml.etree.ElementTree as ET
|
from google.api_core.exceptions import InvalidArgument
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import dashscope
|
import dashscope
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import openai
|
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from openai import (
|
|
||||||
APIConnectionError,
|
|
||||||
APIError,
|
|
||||||
RateLimitError
|
|
||||||
)
|
|
||||||
|
|
||||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
SYS_PROMPT_IN_SOM_OUT_TAG
|
||||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
|
||||||
|
|
||||||
import logging
|
|
||||||
# todo: cross-check with visualwebarena
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.agent")
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
@@ -42,10 +33,10 @@ def encode_image(image_path):
|
|||||||
|
|
||||||
|
|
||||||
def linearize_accessibility_tree(accessibility_tree):
|
def linearize_accessibility_tree(accessibility_tree):
|
||||||
#leaf_nodes = find_leaf_nodes(accessibility_tree)
|
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||||
|
|
||||||
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
|
||||||
# Linearize the accessibility tree nodes into a table format
|
# Linearize the accessibility tree nodes into a table format
|
||||||
|
|
||||||
for node in filtered_nodes:
|
for node in filtered_nodes:
|
||||||
@@ -73,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree):
|
|||||||
uuid_str = str(uuid.uuid4())
|
uuid_str = str(uuid.uuid4())
|
||||||
os.makedirs("tmp/images", exist_ok=True)
|
os.makedirs("tmp/images", exist_ok=True)
|
||||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||||
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||||
|
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
||||||
# Make tag screenshot
|
# Make tag screenshot
|
||||||
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||||
|
|
||||||
@@ -169,79 +161,66 @@ def parse_code_from_som_string(input_string, masks):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|
||||||
class GPT4v_Agent:
|
class PromptAgent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key,
|
|
||||||
instruction,
|
|
||||||
model="gpt-4-vision-preview",
|
model="gpt-4-vision-preview",
|
||||||
max_tokens=500,
|
max_tokens=1500,
|
||||||
|
top_p=0.9,
|
||||||
|
temperature=0.5,
|
||||||
action_space="computer_13",
|
action_space="computer_13",
|
||||||
exp="screenshot_a11y_tree"
|
observation_type="screenshot_a11y_tree",
|
||||||
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
|
max_trajectory_length=3
|
||||||
):
|
):
|
||||||
|
|
||||||
self.instruction = instruction
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
|
self.temperature = temperature
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
self.exp = exp
|
self.observation_type = observation_type
|
||||||
self.max_trajectory_length = 3
|
self.max_trajectory_length = max_trajectory_length
|
||||||
|
|
||||||
self.headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
self.thoughts = []
|
self.thoughts = []
|
||||||
self.actions = []
|
self.actions = []
|
||||||
self.observations = []
|
self.observations = []
|
||||||
|
|
||||||
if exp == "screenshot":
|
if observation_type == "screenshot":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif exp == "a11y_tree":
|
elif observation_type == "a11y_tree":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif exp == "both":
|
elif observation_type == "screenshot_a11y_tree":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif exp == "som":
|
elif observation_type == "som":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
|
||||||
else:
|
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
|
||||||
elif exp == "seeact":
|
|
||||||
if action_space == "computer_13":
|
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
|
||||||
elif action_space == "pyautogui":
|
|
||||||
self.system_message = SYS_PROMPT_SEEACT
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid experiment type: " + exp)
|
raise ValueError("Invalid experiment type: " + observation_type)
|
||||||
|
|
||||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
def predict(self, instruction: str, obs: Dict) -> List:
|
||||||
self.instruction)
|
|
||||||
|
|
||||||
def predict(self, obs: Dict) -> List:
|
|
||||||
"""
|
"""
|
||||||
Predict the next action(s) based on the current observation.
|
Predict the next action(s) based on the current observation.
|
||||||
"""
|
"""
|
||||||
|
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
|
||||||
|
|
||||||
# Prepare the payload for the API call
|
# Prepare the payload for the API call
|
||||||
messages = []
|
messages = []
|
||||||
@@ -252,7 +231,7 @@ class GPT4v_Agent:
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": self.system_message
|
"text": system_message
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
@@ -273,7 +252,7 @@ class GPT4v_Agent:
|
|||||||
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
||||||
|
|
||||||
# {{{1
|
# {{{1
|
||||||
if self.exp == "both":
|
if self.observation_type == "screenshot_a11y_tree":
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||||
@@ -295,18 +274,15 @@ class GPT4v_Agent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp in ["som", "seeact"]:
|
elif self.observation_type in ["som"]:
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
|
||||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||||
_linearized_accessibility_tree)
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@@ -317,7 +293,7 @@ class GPT4v_Agent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "screenshot":
|
elif self.observation_type == "screenshot":
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -336,7 +312,7 @@ class GPT4v_Agent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "a11y_tree":
|
elif self.observation_type == "a11y_tree":
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -350,7 +326,7 @@ class GPT4v_Agent:
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -363,11 +339,11 @@ class GPT4v_Agent:
|
|||||||
})
|
})
|
||||||
|
|
||||||
# {{{1
|
# {{{1
|
||||||
if self.exp in ["screenshot", "both"]:
|
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||||
base64_image = encode_image(obs["screenshot"])
|
base64_image = encode_image(obs["screenshot"])
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||||
|
|
||||||
if self.exp == "both":
|
if self.observation_type == "screenshot_a11y_tree":
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image,
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
"accessibility_tree": linearized_accessibility_tree
|
||||||
@@ -384,7 +360,7 @@ class GPT4v_Agent:
|
|||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
||||||
if self.exp == "screenshot"
|
if self.observation_type == "screenshot"
|
||||||
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||||
linearized_accessibility_tree)
|
linearized_accessibility_tree)
|
||||||
},
|
},
|
||||||
@@ -397,7 +373,7 @@ class GPT4v_Agent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "a11y_tree":
|
elif self.observation_type == "a11y_tree":
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
@@ -415,15 +391,13 @@ class GPT4v_Agent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.exp == "som":
|
elif self.observation_type == "som":
|
||||||
# Add som to the screenshot
|
# Add som to the screenshot
|
||||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||||
base64_image = encode_image(tagged_screenshot)
|
base64_image = encode_image(tagged_screenshot)
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
|
||||||
})
|
})
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -431,35 +405,7 @@ class GPT4v_Agent:
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||||
linearized_accessibility_tree)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/png;base64,{base64_image}",
|
|
||||||
"detail": "high"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
elif self.exp == "seeact":
|
|
||||||
# Add som to the screenshot
|
|
||||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
|
||||||
base64_image = encode_image(tagged_screenshot)
|
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
||||||
|
|
||||||
self.observations.append({
|
|
||||||
"screenshot": base64_image,
|
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
|
||||||
})
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@@ -471,53 +417,26 @@ class GPT4v_Agent:
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||||
|
|
||||||
with open("messages.json", "w") as f:
|
|
||||||
f.write(json.dumps(messages, indent=4))
|
|
||||||
|
|
||||||
|
# with open("messages.json", "w") as f:
|
||||||
|
# f.write(json.dumps(messages, indent=4))
|
||||||
|
|
||||||
|
logger.info("Generating content with GPT model: %s", self.model)
|
||||||
response = self.call_llm({
|
response = self.call_llm({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"max_tokens": self.max_tokens
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.debug("RESPONSE: %s", response)
|
logger.info("RESPONSE: %s", response)
|
||||||
|
|
||||||
if self.exp == "seeact":
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": response
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
|
|
||||||
ACTION_GROUNDING_PROMPT_SEEACT)
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
|
|
||||||
response = self.call_llm({
|
|
||||||
"model": self.model,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": self.max_tokens
|
|
||||||
})
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
actions = self.parse_actions(response, masks)
|
actions = self.parse_actions(response, masks)
|
||||||
self.thoughts.append(response)
|
self.thoughts.append(response)
|
||||||
except Exception as e:
|
except ValueError as e:
|
||||||
print("Failed to parse action from response", e)
|
print("Failed to parse action from response", e)
|
||||||
actions = None
|
actions = None
|
||||||
self.thoughts.append("")
|
self.thoughts.append("")
|
||||||
@@ -526,86 +445,160 @@ class GPT4v_Agent:
|
|||||||
|
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
(Exception),
|
# here you should add more model exceptions as you want,
|
||||||
max_tries=10
|
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||||
|
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
||||||
|
(openai.RateLimitError,
|
||||||
|
openai.BadRequestError,
|
||||||
|
openai.InternalServerError,
|
||||||
|
InvalidArgument),
|
||||||
|
max_tries=5
|
||||||
)
|
)
|
||||||
def call_llm(self, payload):
|
def call_llm(self, payload):
|
||||||
|
|
||||||
if self.model.startswith("gpt"):
|
if self.model.startswith("gpt"):
|
||||||
logger.info("Generating content with GPT model: %s", self.model)
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||||
|
}
|
||||||
|
# logger.info("Generating content with GPT model: %s", self.model)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.openai.com/v1/chat/completions",
|
"https://api.openai.com/v1/chat/completions",
|
||||||
headers=self.headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
if response.json()['error']['code'] == "context_length_exceeded":
|
if response.json()['error']['code'] == "context_length_exceeded":
|
||||||
print("Context length exceeded. Retrying with a smaller context.")
|
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||||
payload["messages"] = payload["messages"][-1:]
|
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
|
||||||
retry_response = requests.post(
|
retry_response = requests.post(
|
||||||
"https://api.openai.com/v1/chat/completions",
|
"https://api.openai.com/v1/chat/completions",
|
||||||
headers=self.headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
if retry_response.status_code != 200:
|
if retry_response.status_code != 200:
|
||||||
print("Failed to call LLM: " + retry_response.text)
|
logger.error(
|
||||||
|
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
print("Failed to call LLM: " + response.text)
|
logger.error("Failed to call LLM: " + response.text)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
return response.json()['choices'][0]['message']['content']
|
return response.json()['choices'][0]['message']['content']
|
||||||
|
|
||||||
elif self.model.startswith("mistral"):
|
elif self.model.startswith("claude"):
|
||||||
print("call mistral")
|
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
misrtal_messages = []
|
claude_messages = []
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
mistral_message = {
|
claude_message = {
|
||||||
"role": message["role"],
|
"role": message["role"],
|
||||||
"content": []
|
"content": []
|
||||||
}
|
}
|
||||||
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||||
for part in message["content"]:
|
for part in message["content"]:
|
||||||
mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
|
||||||
|
if part['type'] == "image_url":
|
||||||
|
image_source = {}
|
||||||
|
image_source["type"] = "base64"
|
||||||
|
image_source["media_type"] = "image/png"
|
||||||
|
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
|
||||||
|
claude_message['content'].append({"type": "image", "source": image_source})
|
||||||
|
|
||||||
|
if part['type'] == "text":
|
||||||
|
claude_message['content'].append({"type": "text", "text": part['text']})
|
||||||
|
|
||||||
|
claude_messages.append(claude_message)
|
||||||
|
|
||||||
misrtal_messages.append(mistral_message)
|
# the claude not support system message in our endpoint, so we concatenate it at the first user message
|
||||||
|
if claude_messages[0]['role'] == "system":
|
||||||
|
claude_system_message_item = claude_messages[0]['content'][0]
|
||||||
|
claude_messages[1]['content'].insert(0, claude_system_message_item)
|
||||||
|
claude_messages.pop(0)
|
||||||
|
|
||||||
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
|
||||||
if misrtal_messages[0]['role'] == "system":
|
|
||||||
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
|
||||||
misrtal_messages.pop(0)
|
|
||||||
|
|
||||||
# openai.api_base = "http://localhost:8000/v1"
|
headers = {
|
||||||
# openai.api_key = "test"
|
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
||||||
# response = openai.ChatCompletion.create(
|
"anthropic-version": "2023-06-01",
|
||||||
# messages=misrtal_messages,
|
"content-type": "application/json"
|
||||||
# model="Mixtral-8x7B-Instruct-v0.1"
|
}
|
||||||
# )
|
|
||||||
|
|
||||||
from openai import OpenAI
|
payload = {
|
||||||
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
"model": self.model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": claude_messages
|
||||||
|
}
|
||||||
|
|
||||||
client = OpenAI(api_key=TOGETHER_API_KEY,
|
response = requests.post(
|
||||||
base_url='https://api.together.xyz',
|
"https://api.anthropic.com/v1/messages",
|
||||||
)
|
headers=headers,
|
||||||
logger.info("Generating content with Mistral model: %s", self.model)
|
json=payload
|
||||||
response = client.chat.completions.create(
|
|
||||||
messages=misrtal_messages,
|
|
||||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
max_tokens=1024
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
|
||||||
try:
|
logger.error("Failed to call LLM: " + response.text)
|
||||||
# return response['choices'][0]['message']['content']
|
time.sleep(5)
|
||||||
return response.choices[0].message.content
|
|
||||||
except Exception as e:
|
|
||||||
print("Failed to call LLM: " + str(e))
|
|
||||||
return ""
|
return ""
|
||||||
|
else:
|
||||||
|
return response.json()['content'][0]['text']
|
||||||
|
|
||||||
|
|
||||||
|
# elif self.model.startswith("mistral"):
|
||||||
|
# print("Call mistral")
|
||||||
|
# messages = payload["messages"]
|
||||||
|
# max_tokens = payload["max_tokens"]
|
||||||
|
#
|
||||||
|
# misrtal_messages = []
|
||||||
|
#
|
||||||
|
# for i, message in enumerate(messages):
|
||||||
|
# mistral_message = {
|
||||||
|
# "role": message["role"],
|
||||||
|
# "content": []
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# for part in message["content"]:
|
||||||
|
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||||
|
#
|
||||||
|
# misrtal_messages.append(mistral_message)
|
||||||
|
#
|
||||||
|
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||||
|
# if misrtal_messages[0]['role'] == "system":
|
||||||
|
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||||
|
# misrtal_messages.pop(0)
|
||||||
|
#
|
||||||
|
# # openai.api_base = "http://localhost:8000/v1"
|
||||||
|
# # openai.api_key = "test"
|
||||||
|
# # response = openai.ChatCompletion.create(
|
||||||
|
# # messages=misrtal_messages,
|
||||||
|
# # model="Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
# # )
|
||||||
|
#
|
||||||
|
# from openai import OpenAI
|
||||||
|
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||||
|
#
|
||||||
|
# client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||||
|
# base_url='https://api.together.xyz',
|
||||||
|
# )
|
||||||
|
# logger.info("Generating content with Mistral model: %s", self.model)
|
||||||
|
# response = client.chat.completions.create(
|
||||||
|
# messages=misrtal_messages,
|
||||||
|
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
# max_tokens=1024
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# # return response['choices'][0]['message']['content']
|
||||||
|
# return response.choices[0].message.content
|
||||||
|
# except Exception as e:
|
||||||
|
# print("Failed to call LLM: " + str(e))
|
||||||
|
# return ""
|
||||||
|
|
||||||
elif self.model.startswith("gemini"):
|
elif self.model.startswith("gemini"):
|
||||||
def encoded_img_to_pil_img(data_str):
|
def encoded_img_to_pil_img(data_str):
|
||||||
@@ -617,6 +610,8 @@ class GPT4v_Agent:
|
|||||||
|
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
gemini_messages = []
|
gemini_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
@@ -653,8 +648,9 @@ class GPT4v_Agent:
|
|||||||
for message in gemini_messages:
|
for message in gemini_messages:
|
||||||
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
|
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
|
||||||
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
|
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
|
||||||
|
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
|
||||||
|
|
||||||
print(gemini_messages)
|
# print(gemini_messages)
|
||||||
api_key = os.environ.get("GENAI_API_KEY")
|
api_key = os.environ.get("GENAI_API_KEY")
|
||||||
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
@@ -662,7 +658,16 @@ class GPT4v_Agent:
|
|||||||
response = genai.GenerativeModel(self.model).generate_content(
|
response = genai.GenerativeModel(self.model).generate_content(
|
||||||
gemini_messages,
|
gemini_messages,
|
||||||
generation_config={
|
generation_config={
|
||||||
"max_output_tokens": max_tokens
|
"candidate_count": 1,
|
||||||
|
"max_output_tokens": max_tokens,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature
|
||||||
|
},
|
||||||
|
safety_settings={
|
||||||
|
"harassment": "block_none",
|
||||||
|
"hate": "block_none",
|
||||||
|
"sex": "block_none",
|
||||||
|
"danger": "block_none"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -673,6 +678,8 @@ class GPT4v_Agent:
|
|||||||
elif self.model.startswith("qwen"):
|
elif self.model.startswith("qwen"):
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
qwen_messages = []
|
qwen_messages = []
|
||||||
|
|
||||||
@@ -683,13 +690,16 @@ class GPT4v_Agent:
|
|||||||
}
|
}
|
||||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||||
for part in message["content"]:
|
for part in message["content"]:
|
||||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
|
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
|
||||||
|
'type'] == "image_url" else None
|
||||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||||
|
|
||||||
qwen_messages.append(qwen_message)
|
qwen_messages.append(qwen_message)
|
||||||
|
|
||||||
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
|
response = dashscope.MultiModalConversation.call(
|
||||||
messages=messages)
|
model='qwen-vl-plus',
|
||||||
|
messages=messages, # todo: add the hyperparameters
|
||||||
|
)
|
||||||
# The response status_code is HTTPStatus.OK indicate success,
|
# The response status_code is HTTPStatus.OK indicate success,
|
||||||
# otherwise indicate request is failed, you can get error code
|
# otherwise indicate request is failed, you can get error code
|
||||||
# and message from code and message.
|
# and message from code and message.
|
||||||
@@ -708,7 +718,7 @@ class GPT4v_Agent:
|
|||||||
|
|
||||||
def parse_actions(self, response: str, masks=None):
|
def parse_actions(self, response: str, masks=None):
|
||||||
|
|
||||||
if self.exp in ["screenshot", "a11y_tree", "both"]:
|
if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
|
||||||
# parse from the response
|
# parse from the response
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
actions = parse_actions_from_string(response)
|
actions = parse_actions_from_string(response)
|
||||||
@@ -720,7 +730,7 @@ class GPT4v_Agent:
|
|||||||
self.actions.append(actions)
|
self.actions.append(actions)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
elif self.exp in ["som", "seeact"]:
|
elif self.observation_type in ["som"]:
|
||||||
# parse from the response
|
# parse from the response
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + self.action_space)
|
raise ValueError("Invalid action space: " + self.action_space)
|
||||||
@@ -732,3 +742,8 @@ class GPT4v_Agent:
|
|||||||
self.actions.append(actions)
|
self.actions.append(actions)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.thoughts = []
|
||||||
|
self.actions = []
|
||||||
|
self.observations = []
|
||||||
@@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti
|
|||||||
You CAN predict multiple actions at one step, but you should only return one action for each step.
|
You CAN predict multiple actions at one step, but you should only return one action for each step.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """
|
SYS_PROMPT_IN_SOM_OUT_TAG = """
|
||||||
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
||||||
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
||||||
For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library.
|
For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image.
|
||||||
|
|
||||||
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
|
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
|
||||||
You can replace x, y in the code with the tag of the element you want to operate with. such as:
|
You can replace x, y in the code with the tag of the element you want to operate with. such as:
|
||||||
|
|||||||
245
run.py
Normal file
245
run.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Script to run end-to-end evaluation on the benchmark.
|
||||||
|
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import lib_run_single
|
||||||
|
from desktop_env.envs.desktop_env import DesktopEnv
|
||||||
|
from mm_agents.agent import PromptAgent
|
||||||
|
|
||||||
|
# Logger Configs {{{ #
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||||
|
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
|
||||||
|
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
debug_handler.setLevel(logging.DEBUG)
|
||||||
|
stdout_handler.setLevel(logging.INFO)
|
||||||
|
sdebug_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s")
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
debug_handler.setFormatter(formatter)
|
||||||
|
stdout_handler.setFormatter(formatter)
|
||||||
|
sdebug_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(debug_handler)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
logger.addHandler(sdebug_handler)
|
||||||
|
# }}} Logger Configs #
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
|
def config() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run end-to-end evaluation on the benchmark"
|
||||||
|
)
|
||||||
|
|
||||||
|
# environment config
|
||||||
|
parser.add_argument("--path_to_vm", type=str,
|
||||||
|
default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx")
|
||||||
|
parser.add_argument(
|
||||||
|
"--headless", action="store_true", help="Run in headless machine"
|
||||||
|
)
|
||||||
|
parser.add_argument("--action_space", type=str, default="pyautogui", help="Action type")
|
||||||
|
parser.add_argument(
|
||||||
|
"--observation_type",
|
||||||
|
choices=[
|
||||||
|
"screenshot",
|
||||||
|
"a11y_tree",
|
||||||
|
"screenshot_a11y_tree",
|
||||||
|
"som"
|
||||||
|
],
|
||||||
|
default="som",
|
||||||
|
help="Observation type",
|
||||||
|
)
|
||||||
|
parser.add_argument("--screen_width", type=int, default=1920)
|
||||||
|
parser.add_argument("--screen_height", type=int, default=1080)
|
||||||
|
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||||
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
|
||||||
|
# agent config
|
||||||
|
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||||
|
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
||||||
|
|
||||||
|
# lm config
|
||||||
|
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||||
|
parser.add_argument("--stop_token", type=str, default=None)
|
||||||
|
|
||||||
|
# logging related
|
||||||
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
test_all_meta: dict
|
||||||
|
) -> None:
|
||||||
|
scores = []
|
||||||
|
max_steps = args.max_steps
|
||||||
|
|
||||||
|
# log args
|
||||||
|
logger.info("Args: %s", args)
|
||||||
|
|
||||||
|
agent = PromptAgent(
|
||||||
|
model=args.model,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
action_space=args.action_space,
|
||||||
|
observation_type=args.observation_type,
|
||||||
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
env = DesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=agent.action_space,
|
||||||
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
|
headless=args.headless,
|
||||||
|
)
|
||||||
|
|
||||||
|
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||||
|
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||||
|
# example setting
|
||||||
|
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
example = json.load(f)
|
||||||
|
|
||||||
|
logger.info(f"[Domain]: {domain}")
|
||||||
|
logger.info(f"[Example ID]: {example_id}")
|
||||||
|
|
||||||
|
instruction = example["instruction"]
|
||||||
|
|
||||||
|
logger.info(f"[Instruction]: {instruction}")
|
||||||
|
|
||||||
|
example_result_dir = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
domain,
|
||||||
|
example_id
|
||||||
|
)
|
||||||
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
# example start running
|
||||||
|
try:
|
||||||
|
lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
|
||||||
|
scores)
|
||||||
|
except Exception as e:
|
||||||
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
|
logger.error(f"Time limit exceeded in {domain}/{example_id}")
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(json.dumps({
|
||||||
|
"Error": f"Time limit exceeded in {domain}/{example_id}"
|
||||||
|
}))
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
finished = {}
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
finished[domain] = []
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" not in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
for file in os.listdir(example_path):
|
||||||
|
os.remove(os.path.join(example_path, file))
|
||||||
|
else:
|
||||||
|
finished[domain].append(example_id)
|
||||||
|
|
||||||
|
if not finished:
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
for domain, examples in finished.items():
|
||||||
|
if domain in total_file_json:
|
||||||
|
total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples]
|
||||||
|
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
|
||||||
|
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
|
||||||
|
all_result = []
|
||||||
|
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
finished = {}
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
finished[domain] = []
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
|
||||||
|
|
||||||
|
print("Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||||
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
####### The complete version of the list of examples #######
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
args = config()
|
||||||
|
|
||||||
|
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
||||||
|
test_all_meta = json.load(f)
|
||||||
|
|
||||||
|
test_file_list = get_unfinished(
|
||||||
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta
|
||||||
|
)
|
||||||
|
left_info = ""
|
||||||
|
for domain in test_file_list:
|
||||||
|
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||||
|
logger.info(f"Left tasks:\n{left_info}")
|
||||||
|
|
||||||
|
get_result(args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta
|
||||||
|
)
|
||||||
|
|
||||||
|
# test(args, test_all_meta)
|
||||||
3
settings.json
Normal file
3
settings.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"time_limit": "1200"
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user