Merge branch 'main' of github.com:ztjhz/DesktopEnv
This commit is contained in:
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -11,8 +11,8 @@
|
|||||||
"program": "${file}",
|
"program": "${file}",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": [
|
"args": [
|
||||||
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx",
|
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx"
|
||||||
"--example_time_limit", "60"
|
// "--example_time_limit", "60"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
16
demo.py
16
demo.py
@@ -1,16 +0,0 @@
|
|||||||
import signal
|
|
||||||
import time
|
|
||||||
|
|
||||||
def handler(signo, frame):
|
|
||||||
raise RuntimeError("Timeout")
|
|
||||||
|
|
||||||
signal.signal(signal.SIGALRM, handler)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
signal.alarm(5) # seconds
|
|
||||||
time.sleep(10)
|
|
||||||
print("Working...")
|
|
||||||
except Exception as e :
|
|
||||||
print(e)
|
|
||||||
continue
|
|
||||||
@@ -263,16 +263,19 @@ class PythonController:
|
|||||||
"""
|
"""
|
||||||
Ends recording the screen.
|
Ends recording the screen.
|
||||||
"""
|
"""
|
||||||
response = requests.post(self.http_server + "/end_recording")
|
try:
|
||||||
if response.status_code == 200:
|
response = requests.post(self.http_server + "/end_recording")
|
||||||
logger.info("Recording stopped successfully")
|
if response.status_code == 200:
|
||||||
with open(dest, 'wb') as f:
|
logger.info("Recording stopped successfully")
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
with open(dest, 'wb') as f:
|
||||||
if chunk:
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
f.write(chunk)
|
if chunk:
|
||||||
else:
|
f.write(chunk)
|
||||||
logger.error("Failed to stop recording. Status code: %d", response.status_code)
|
else:
|
||||||
return None
|
logger.error("Failed to stop recording. Status code: %d", response.status_code)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("An error occurred while trying to download the recording: %s", e)
|
||||||
|
|
||||||
# Additional info
|
# Additional info
|
||||||
def get_vm_platform(self):
|
def get_vm_platform(self):
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
"url": "https://drive.usercontent.google.com/download?id=104pg3yochKyH2Uvlp3BdvKmHgYmSIESu&export=download&authuser=0&confirm=t&uuid=d1926366-4e54-4a44-8dcd-fc49ed6524d7&at=APZUnTXcBFV9kcacsA0toU83lMKJ:1706505549057d",
|
"url": "https://drive.usercontent.google.com/download?id=1gqqY56robX1tb4YPa3Yk1d72T_k-Rgz3&export=download&authuser=0&confirm=t",
|
||||||
"path": "/home/user/Desktop/15-MB-docx-file-download.docx"
|
"path": "/home/user/Desktop/15-MB-docx-file-download.docx"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"id": "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
"id": "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
||||||
"snapshot": "gimp",
|
"snapshot": "gimp",
|
||||||
"instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB. Resize if needed.",
|
"instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB as \"compressed.jpeg\" on the Desktop. Resize if needed.",
|
||||||
"source": "",
|
"source": "",
|
||||||
"config": [
|
"config": [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
{
|
{
|
||||||
"id": "e2392362-125e-4f76-a2ee-524b183a3412",
|
"id": "e2392362-125e-4f76-a2ee-524b183a3412",
|
||||||
"snapshot": "chrome",
|
"snapshot": "chrome",
|
||||||
"instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to “Test Account” and “Test@gmail.com”.",
|
"instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to \"Test Account\" and \"Test@gmail.com\".",
|
||||||
"source": "authors",
|
"source": "authors",
|
||||||
"config": [
|
"config": [
|
||||||
{
|
{
|
||||||
"type": "command",
|
"type": "command",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": ["mkdir", "-p", "/home/user/Code/Website"]
|
"command": [
|
||||||
|
"mkdir",
|
||||||
|
"-p",
|
||||||
|
"/home/user/Code/Website"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -24,13 +28,22 @@
|
|||||||
{
|
{
|
||||||
"type": "execute",
|
"type": "execute",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": ["tar", "-xJvf", ".tmp.tar.xz", "-C", "/home/user/Code/Website/"]
|
"command": [
|
||||||
|
"tar",
|
||||||
|
"-xJvf",
|
||||||
|
".tmp.tar.xz",
|
||||||
|
"-C",
|
||||||
|
"/home/user/Code/Website/"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "launch",
|
"type": "launch",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": ["google-chrome", "--remote-debugging-port=1337"]
|
"command": [
|
||||||
|
"google-chrome",
|
||||||
|
"--remote-debugging-port=1337"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -46,14 +59,20 @@
|
|||||||
{
|
{
|
||||||
"type": "chrome_open_tabs",
|
"type": "chrome_open_tabs",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"urls_to_open": ["https://academicpages.github.io/"]
|
"urls_to_open": [
|
||||||
|
"https://academicpages.github.io/"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"trajectory": "trajectories/e2392362-125e-4f76-a2ee-524b183a3412",
|
"trajectory": "trajectories/e2392362-125e-4f76-a2ee-524b183a3412",
|
||||||
"related_apps": ["chrome", "os", "vscode"],
|
"related_apps": [
|
||||||
|
"chrome",
|
||||||
|
"os",
|
||||||
|
"vscode"
|
||||||
|
],
|
||||||
"evaluator": {
|
"evaluator": {
|
||||||
"postconfig":[
|
"postconfig": [
|
||||||
{
|
{
|
||||||
"type": "execute",
|
"type": "execute",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
@@ -66,23 +85,33 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"func": "check_json",
|
"func": "check_json",
|
||||||
"options": {"is_yaml": true},
|
"options": {
|
||||||
|
"is_yaml": true
|
||||||
|
},
|
||||||
"expected": {
|
"expected": {
|
||||||
"type": "rule",
|
"type": "rule",
|
||||||
"rules": {
|
"rules": {
|
||||||
"expect": [
|
"expect": [
|
||||||
{
|
{
|
||||||
"key": ["name"],
|
"key": [
|
||||||
|
"name"
|
||||||
|
],
|
||||||
"method": "eq",
|
"method": "eq",
|
||||||
"ref": "Test Account"
|
"ref": "Test Account"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"key": ["author", "name"],
|
"key": [
|
||||||
|
"author",
|
||||||
|
"name"
|
||||||
|
],
|
||||||
"method": "eq",
|
"method": "eq",
|
||||||
"ref": "Test Account"
|
"ref": "Test Account"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"key": ["author", "email"],
|
"key": [
|
||||||
|
"author",
|
||||||
|
"email"
|
||||||
|
],
|
||||||
"method": "eq",
|
"method": "eq",
|
||||||
"ref": "Test@gmail.com"
|
"ref": "Test@gmail.com"
|
||||||
}
|
}
|
||||||
@@ -95,4 +124,4 @@
|
|||||||
"dest": "_config.yaml"
|
"dest": "_config.yaml"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,6 @@
|
|||||||
"1e8df695-bd1b-45b3-b557-e7d599cf7597",
|
"1e8df695-bd1b-45b3-b557-e7d599cf7597",
|
||||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||||
"7b802dad-6e0f-4204-9815-d4e3f57627d8",
|
|
||||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
||||||
"0326d92d-d218-48a8-9ca1-981cd6d064c7",
|
"0326d92d-d218-48a8-9ca1-981cd6d064c7",
|
||||||
"0a2e43bf-b26c-4631-a966-af9dfa12c9e5",
|
"0a2e43bf-b26c-4631-a966-af9dfa12c9e5",
|
||||||
@@ -380,7 +379,6 @@
|
|||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||||
"ae506c68-352c-4094-9caa-ee9d42052317",
|
"ae506c68-352c-4094-9caa-ee9d42052317",
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||||
"c714dcee-cad3-4e12-8f3c-12bdcfcdb048",
|
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140",
|
"9d425400-e9b2-4424-9a4b-d4c7abac4140",
|
||||||
|
|||||||
102
evaluation_examples/test_small.json
Normal file
102
evaluation_examples/test_small.json
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
{
|
||||||
|
"chrome": [
|
||||||
|
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||||
|
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||||
|
],
|
||||||
|
"gimp": [
|
||||||
|
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
||||||
|
"554785e9-4523-4e7a-b8e1-8016f565f56a"
|
||||||
|
],
|
||||||
|
"libreoffice_calc": [
|
||||||
|
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||||
|
"42e0a640-4f19-4b28-973d-729602b5a4a7"
|
||||||
|
],
|
||||||
|
"libreoffice_impress": [
|
||||||
|
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||||
|
"550ce7e7-747b-495f-b122-acdc4d0b8e54"
|
||||||
|
],
|
||||||
|
"libreoffice_writer": [
|
||||||
|
"0810415c-bde4-4443-9047-d5f70165a697",
|
||||||
|
"0a0faba3-5580-44df-965d-f562a99b291c"
|
||||||
|
],
|
||||||
|
"multi_apps": [
|
||||||
|
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||||
|
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||||
|
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||||
|
"510f64c8-9bcc-4be1-8d30-638705850618",
|
||||||
|
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||||
|
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
|
||||||
|
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||||
|
"f7dfbef3-7697-431c-883a-db8583a4e4f9",
|
||||||
|
"6d72aad6-187a-4392-a4c4-ed87269c51cf",
|
||||||
|
"f918266a-b3e0-4914-865d-4faa564f1aef",
|
||||||
|
"da52d699-e8d2-4dc5-9191-a2199e0b6a9b",
|
||||||
|
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
|
||||||
|
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
|
||||||
|
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
|
||||||
|
"eb303e01-261e-4972-8c07-c9b4e7a4922a",
|
||||||
|
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
|
||||||
|
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
|
||||||
|
"8e116af7-7db7-4e35-a68b-b0939c066c78",
|
||||||
|
"185f29bd-5da0-40a6-b69c-ba7f4e0324ef",
|
||||||
|
"2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e",
|
||||||
|
"3a93cae4-ad3e-403e-8c12-65303b271818",
|
||||||
|
"1f18aa87-af6f-41ef-9853-cdb8f32ebdea",
|
||||||
|
"26150609-0da3-4a7d-8868-0faf9c5f01bb",
|
||||||
|
"7e287123-70ca-47b9-8521-47db09b69b14",
|
||||||
|
"e2392362-125e-4f76-a2ee-524b183a3412",
|
||||||
|
"26660ad1-6ebb-4f59-8cba-a8432dfe8d38",
|
||||||
|
"a82b78bb-7fde-4cb3-94a4-035baf10bcf0",
|
||||||
|
"36037439-2044-4b50-b9d1-875b5a332143",
|
||||||
|
"716a6079-22da-47f1-ba73-c9d58f986a38",
|
||||||
|
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
|
||||||
|
"6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a",
|
||||||
|
"da922383-bfa4-4cd3-bbad-6bebab3d7742",
|
||||||
|
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d",
|
||||||
|
"81c425f5-78f3-4771-afd6-3d2973825947",
|
||||||
|
"227d2f97-562b-4ccb-ae47-a5ec9e142fbb",
|
||||||
|
"20236825-b5df-46e7-89bf-62e1d640a897",
|
||||||
|
"02ce9a50-7af2-47ed-8596-af0c230501f8",
|
||||||
|
"4c26e3f3-3a14-4d86-b44a-d3cedebbb487",
|
||||||
|
"09a37c51-e625-49f4-a514-20a773797a8a",
|
||||||
|
"3e3fc409-bff3-4905-bf16-c968eee3f807",
|
||||||
|
"415ef462-bed3-493a-ac36-ca8c6d23bf1b",
|
||||||
|
"9f3bb592-209d-43bc-bb47-d77d9df56504",
|
||||||
|
"dd60633f-2c72-42ba-8547-6f2c8cb0fdb0",
|
||||||
|
"3f05f3b9-29ba-4b6b-95aa-2204697ffc06",
|
||||||
|
"f8369178-fafe-40c2-adc4-b9b08a125456",
|
||||||
|
"778efd0a-153f-4842-9214-f05fc176b877",
|
||||||
|
"47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5",
|
||||||
|
"c2751594-0cd5-4088-be1b-b5f2f9ec97c4",
|
||||||
|
"48c46dc7-fe04-4505-ade7-723cba1aa6f6",
|
||||||
|
"42d25c08-fb87-4927-8b65-93631280a26f",
|
||||||
|
"bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108",
|
||||||
|
"3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
||||||
|
"d68204bf-11c1-4b13-b48b-d303c73d4bf6",
|
||||||
|
"91190194-f406-4cd6-b3f9-c43fac942b22",
|
||||||
|
"7f35355e-02a6-45b5-b140-f0be698bcf85",
|
||||||
|
"98e8e339-5f91-4ed2-b2b2-12647cb134f4",
|
||||||
|
"df67aebb-fb3a-44fd-b75b-51b6012df509",
|
||||||
|
"5df7b33a-9f77-4101-823e-02f863e1c1ae",
|
||||||
|
"22a4636f-8179-4357-8e87-d1743ece1f81",
|
||||||
|
"236833a3-5704-47fc-888c-4f298f09f799"
|
||||||
|
],
|
||||||
|
"os": [
|
||||||
|
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||||
|
"5812b315-e7bd-4265-b51f-863c02174c28",
|
||||||
|
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||||
|
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82"
|
||||||
|
],
|
||||||
|
"thunderbird": [
|
||||||
|
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||||
|
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||||
|
],
|
||||||
|
"vlc": [
|
||||||
|
"59f21cfb-0120-4326-b255-a5b827b38967",
|
||||||
|
"8f080098-ddb1-424c-b438-4e96e5e4786e"
|
||||||
|
],
|
||||||
|
"vs_code": [
|
||||||
|
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||||
|
"53ad5833-3455-407b-bbc6-45b4c79ab8fb"
|
||||||
|
]
|
||||||
|
}
|
||||||
72
lib_run_single.py
Normal file
72
lib_run_single.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
from wrapt_timeout_decorator import *
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
# Open the JSON file
|
||||||
|
with open("./settings.json", "r") as file:
|
||||||
|
# Load the JSON data from the file
|
||||||
|
data = json.load(file)
|
||||||
|
time_limit = data["time_limit"]
|
||||||
|
|
||||||
|
@timeout(time_limit, use_signals=False)
|
||||||
|
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||||
|
agent.reset()
|
||||||
|
obs = env.reset(task_config=example)
|
||||||
|
done = False
|
||||||
|
step_idx = 0
|
||||||
|
env.controller.start_recording()
|
||||||
|
str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"])
|
||||||
|
while not done and step_idx < max_steps:
|
||||||
|
response, actions = agent.predict(
|
||||||
|
instruction,
|
||||||
|
obs
|
||||||
|
)
|
||||||
|
for action in actions:
|
||||||
|
# Capture the timestamp before executing the action
|
||||||
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
logger.info("Step %d: %s", step_idx + 1, action)
|
||||||
|
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||||
|
|
||||||
|
logger.info("Reward: %.2f", reward)
|
||||||
|
logger.info("Done: %s", done)
|
||||||
|
# Save screenshot and trajectory information
|
||||||
|
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||||
|
"wb") as _f:
|
||||||
|
with open(obs['screenshot'], "rb") as __f:
|
||||||
|
screenshot = __f.read()
|
||||||
|
_f.write(screenshot)
|
||||||
|
# get a11tree and save to wandb
|
||||||
|
thisrun_a11tree = env.controller.get_accessibility_tree()
|
||||||
|
str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"),
|
||||||
|
thisrun_a11tree,
|
||||||
|
response, action, action_timestamp, done)
|
||||||
|
wandb.log({"Reward": reward})
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(json.dumps({
|
||||||
|
"step_num": step_idx + 1,
|
||||||
|
"action_timestamp": action_timestamp,
|
||||||
|
"action": action,
|
||||||
|
"reward": reward,
|
||||||
|
"done": done,
|
||||||
|
"info": info,
|
||||||
|
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||||
|
}))
|
||||||
|
f.write("\n")
|
||||||
|
if done:
|
||||||
|
logger.info("The episode is done.")
|
||||||
|
break
|
||||||
|
step_idx += 1
|
||||||
|
wandb.log({"str_trajectory": str_table})
|
||||||
|
result = env.evaluate()
|
||||||
|
logger.info("Result: %.2f", result)
|
||||||
|
scores.append(result)
|
||||||
|
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||||
|
f.write(f"{result}\n")
|
||||||
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
|
wandb.log({"Result": result})
|
||||||
@@ -5,19 +5,21 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import openai
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from google.api_core.exceptions import InvalidArgument
|
|
||||||
import backoff
|
import backoff
|
||||||
import dashscope
|
import dashscope
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
import wandb
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from google.api_core.exceptions import InvalidArgument
|
||||||
|
|
||||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
||||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||||
@@ -422,7 +424,6 @@ class PromptAgent:
|
|||||||
# with open("messages.json", "w") as f:
|
# with open("messages.json", "w") as f:
|
||||||
# f.write(json.dumps(messages, indent=4))
|
# f.write(json.dumps(messages, indent=4))
|
||||||
|
|
||||||
logger.info("Generating content with GPT model: %s", self.model)
|
|
||||||
response = self.call_llm({
|
response = self.call_llm({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -441,7 +442,7 @@ class PromptAgent:
|
|||||||
actions = None
|
actions = None
|
||||||
self.thoughts.append("")
|
self.thoughts.append("")
|
||||||
|
|
||||||
return actions
|
return response, actions
|
||||||
|
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
@@ -461,7 +462,7 @@ class PromptAgent:
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||||
}
|
}
|
||||||
# logger.info("Generating content with GPT model: %s", self.model)
|
logger.info("Generating content with GPT model: %s", self.model)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.openai.com/v1/chat/completions",
|
"https://api.openai.com/v1/chat/completions",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -488,55 +489,162 @@ class PromptAgent:
|
|||||||
else:
|
else:
|
||||||
return response.json()['choices'][0]['message']['content']
|
return response.json()['choices'][0]['message']['content']
|
||||||
|
|
||||||
# elif self.model.startswith("mistral"):
|
elif self.model.startswith("claude"):
|
||||||
# print("Call mistral")
|
messages = payload["messages"]
|
||||||
# messages = payload["messages"]
|
max_tokens = payload["max_tokens"]
|
||||||
# max_tokens = payload["max_tokens"]
|
top_p = payload["top_p"]
|
||||||
#
|
temperature = payload["temperature"]
|
||||||
# misrtal_messages = []
|
|
||||||
#
|
claude_messages = []
|
||||||
# for i, message in enumerate(messages):
|
|
||||||
# mistral_message = {
|
for i, message in enumerate(messages):
|
||||||
# "role": message["role"],
|
claude_message = {
|
||||||
# "content": []
|
"role": message["role"],
|
||||||
# }
|
"content": []
|
||||||
#
|
}
|
||||||
# for part in message["content"]:
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||||
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
for part in message["content"]:
|
||||||
#
|
|
||||||
# misrtal_messages.append(mistral_message)
|
if part['type'] == "image_url":
|
||||||
#
|
image_source = {}
|
||||||
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
image_source["type"] = "base64"
|
||||||
# if misrtal_messages[0]['role'] == "system":
|
image_source["media_type"] = "image/png"
|
||||||
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
|
||||||
# misrtal_messages.pop(0)
|
claude_message['content'].append({"type": "image", "source": image_source})
|
||||||
#
|
|
||||||
# # openai.api_base = "http://localhost:8000/v1"
|
if part['type'] == "text":
|
||||||
# # openai.api_key = "test"
|
claude_message['content'].append({"type": "text", "text": part['text']})
|
||||||
# # response = openai.ChatCompletion.create(
|
|
||||||
# # messages=misrtal_messages,
|
claude_messages.append(claude_message)
|
||||||
# # model="Mixtral-8x7B-Instruct-v0.1"
|
|
||||||
# # )
|
# the claude not support system message in our endpoint, so we concatenate it at the first user message
|
||||||
#
|
if claude_messages[0]['role'] == "system":
|
||||||
# from openai import OpenAI
|
claude_system_message_item = claude_messages[0]['content'][0]
|
||||||
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
claude_messages[1]['content'].insert(0, claude_system_message_item)
|
||||||
#
|
claude_messages.pop(0)
|
||||||
# client = OpenAI(api_key=TOGETHER_API_KEY,
|
|
||||||
# base_url='https://api.together.xyz',
|
headers = {
|
||||||
# )
|
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
||||||
# logger.info("Generating content with Mistral model: %s", self.model)
|
"anthropic-version": "2023-06-01",
|
||||||
# response = client.chat.completions.create(
|
"content-type": "application/json"
|
||||||
# messages=misrtal_messages,
|
}
|
||||||
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
# max_tokens=1024
|
payload = {
|
||||||
# )
|
"model": self.model,
|
||||||
#
|
"max_tokens": max_tokens,
|
||||||
# try:
|
"messages": claude_messages
|
||||||
# # return response['choices'][0]['message']['content']
|
}
|
||||||
# return response.choices[0].message.content
|
|
||||||
# except Exception as e:
|
response = requests.post(
|
||||||
# print("Failed to call LLM: " + str(e))
|
"https://api.anthropic.com/v1/messages",
|
||||||
# return ""
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
|
||||||
|
logger.error("Failed to call LLM: " + response.text)
|
||||||
|
time.sleep(5)
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return response.json()['content'][0]['text']
|
||||||
|
|
||||||
|
|
||||||
|
elif self.model.startswith("mistral"):
|
||||||
|
print("Call mistral")
|
||||||
|
messages = payload["messages"]
|
||||||
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
|
misrtal_messages = []
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
mistral_message = {
|
||||||
|
"role": message["role"],
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for part in message["content"]:
|
||||||
|
mistral_message['content'] = part['text'] if part['type'] == "text" else ""
|
||||||
|
|
||||||
|
|
||||||
|
misrtal_messages.append(mistral_message)
|
||||||
|
|
||||||
|
|
||||||
|
# openai.api_base = "http://localhost:8000/v1"
|
||||||
|
# response = openai.ChatCompletion.create(
|
||||||
|
# messages=misrtal_messages,
|
||||||
|
# model="Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
# )
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||||
|
base_url='https://api.together.xyz',
|
||||||
|
)
|
||||||
|
logger.info("Generating content with Mistral model: %s", self.model)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=misrtal_messages,
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to call LLM: " + str(e))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
elif self.model.startswith("THUDM"):
|
||||||
|
# THUDM/cogagent-chat-hf
|
||||||
|
print("Call CogAgent")
|
||||||
|
messages = payload["messages"]
|
||||||
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
|
cog_messages = []
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
cog_message = {
|
||||||
|
"role": message["role"],
|
||||||
|
"content": []
|
||||||
|
}
|
||||||
|
|
||||||
|
for part in message["content"]:
|
||||||
|
if part['type'] == "image_url":
|
||||||
|
cog_message['content'].append({"type": "image_url", "image_url": {"url": part['image_url']['url'] } })
|
||||||
|
|
||||||
|
if part['type'] == "text":
|
||||||
|
cog_message['content'].append({"type": "text", "text": part['text']})
|
||||||
|
|
||||||
|
cog_messages.append(cog_message)
|
||||||
|
|
||||||
|
# the cogagent not support system message in our endpoint, so we concatenate it at the first user message
|
||||||
|
if cog_messages[0]['role'] == "system":
|
||||||
|
cog_system_message_item = cog_messages[0]['content'][0]
|
||||||
|
cog_messages[1]['content'].insert(0, cog_system_message_item)
|
||||||
|
cog_messages.pop(0)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": cog_messages
|
||||||
|
}
|
||||||
|
|
||||||
|
base_url = "http://127.0.0.1:8000"
|
||||||
|
|
||||||
|
response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=False)
|
||||||
|
if response.status_code == 200:
|
||||||
|
decoded_line = response.json()
|
||||||
|
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
print("Failed to call LLM: ", response.status_code)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
elif self.model.startswith("gemini"):
|
elif self.model.startswith("gemini"):
|
||||||
def encoded_img_to_pil_img(data_str):
|
def encoded_img_to_pil_img(data_str):
|
||||||
@@ -612,6 +720,7 @@ class PromptAgent:
|
|||||||
try:
|
try:
|
||||||
return response.text
|
return response.text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error("Meet exception when calling Gemini API, " + str(e))
|
||||||
return ""
|
return ""
|
||||||
elif self.model.startswith("qwen"):
|
elif self.model.startswith("qwen"):
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
|
|||||||
189
run.py
189
run.py
@@ -6,13 +6,17 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import lib_run_single
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
from desktop_env.envs.desktop_env import DesktopEnv
|
||||||
from mm_agents.agent import PromptAgent
|
from mm_agents.agent import PromptAgent
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
# Logger Configs {{{ #
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
@@ -46,13 +50,10 @@ logger.addHandler(sdebug_handler)
|
|||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
# wandb config
|
||||||
# make sure each example won't exceed the time limit
|
### set your wandb api key here
|
||||||
def handler(signo, frame):
|
os.environ["WANDB_API_KEY"] = ""
|
||||||
raise RuntimeError("Time limit exceeded!")
|
wandb.login(key=os.environ["WANDB_API_KEY"])
|
||||||
|
|
||||||
|
|
||||||
signal.signal(signal.SIGALRM, handler)
|
|
||||||
|
|
||||||
|
|
||||||
def config() -> argparse.Namespace:
|
def config() -> argparse.Namespace:
|
||||||
@@ -75,7 +76,7 @@ def config() -> argparse.Namespace:
|
|||||||
"screenshot_a11y_tree",
|
"screenshot_a11y_tree",
|
||||||
"som"
|
"som"
|
||||||
],
|
],
|
||||||
default="som",
|
default="a11y_tree",
|
||||||
help="Observation type",
|
help="Observation type",
|
||||||
)
|
)
|
||||||
parser.add_argument("--screen_width", type=int, default=1920)
|
parser.add_argument("--screen_width", type=int, default=1920)
|
||||||
@@ -86,10 +87,9 @@ def config() -> argparse.Namespace:
|
|||||||
# agent config
|
# agent config
|
||||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||||
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
||||||
parser.add_argument("--example_time_limit", type=int, default=600)
|
|
||||||
|
|
||||||
# lm config
|
# lm config
|
||||||
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
parser.add_argument("--model", type=str, default="gpt-4-0125-preview")
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||||
@@ -108,10 +108,28 @@ def test(
|
|||||||
) -> None:
|
) -> None:
|
||||||
scores = []
|
scores = []
|
||||||
max_steps = args.max_steps
|
max_steps = args.max_steps
|
||||||
time_limit = args.example_time_limit
|
|
||||||
|
|
||||||
# log args
|
# log args
|
||||||
logger.info("Args: %s", args)
|
logger.info("Args: %s", args)
|
||||||
|
# set wandb project
|
||||||
|
cfg_args = \
|
||||||
|
{
|
||||||
|
"path_to_vm": args.path_to_vm,
|
||||||
|
"headless": args.headless,
|
||||||
|
"action_space": args.action_space,
|
||||||
|
"observation_type": args.observation_type,
|
||||||
|
"screen_width": args.screen_width,
|
||||||
|
"screen_height": args.screen_height,
|
||||||
|
"sleep_after_execution": args.sleep_after_execution,
|
||||||
|
"max_steps": args.max_steps,
|
||||||
|
"max_trajectory_length": args.max_trajectory_length,
|
||||||
|
"model": args.model,
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"top_p": args.top_p,
|
||||||
|
"max_tokens": args.max_tokens,
|
||||||
|
"stop_token": args.stop_token,
|
||||||
|
"result_dir": args.result_dir
|
||||||
|
}
|
||||||
|
|
||||||
agent = PromptAgent(
|
agent = PromptAgent(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
@@ -128,8 +146,10 @@ def test(
|
|||||||
headless=args.headless,
|
headless=args.headless,
|
||||||
)
|
)
|
||||||
|
|
||||||
for domain in test_all_meta:
|
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||||
for example_id in test_all_meta[domain]:
|
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||||
|
wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}",
|
||||||
|
name=f"{example_id}")
|
||||||
# example setting
|
# example setting
|
||||||
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||||
with open(config_file, "r", encoding="utf-8") as f:
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
@@ -141,6 +161,10 @@ def test(
|
|||||||
instruction = example["instruction"]
|
instruction = example["instruction"]
|
||||||
|
|
||||||
logger.info(f"[Instruction]: {instruction}")
|
logger.info(f"[Instruction]: {instruction}")
|
||||||
|
# wandb each example config settings
|
||||||
|
cfg_args["instruction"] = instruction
|
||||||
|
cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S")
|
||||||
|
wandb.config.update(cfg_args)
|
||||||
|
|
||||||
example_result_dir = os.path.join(
|
example_result_dir = os.path.join(
|
||||||
args.result_dir,
|
args.result_dir,
|
||||||
@@ -151,79 +175,26 @@ def test(
|
|||||||
example_id
|
example_id
|
||||||
)
|
)
|
||||||
os.makedirs(example_result_dir, exist_ok=True)
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
|
||||||
# example start running
|
# example start running
|
||||||
try:
|
try:
|
||||||
signal.alarm(time_limit)
|
lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
|
||||||
agent.reset()
|
scores)
|
||||||
obs = env.reset(task_config=example)
|
except Exception as e:
|
||||||
done = False
|
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||||
step_idx = 0
|
wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])})
|
||||||
env.controller.start_recording()
|
|
||||||
|
|
||||||
while not done and step_idx < max_steps:
|
|
||||||
actions = agent.predict(
|
|
||||||
instruction,
|
|
||||||
obs
|
|
||||||
)
|
|
||||||
for action in actions:
|
|
||||||
# Capture the timestamp before executing the action
|
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
logger.info("Step %d: %s", step_idx + 1, action)
|
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
|
||||||
logger.info("Done: %s", done)
|
|
||||||
logger.info("Info: %s", info)
|
|
||||||
|
|
||||||
# Save screenshot and trajectory information
|
|
||||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
|
||||||
"wb") as _f:
|
|
||||||
with open(obs['screenshot'], "rb") as __f:
|
|
||||||
screenshot = __f.read()
|
|
||||||
_f.write(screenshot)
|
|
||||||
|
|
||||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"step_num": step_idx + 1,
|
|
||||||
"action_timestamp": action_timestamp,
|
|
||||||
"action": action,
|
|
||||||
"reward": reward,
|
|
||||||
"done": done,
|
|
||||||
"info": info,
|
|
||||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
if done:
|
|
||||||
logger.info("The episode is done.")
|
|
||||||
break
|
|
||||||
step_idx += 1
|
|
||||||
|
|
||||||
result = env.evaluate()
|
|
||||||
logger.info("Result: %.2f", result)
|
|
||||||
scores.append(result)
|
|
||||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
except RuntimeError as e:
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
logger.error(f"Error in example {domain}/{example_id}: {e}")
|
f.write(json.dumps({
|
||||||
# save info of this example and then continue
|
"Error": f"Time limit exceeded in {domain}/{example_id}"
|
||||||
try:
|
}))
|
||||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
f.write("\n")
|
||||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
# wandb settings
|
||||||
f.write(json.dumps({
|
os.mkdir(os.path.join(wandb.run.dir, "results/"))
|
||||||
"Error": f"Error in example {domain}/{example_id}: {e}",
|
for file in os.listdir(example_result_dir):
|
||||||
"step": step_idx + 1,
|
# move file to just under the root dir
|
||||||
}))
|
os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}"))
|
||||||
f.write("\n")
|
wandb.finish()
|
||||||
except Exception as new_e:
|
|
||||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
|
||||||
f.write(json.dumps({
|
|
||||||
"Error": f"Error in example {domain}/{example_id}: {e} and {new_e}",
|
|
||||||
"step": "before start recording",
|
|
||||||
}))
|
|
||||||
f.write("\n")
|
|
||||||
continue
|
|
||||||
env.close()
|
env.close()
|
||||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||||
|
|
||||||
@@ -236,9 +207,18 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
|
|||||||
|
|
||||||
finished = {}
|
finished = {}
|
||||||
for domain in os.listdir(target_dir):
|
for domain in os.listdir(target_dir):
|
||||||
|
finished[domain] = []
|
||||||
domain_path = os.path.join(target_dir, domain)
|
domain_path = os.path.join(target_dir, domain)
|
||||||
if os.path.isdir(domain_path):
|
if os.path.isdir(domain_path):
|
||||||
finished[domain] = os.listdir(domain_path)
|
for example_id in os.listdir(domain_path):
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" not in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
for file in os.listdir(example_path):
|
||||||
|
os.remove(os.path.join(example_path, file))
|
||||||
|
else:
|
||||||
|
finished[domain].append(example_id)
|
||||||
|
|
||||||
if not finished:
|
if not finished:
|
||||||
return total_file_json
|
return total_file_json
|
||||||
@@ -250,6 +230,35 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
|
|||||||
return total_file_json
|
return total_file_json
|
||||||
|
|
||||||
|
|
||||||
|
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_result = []
|
||||||
|
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
try:
|
||||||
|
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
|
||||||
|
except:
|
||||||
|
all_result.append(0.0)
|
||||||
|
|
||||||
|
if not all_result:
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||||
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
####### The complete version of the list of examples #######
|
####### The complete version of the list of examples #######
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -270,4 +279,10 @@ if __name__ == '__main__':
|
|||||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||||
logger.info(f"Left tasks:\n{left_info}")
|
logger.info(f"Left tasks:\n{left_info}")
|
||||||
|
|
||||||
test(args, test_all_meta)
|
# get_result(args.action_space,
|
||||||
|
# args.model,
|
||||||
|
# args.observation_type,
|
||||||
|
# args.result_dir,
|
||||||
|
# test_all_meta
|
||||||
|
# )
|
||||||
|
test(args, test_file_list)
|
||||||
|
|||||||
3
settings.json
Normal file
3
settings.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"time_limit": "10"
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user