This commit is contained in:
rhythmcao
2024-03-15 22:09:44 +08:00
6 changed files with 198 additions and 83 deletions

19
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,19 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": [
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx",
"--example_time_limit", "60"
]
}
]
}

View File

@@ -21,10 +21,12 @@
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw) Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
2. Install the environment package, download the examples and the virtual machine image. 2. Install the environment package, download the examples and the virtual machine image.
For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands:
```bash ```bash
pip install desktop-env pip install desktop-env
gdown xxxx gdown xxxx
gdown xxxx vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui
vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state"
``` ```
## Quick Start ## Quick Start

16
demo.py Normal file
View File

@@ -0,0 +1,16 @@
import signal
import time
def handler(signo, frame):
raise RuntimeError("Timeout")
signal.signal(signal.SIGALRM, handler)
while True:
try:
signal.alarm(5) # seconds
time.sleep(10)
print("Working...")
except Exception as e :
print(e)
continue

View File

@@ -0,0 +1,19 @@
import pandas as pd
file_path = "/Users/lxc/Downloads/Speedtest.csv"
# 找到csv第二行的第二个数据格里的值
# with open(file_path, "r") as f:
# for i, line in enumerate(f):
# if i == 1:
# data = line.split(",")[1]
# break
# print(data)
with open(file_path, "r") as f:
reader = pd.read_csv(f, sep=',', header=None)
# for column in reader.columns:
# if column.startswith("TEST_DATE"):
# data_col = column
# break
for data in reader['TEST_DATE']:
print(data)

View File

@@ -5,21 +5,17 @@ import os
import re import re
import time import time
import uuid import uuid
import openai
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Dict, List from typing import Dict, List
from google.api_core.exceptions import InvalidArgument
import backoff import backoff
import dashscope import dashscope
import google.generativeai as genai import google.generativeai as genai
import requests import requests
from PIL import Image from PIL import Image
from vertexai.preview.generative_models import (
HarmBlockThreshold,
HarmCategory,
Image,
)
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
@@ -43,7 +39,7 @@ def linearize_accessibility_tree(accessibility_tree):
# leaf_nodes = find_leaf_nodes(accessibility_tree) # leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n" linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
# Linearize the accessibility tree nodes into a table format # Linearize the accessibility tree nodes into a table format
for node in filtered_nodes: for node in filtered_nodes:
@@ -205,7 +201,7 @@ class PromptAgent:
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else: else:
raise ValueError("Invalid action space: " + action_space) raise ValueError("Invalid action space: " + action_space)
elif observation_type == "both": elif observation_type == "screenshot_a11y_tree":
if action_space == "computer_13": if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui": elif action_space == "pyautogui":
@@ -233,8 +229,7 @@ class PromptAgent:
""" """
Predict the next action(s) based on the current observation. Predict the next action(s) based on the current observation.
""" """
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format( system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
instruction)
# Prepare the payload for the API call # Prepare the payload for the API call
messages = [] messages = []
@@ -245,7 +240,7 @@ class PromptAgent:
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": self.system_message "text": system_message
}, },
] ]
}) })
@@ -266,7 +261,7 @@ class PromptAgent:
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts): for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1 # {{{1
if self.observation_type == "both": if self.observation_type == "screenshot_a11y_tree":
_screenshot = previous_obs["screenshot"] _screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"] _linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
@@ -356,11 +351,11 @@ class PromptAgent:
}) })
# {{{1 # {{{1
if self.observation_type in ["screenshot", "both"]: if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"]) base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
if self.observation_type == "both": if self.observation_type == "screenshot_a11y_tree":
self.observations.append({ self.observations.append({
"screenshot": base64_image, "screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree "accessibility_tree": linearized_accessibility_tree
@@ -473,7 +468,9 @@ class PromptAgent:
response = self.call_llm({ response = self.call_llm({
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"max_tokens": self.max_tokens "max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}) })
logger.info("RESPONSE: %s", response) logger.info("RESPONSE: %s", response)
@@ -513,7 +510,7 @@ class PromptAgent:
try: try:
actions = self.parse_actions(response, masks) actions = self.parse_actions(response, masks)
self.thoughts.append(response) self.thoughts.append(response)
except Exception as e: except ValueError as e:
print("Failed to parse action from response", e) print("Failed to parse action from response", e)
actions = None actions = None
self.thoughts.append("") self.thoughts.append("")
@@ -522,9 +519,16 @@ class PromptAgent:
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
(Exception), # here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
max_tries=5 max_tries=5
) )
def call_llm(self, payload): def call_llm(self, payload):
if self.model.startswith("gpt"): if self.model.startswith("gpt"):
@@ -532,7 +536,7 @@ class PromptAgent:
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
} }
logger.info("Generating content with GPT model: %s", self.model) # logger.info("Generating content with GPT model: %s", self.model)
response = requests.post( response = requests.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers=headers,
@@ -542,14 +546,14 @@ class PromptAgent:
if response.status_code != 200: if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded": if response.json()['error']['code'] == "context_length_exceeded":
logger.error("Context length exceeded. Retrying with a smaller context.") logger.error("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:] payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
retry_response = requests.post( retry_response = requests.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers=headers,
json=payload json=payload
) )
if retry_response.status_code != 200: if retry_response.status_code != 200:
logger.error("Failed to call LLM: " + retry_response.text) logger.error("Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
return "" return ""
logger.error("Failed to call LLM: " + response.text) logger.error("Failed to call LLM: " + response.text)
@@ -656,8 +660,9 @@ class PromptAgent:
for message in gemini_messages: for message in gemini_messages:
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n" message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}] gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
print(gemini_messages) # print(gemini_messages)
api_key = os.environ.get("GENAI_API_KEY") api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable" assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
@@ -671,11 +676,10 @@ class PromptAgent:
"temperature": temperature "temperature": temperature
}, },
safety_settings={ safety_settings={
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, "harassment": "block_none",
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, "hate": "block_none",
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, "sex": "block_none",
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, "danger": "block_none"
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
} }
) )
@@ -726,7 +730,7 @@ class PromptAgent:
def parse_actions(self, response: str, masks=None): def parse_actions(self, response: str, masks=None):
if self.observation_type in ["screenshot", "a11y_tree", "both"]: if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
# parse from the response # parse from the response
if self.action_space == "computer_13": if self.action_space == "computer_13":
actions = parse_actions_from_string(response) actions = parse_actions_from_string(response)

165
run.py
View File

@@ -6,6 +6,7 @@ import datetime
import json import json
import logging import logging
import os import os
import signal
import sys import sys
from desktop_env.envs.desktop_env import DesktopEnv from desktop_env.envs.desktop_env import DesktopEnv
@@ -46,6 +47,14 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment") logger = logging.getLogger("desktopenv.experiment")
# make sure each example won't exceed the time limit
def handler(signo, frame):
raise RuntimeError("Time limit exceeded!")
signal.signal(signal.SIGALRM, handler)
def config() -> argparse.Namespace: def config() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark" description="Run end-to-end evaluation on the benchmark"
@@ -66,7 +75,7 @@ def config() -> argparse.Namespace:
"screenshot_a11y_tree", "screenshot_a11y_tree",
"som" "som"
], ],
default="a11y_tree", default="som",
help="Observation type", help="Observation type",
) )
parser.add_argument("--screen_width", type=int, default=1920) parser.add_argument("--screen_width", type=int, default=1920)
@@ -77,6 +86,7 @@ def config() -> argparse.Namespace:
# agent config # agent config
parser.add_argument("--max_trajectory_length", type=int, default=3) parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
parser.add_argument("--example_time_limit", type=int, default=600)
# lm config # lm config
parser.add_argument("--model", type=str, default="gpt-4-vision-preview") parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
@@ -98,6 +108,7 @@ def test(
) -> None: ) -> None:
scores = [] scores = []
max_steps = args.max_steps max_steps = args.max_steps
time_limit = args.example_time_limit
# log args # log args
logger.info("Args: %s", args) logger.info("Args: %s", args)
@@ -119,6 +130,7 @@ def test(
for domain in test_all_meta: for domain in test_all_meta:
for example_id in test_all_meta[domain]: for example_id in test_all_meta[domain]:
# example setting
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f: with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f) example = json.load(f)
@@ -140,68 +152,102 @@ def test(
) )
os.makedirs(example_result_dir, exist_ok=True) os.makedirs(example_result_dir, exist_ok=True)
agent.reset() # example start running
obs = env.reset(task_config=example) try:
done = False signal.alarm(time_limit)
step_idx = 0 agent.reset()
env.controller.start_recording() obs = env.reset(task_config=example)
done = False
step_idx = 0
env.controller.start_recording()
while not done and step_idx < max_steps: while not done and step_idx < max_steps:
actions = agent.predict( actions = agent.predict(
instruction, instruction,
obs obs
) )
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
for action in actions: obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
with open(obs['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
step_idx += 1 step_idx += 1
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
observation, reward, done, info = env.step(action, args.sleep_after_execution) result = env.evaluate()
logger.info("Result: %.2f", result)
logger.info("Reward: %.2f", reward) scores.append(result)
logger.info("Done: %s", done) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
logger.info("Info: %s", info) except RuntimeError as e:
logger.error(f"Error in example {domain}/{example_id}: {e}")
# Save screenshot and trajectory information # save info of this example and then continue
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), try:
"wb") as _f: env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
with open(observation['screenshot'], "rb") as __f: with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
screenshot = __f.read()
_f.write(screenshot)
with open(os.path.join(example_result_dir, "traj.json"), "a") as f:
f.write(json.dumps({ f.write(json.dumps({
"step_num": step_idx + 1, "Error": f"Error in example {domain}/{example_id}: {e}",
"action_timestamp": action_timestamp, "step": step_idx + 1,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
})) }))
f.write("\n") f.write("\n")
except Exception as new_e:
if done: with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
logger.info("The episode is done.") f.write(json.dumps({
break "Error": f"Error in example {domain}/{example_id}: {e} and {new_e}",
"step": "before start recording",
result = env.evaluate() }))
logger.info("Result: %.2f", result) f.write("\n")
scores.append(result) continue
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
env.close() env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}") logger.info(f"Average score: {sum(scores) / len(scores)}")
def get_unfinished(test_file_list, result_dir): def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
finished = [] target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
for domain in os.listdir(result_dir):
for example_id in os.listdir(os.path.join(result_dir, domain)): if not os.path.exists(target_dir):
finished.append(f"{domain}/{example_id}") return total_file_json
return [x for x in test_file_list if x not in finished]
finished = {}
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
finished[domain] = os.listdir(domain_path)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples]
return total_file_json
if __name__ == '__main__': if __name__ == '__main__':
@@ -209,10 +255,19 @@ if __name__ == '__main__':
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config() args = config()
# test_file_list = get_unfinished(args.test, args.result_dir)
# logger.info(f"Total {len(test_file_list)} tasks left")
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
test_all_meta = json.load(f) test_all_meta = json.load(f)
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
test(args, test_all_meta) test(args, test_all_meta)