Merge branch 'main' of github.com:ztjhz/DesktopEnv

This commit is contained in:
David Chang
2024-03-17 23:04:12 +08:00
12 changed files with 501 additions and 186 deletions

4
.vscode/launch.json vendored
View File

@@ -11,8 +11,8 @@
"program": "${file}", "program": "${file}",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": [ "args": [
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx", "--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx"
"--example_time_limit", "60" // "--example_time_limit", "60"
] ]
} }
] ]

16
demo.py
View File

@@ -1,16 +0,0 @@
import signal
import time
def handler(signo, frame):
raise RuntimeError("Timeout")
signal.signal(signal.SIGALRM, handler)
while True:
try:
signal.alarm(5) # seconds
time.sleep(10)
print("Working...")
except Exception as e :
print(e)
continue

View File

@@ -263,16 +263,19 @@ class PythonController:
""" """
Ends recording the screen. Ends recording the screen.
""" """
response = requests.post(self.http_server + "/end_recording") try:
if response.status_code == 200: response = requests.post(self.http_server + "/end_recording")
logger.info("Recording stopped successfully") if response.status_code == 200:
with open(dest, 'wb') as f: logger.info("Recording stopped successfully")
for chunk in response.iter_content(chunk_size=8192): with open(dest, 'wb') as f:
if chunk: for chunk in response.iter_content(chunk_size=8192):
f.write(chunk) if chunk:
else: f.write(chunk)
logger.error("Failed to stop recording. Status code: %d", response.status_code) else:
return None logger.error("Failed to stop recording. Status code: %d", response.status_code)
return None
except Exception as e:
logger.error("An error occurred while trying to download the recording: %s", e)
# Additional info # Additional info
def get_vm_platform(self): def get_vm_platform(self):

View File

@@ -9,7 +9,7 @@
"parameters": { "parameters": {
"files": [ "files": [
{ {
"url": "https://drive.usercontent.google.com/download?id=104pg3yochKyH2Uvlp3BdvKmHgYmSIESu&export=download&authuser=0&confirm=t&uuid=d1926366-4e54-4a44-8dcd-fc49ed6524d7&at=APZUnTXcBFV9kcacsA0toU83lMKJ:1706505549057d", "url": "https://drive.usercontent.google.com/download?id=1gqqY56robX1tb4YPa3Yk1d72T_k-Rgz3&export=download&authuser=0&confirm=t",
"path": "/home/user/Desktop/15-MB-docx-file-download.docx" "path": "/home/user/Desktop/15-MB-docx-file-download.docx"
} }
] ]

View File

@@ -1,7 +1,7 @@
{ {
"id": "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f", "id": "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
"snapshot": "gimp", "snapshot": "gimp",
"instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB. Resize if needed.", "instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB as \"compressed.jpeg\" on the Desktop. Resize if needed.",
"source": "", "source": "",
"config": [ "config": [
{ {

View File

@@ -1,13 +1,17 @@
{ {
"id": "e2392362-125e-4f76-a2ee-524b183a3412", "id": "e2392362-125e-4f76-a2ee-524b183a3412",
"snapshot": "chrome", "snapshot": "chrome",
"instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to Test Account and Test@gmail.com.", "instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to \"Test Account\" and \"Test@gmail.com\".",
"source": "authors", "source": "authors",
"config": [ "config": [
{ {
"type": "command", "type": "command",
"parameters": { "parameters": {
"command": ["mkdir", "-p", "/home/user/Code/Website"] "command": [
"mkdir",
"-p",
"/home/user/Code/Website"
]
} }
}, },
{ {
@@ -24,13 +28,22 @@
{ {
"type": "execute", "type": "execute",
"parameters": { "parameters": {
"command": ["tar", "-xJvf", ".tmp.tar.xz", "-C", "/home/user/Code/Website/"] "command": [
"tar",
"-xJvf",
".tmp.tar.xz",
"-C",
"/home/user/Code/Website/"
]
} }
}, },
{ {
"type": "launch", "type": "launch",
"parameters": { "parameters": {
"command": ["google-chrome", "--remote-debugging-port=1337"] "command": [
"google-chrome",
"--remote-debugging-port=1337"
]
} }
}, },
{ {
@@ -46,14 +59,20 @@
{ {
"type": "chrome_open_tabs", "type": "chrome_open_tabs",
"parameters": { "parameters": {
"urls_to_open": ["https://academicpages.github.io/"] "urls_to_open": [
"https://academicpages.github.io/"
]
} }
} }
], ],
"trajectory": "trajectories/e2392362-125e-4f76-a2ee-524b183a3412", "trajectory": "trajectories/e2392362-125e-4f76-a2ee-524b183a3412",
"related_apps": ["chrome", "os", "vscode"], "related_apps": [
"chrome",
"os",
"vscode"
],
"evaluator": { "evaluator": {
"postconfig":[ "postconfig": [
{ {
"type": "execute", "type": "execute",
"parameters": { "parameters": {
@@ -66,23 +85,33 @@
} }
], ],
"func": "check_json", "func": "check_json",
"options": {"is_yaml": true}, "options": {
"is_yaml": true
},
"expected": { "expected": {
"type": "rule", "type": "rule",
"rules": { "rules": {
"expect": [ "expect": [
{ {
"key": ["name"], "key": [
"name"
],
"method": "eq", "method": "eq",
"ref": "Test Account" "ref": "Test Account"
}, },
{ {
"key": ["author", "name"], "key": [
"author",
"name"
],
"method": "eq", "method": "eq",
"ref": "Test Account" "ref": "Test Account"
}, },
{ {
"key": ["author", "email"], "key": [
"author",
"email"
],
"method": "eq", "method": "eq",
"ref": "Test@gmail.com" "ref": "Test@gmail.com"
} }
@@ -95,4 +124,4 @@
"dest": "_config.yaml" "dest": "_config.yaml"
} }
} }
} }

View File

@@ -103,7 +103,6 @@
"1e8df695-bd1b-45b3-b557-e7d599cf7597", "1e8df695-bd1b-45b3-b557-e7d599cf7597",
"ecb0df7a-4e8d-4a03-b162-053391d3afaf", "ecb0df7a-4e8d-4a03-b162-053391d3afaf",
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14", "8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
"7b802dad-6e0f-4204-9815-d4e3f57627d8",
"a01fbce3-2793-461f-ab86-43680ccbae25", "a01fbce3-2793-461f-ab86-43680ccbae25",
"0326d92d-d218-48a8-9ca1-981cd6d064c7", "0326d92d-d218-48a8-9ca1-981cd6d064c7",
"0a2e43bf-b26c-4631-a966-af9dfa12c9e5", "0a2e43bf-b26c-4631-a966-af9dfa12c9e5",
@@ -380,7 +379,6 @@
"9439a27b-18ae-42d8-9778-5f68f891805e", "9439a27b-18ae-42d8-9778-5f68f891805e",
"ae506c68-352c-4094-9caa-ee9d42052317", "ae506c68-352c-4094-9caa-ee9d42052317",
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae", "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
"c714dcee-cad3-4e12-8f3c-12bdcfcdb048",
"930fdb3b-11a8-46fe-9bac-577332e2640e", "930fdb3b-11a8-46fe-9bac-577332e2640e",
"276cc624-87ea-4f08-ab93-f770e3790175", "276cc624-87ea-4f08-ab93-f770e3790175",
"9d425400-e9b2-4424-9a4b-d4c7abac4140", "9d425400-e9b2-4424-9a4b-d4c7abac4140",

View File

@@ -0,0 +1,102 @@
{
"chrome": [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
],
"gimp": [
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
"554785e9-4523-4e7a-b8e1-8016f565f56a"
],
"libreoffice_calc": [
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
"42e0a640-4f19-4b28-973d-729602b5a4a7"
],
"libreoffice_impress": [
"5d901039-a89c-4bfb-967b-bf66f4df075e",
"550ce7e7-747b-495f-b122-acdc4d0b8e54"
],
"libreoffice_writer": [
"0810415c-bde4-4443-9047-d5f70165a697",
"0a0faba3-5580-44df-965d-f562a99b291c"
],
"multi_apps": [
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
"510f64c8-9bcc-4be1-8d30-638705850618",
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
"f7dfbef3-7697-431c-883a-db8583a4e4f9",
"6d72aad6-187a-4392-a4c4-ed87269c51cf",
"f918266a-b3e0-4914-865d-4faa564f1aef",
"da52d699-e8d2-4dc5-9191-a2199e0b6a9b",
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
"eb303e01-261e-4972-8c07-c9b4e7a4922a",
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
"8e116af7-7db7-4e35-a68b-b0939c066c78",
"185f29bd-5da0-40a6-b69c-ba7f4e0324ef",
"2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e",
"3a93cae4-ad3e-403e-8c12-65303b271818",
"1f18aa87-af6f-41ef-9853-cdb8f32ebdea",
"26150609-0da3-4a7d-8868-0faf9c5f01bb",
"7e287123-70ca-47b9-8521-47db09b69b14",
"e2392362-125e-4f76-a2ee-524b183a3412",
"26660ad1-6ebb-4f59-8cba-a8432dfe8d38",
"a82b78bb-7fde-4cb3-94a4-035baf10bcf0",
"36037439-2044-4b50-b9d1-875b5a332143",
"716a6079-22da-47f1-ba73-c9d58f986a38",
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
"6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a",
"da922383-bfa4-4cd3-bbad-6bebab3d7742",
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d",
"81c425f5-78f3-4771-afd6-3d2973825947",
"227d2f97-562b-4ccb-ae47-a5ec9e142fbb",
"20236825-b5df-46e7-89bf-62e1d640a897",
"02ce9a50-7af2-47ed-8596-af0c230501f8",
"4c26e3f3-3a14-4d86-b44a-d3cedebbb487",
"09a37c51-e625-49f4-a514-20a773797a8a",
"3e3fc409-bff3-4905-bf16-c968eee3f807",
"415ef462-bed3-493a-ac36-ca8c6d23bf1b",
"9f3bb592-209d-43bc-bb47-d77d9df56504",
"dd60633f-2c72-42ba-8547-6f2c8cb0fdb0",
"3f05f3b9-29ba-4b6b-95aa-2204697ffc06",
"f8369178-fafe-40c2-adc4-b9b08a125456",
"778efd0a-153f-4842-9214-f05fc176b877",
"47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5",
"c2751594-0cd5-4088-be1b-b5f2f9ec97c4",
"48c46dc7-fe04-4505-ade7-723cba1aa6f6",
"42d25c08-fb87-4927-8b65-93631280a26f",
"bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108",
"3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
"d68204bf-11c1-4b13-b48b-d303c73d4bf6",
"91190194-f406-4cd6-b3f9-c43fac942b22",
"7f35355e-02a6-45b5-b140-f0be698bcf85",
"98e8e339-5f91-4ed2-b2b2-12647cb134f4",
"df67aebb-fb3a-44fd-b75b-51b6012df509",
"5df7b33a-9f77-4101-823e-02f863e1c1ae",
"22a4636f-8179-4357-8e87-d1743ece1f81",
"236833a3-5704-47fc-888c-4f298f09f799"
],
"os": [
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
"5812b315-e7bd-4265-b51f-863c02174c28",
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82"
],
"thunderbird": [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
],
"vlc": [
"59f21cfb-0120-4326-b255-a5b827b38967",
"8f080098-ddb1-424c-b438-4e96e5e4786e"
],
"vs_code": [
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
"53ad5833-3455-407b-bbc6-45b4c79ab8fb"
]
}

72
lib_run_single.py Normal file
View File

@@ -0,0 +1,72 @@
import datetime
import json
import logging
import os
import wandb
from wrapt_timeout_decorator import *
logger = logging.getLogger("desktopenv.experiment")
# Open the JSON file
with open("./settings.json", "r") as file:
# Load the JSON data from the file
data = json.load(file)
time_limit = data["time_limit"]
@timeout(time_limit, use_signals=False)
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
agent.reset()
obs = env.reset(task_config=example)
done = False
step_idx = 0
env.controller.start_recording()
str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"])
while not done and step_idx < max_steps:
response, actions = agent.predict(
instruction,
obs
)
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
# Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
with open(obs['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
# get a11tree and save to wandb
thisrun_a11tree = env.controller.get_accessibility_tree()
str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"),
thisrun_a11tree,
response, action, action_timestamp, done)
wandb.log({"Reward": reward})
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
step_idx += 1
wandb.log({"str_trajectory": str_table})
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
wandb.log({"Result": result})

View File

@@ -5,19 +5,21 @@ import os
import re import re
import time import time
import uuid import uuid
import openai
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Dict, List from typing import Dict, List
from google.api_core.exceptions import InvalidArgument
import backoff import backoff
import dashscope import dashscope
import google.generativeai as genai import google.generativeai as genai
import openai
import requests import requests
import wandb
from PIL import Image from PIL import Image
from google.api_core.exceptions import InvalidArgument
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
@@ -422,7 +424,6 @@ class PromptAgent:
# with open("messages.json", "w") as f: # with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4)) # f.write(json.dumps(messages, indent=4))
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({ response = self.call_llm({
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
@@ -441,7 +442,7 @@ class PromptAgent:
actions = None actions = None
self.thoughts.append("") self.thoughts.append("")
return actions return response, actions
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
@@ -461,7 +462,7 @@ class PromptAgent:
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
} }
# logger.info("Generating content with GPT model: %s", self.model) logger.info("Generating content with GPT model: %s", self.model)
response = requests.post( response = requests.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers=headers,
@@ -488,55 +489,162 @@ class PromptAgent:
else: else:
return response.json()['choices'][0]['message']['content'] return response.json()['choices'][0]['message']['content']
# elif self.model.startswith("mistral"): elif self.model.startswith("claude"):
# print("Call mistral") messages = payload["messages"]
# messages = payload["messages"] max_tokens = payload["max_tokens"]
# max_tokens = payload["max_tokens"] top_p = payload["top_p"]
# temperature = payload["temperature"]
# misrtal_messages = []
# claude_messages = []
# for i, message in enumerate(messages):
# mistral_message = { for i, message in enumerate(messages):
# "role": message["role"], claude_message = {
# "content": [] "role": message["role"],
# } "content": []
# }
# for part in message["content"]: assert len(message["content"]) in [1, 2], "One text, or one text with one image"
# mistral_message['content'] = part['text'] if part['type'] == "text" else None for part in message["content"]:
#
# misrtal_messages.append(mistral_message) if part['type'] == "image_url":
# image_source = {}
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message image_source["type"] = "base64"
# if misrtal_messages[0]['role'] == "system": image_source["media_type"] = "image/png"
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content'] image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
# misrtal_messages.pop(0) claude_message['content'].append({"type": "image", "source": image_source})
#
# # openai.api_base = "http://localhost:8000/v1" if part['type'] == "text":
# # openai.api_key = "test" claude_message['content'].append({"type": "text", "text": part['text']})
# # response = openai.ChatCompletion.create(
# # messages=misrtal_messages, claude_messages.append(claude_message)
# # model="Mixtral-8x7B-Instruct-v0.1"
# # ) # the claude not support system message in our endpoint, so we concatenate it at the first user message
# if claude_messages[0]['role'] == "system":
# from openai import OpenAI claude_system_message_item = claude_messages[0]['content'][0]
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2" claude_messages[1]['content'].insert(0, claude_system_message_item)
# claude_messages.pop(0)
# client = OpenAI(api_key=TOGETHER_API_KEY,
# base_url='https://api.together.xyz', headers = {
# ) "x-api-key": os.environ["ANTHROPIC_API_KEY"],
# logger.info("Generating content with Mistral model: %s", self.model) "anthropic-version": "2023-06-01",
# response = client.chat.completions.create( "content-type": "application/json"
# messages=misrtal_messages, }
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
# max_tokens=1024 payload = {
# ) "model": self.model,
# "max_tokens": max_tokens,
# try: "messages": claude_messages
# # return response['choices'][0]['message']['content'] }
# return response.choices[0].message.content
# except Exception as e: response = requests.post(
# print("Failed to call LLM: " + str(e)) "https://api.anthropic.com/v1/messages",
# return "" headers=headers,
json=payload
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['content'][0]['text']
elif self.model.startswith("mistral"):
print("Call mistral")
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
misrtal_messages = []
for i, message in enumerate(messages):
mistral_message = {
"role": message["role"],
"content": ""
}
for part in message["content"]:
mistral_message['content'] = part['text'] if part['type'] == "text" else ""
misrtal_messages.append(mistral_message)
# openai.api_base = "http://localhost:8000/v1"
# response = openai.ChatCompletion.create(
# messages=misrtal_messages,
# model="Mixtral-8x7B-Instruct-v0.1"
# )
from openai import OpenAI
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
base_url='https://api.together.xyz',
)
logger.info("Generating content with Mistral model: %s", self.model)
response = client.chat.completions.create(
messages=misrtal_messages,
model=self.model,
max_tokens=max_tokens
)
try:
return response.choices[0].message.content
except Exception as e:
print("Failed to call LLM: " + str(e))
return ""
elif self.model.startswith("THUDM"):
# THUDM/cogagent-chat-hf
print("Call CogAgent")
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
cog_messages = []
for i, message in enumerate(messages):
cog_message = {
"role": message["role"],
"content": []
}
for part in message["content"]:
if part['type'] == "image_url":
cog_message['content'].append({"type": "image_url", "image_url": {"url": part['image_url']['url'] } })
if part['type'] == "text":
cog_message['content'].append({"type": "text", "text": part['text']})
cog_messages.append(cog_message)
# the cogagent not support system message in our endpoint, so we concatenate it at the first user message
if cog_messages[0]['role'] == "system":
cog_system_message_item = cog_messages[0]['content'][0]
cog_messages[1]['content'].insert(0, cog_system_message_item)
cog_messages.pop(0)
payload = {
"model": self.model,
"max_tokens": max_tokens,
"messages": cog_messages
}
base_url = "http://127.0.0.1:8000"
response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=False)
if response.status_code == 200:
decoded_line = response.json()
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
return content
else:
print("Failed to call LLM: ", response.status_code)
return ""
elif self.model.startswith("gemini"): elif self.model.startswith("gemini"):
def encoded_img_to_pil_img(data_str): def encoded_img_to_pil_img(data_str):
@@ -612,6 +720,7 @@ class PromptAgent:
try: try:
return response.text return response.text
except Exception as e: except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e))
return "" return ""
elif self.model.startswith("qwen"): elif self.model.startswith("qwen"):
messages = payload["messages"] messages = payload["messages"]

189
run.py
View File

@@ -6,13 +6,17 @@ import datetime
import json import json
import logging import logging
import os import os
import signal import random
import sys import sys
import wandb
from tqdm import tqdm
import lib_run_single
from desktop_env.envs.desktop_env import DesktopEnv from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent from mm_agents.agent import PromptAgent
# Logger Configs {{{ # # Logger Configs {{{ #
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@@ -46,13 +50,10 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment") logger = logging.getLogger("desktopenv.experiment")
# wandb config
# make sure each example won't exceed the time limit ### set your wandb api key here
def handler(signo, frame): os.environ["WANDB_API_KEY"] = ""
raise RuntimeError("Time limit exceeded!") wandb.login(key=os.environ["WANDB_API_KEY"])
signal.signal(signal.SIGALRM, handler)
def config() -> argparse.Namespace: def config() -> argparse.Namespace:
@@ -75,7 +76,7 @@ def config() -> argparse.Namespace:
"screenshot_a11y_tree", "screenshot_a11y_tree",
"som" "som"
], ],
default="som", default="a11y_tree",
help="Observation type", help="Observation type",
) )
parser.add_argument("--screen_width", type=int, default=1920) parser.add_argument("--screen_width", type=int, default=1920)
@@ -86,10 +87,9 @@ def config() -> argparse.Namespace:
# agent config # agent config
parser.add_argument("--max_trajectory_length", type=int, default=3) parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
parser.add_argument("--example_time_limit", type=int, default=600)
# lm config # lm config
parser.add_argument("--model", type=str, default="gpt-4-vision-preview") parser.add_argument("--model", type=str, default="gpt-4-0125-preview")
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--max_tokens", type=int, default=1500)
@@ -108,10 +108,28 @@ def test(
) -> None: ) -> None:
scores = [] scores = []
max_steps = args.max_steps max_steps = args.max_steps
time_limit = args.example_time_limit
# log args # log args
logger.info("Args: %s", args) logger.info("Args: %s", args)
# set wandb project
cfg_args = \
{
"path_to_vm": args.path_to_vm,
"headless": args.headless,
"action_space": args.action_space,
"observation_type": args.observation_type,
"screen_width": args.screen_width,
"screen_height": args.screen_height,
"sleep_after_execution": args.sleep_after_execution,
"max_steps": args.max_steps,
"max_trajectory_length": args.max_trajectory_length,
"model": args.model,
"temperature": args.temperature,
"top_p": args.top_p,
"max_tokens": args.max_tokens,
"stop_token": args.stop_token,
"result_dir": args.result_dir
}
agent = PromptAgent( agent = PromptAgent(
model=args.model, model=args.model,
@@ -128,8 +146,10 @@ def test(
headless=args.headless, headless=args.headless,
) )
for domain in test_all_meta: for domain in tqdm(test_all_meta, desc="Domain"):
for example_id in test_all_meta[domain]: for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}",
name=f"{example_id}")
# example setting # example setting
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f: with open(config_file, "r", encoding="utf-8") as f:
@@ -141,6 +161,10 @@ def test(
instruction = example["instruction"] instruction = example["instruction"]
logger.info(f"[Instruction]: {instruction}") logger.info(f"[Instruction]: {instruction}")
# wandb each example config settings
cfg_args["instruction"] = instruction
cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S")
wandb.config.update(cfg_args)
example_result_dir = os.path.join( example_result_dir = os.path.join(
args.result_dir, args.result_dir,
@@ -151,79 +175,26 @@ def test(
example_id example_id
) )
os.makedirs(example_result_dir, exist_ok=True) os.makedirs(example_result_dir, exist_ok=True)
# example start running # example start running
try: try:
signal.alarm(time_limit) lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
agent.reset() scores)
obs = env.reset(task_config=example) except Exception as e:
done = False logger.error(f"Exception in {domain}/{example_id}: {e}")
step_idx = 0 wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])})
env.controller.start_recording()
while not done and step_idx < max_steps:
actions = agent.predict(
instruction,
obs
)
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
with open(obs['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
step_idx += 1
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
except RuntimeError as e: with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
logger.error(f"Error in example {domain}/{example_id}: {e}") f.write(json.dumps({
# save info of this example and then continue "Error": f"Time limit exceeded in {domain}/{example_id}"
try: }))
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) f.write("\n")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: # wandb settings
f.write(json.dumps({ os.mkdir(os.path.join(wandb.run.dir, "results/"))
"Error": f"Error in example {domain}/{example_id}: {e}", for file in os.listdir(example_result_dir):
"step": step_idx + 1, # move file to just under the root dir
})) os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}"))
f.write("\n") wandb.finish()
except Exception as new_e:
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"Error": f"Error in example {domain}/{example_id}: {e} and {new_e}",
"step": "before start recording",
}))
f.write("\n")
continue
env.close() env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}") logger.info(f"Average score: {sum(scores) / len(scores)}")
@@ -236,9 +207,18 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
finished = {} finished = {}
for domain in os.listdir(target_dir): for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain) domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path): if os.path.isdir(domain_path):
finished[domain] = os.listdir(domain_path) for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished: if not finished:
return total_file_json return total_file_json
@@ -250,6 +230,35 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
return total_file_json return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == '__main__': if __name__ == '__main__':
####### The complete version of the list of examples ####### ####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -270,4 +279,10 @@ if __name__ == '__main__':
left_info += f"{domain}: {len(test_file_list[domain])}\n" left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}") logger.info(f"Left tasks:\n{left_info}")
test(args, test_all_meta) # get_result(args.action_space,
# args.model,
# args.observation_type,
# args.result_dir,
# test_all_meta
# )
test(args, test_file_list)

3
settings.json Normal file
View File

@@ -0,0 +1,3 @@
{
"time_limit": "10"
}