update timer

This commit is contained in:
Jason Lee
2024-03-15 23:28:45 +08:00
7 changed files with 106 additions and 147 deletions

View File

@@ -21,10 +21,12 @@
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
2. Install the environment package, download the examples and the virtual machine image.
For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands:
```bash
pip install desktop-env
gdown xxxx
gdown xxxx
vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui
vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state"
```
## Quick Start

View File

@@ -284,6 +284,15 @@ def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = N
text = text.replace("\ufffc", "").replace("\ufffd", "")
# }}} Text #
# Image {{{ #
try:
node.queryImage()
except NotImplementedError:
pass
else:
attribute_dict["image"] = "true"
# }}} Image #
# Selection {{{ #
try:
node.querySelection()

24
main.py
View File

@@ -4,7 +4,7 @@ import logging
import os
import sys
import time
import argparse
from desktop_env.envs.desktop_env import DesktopEnv
# Logger Configs {{{ #
@@ -46,19 +46,29 @@ def human_agent():
"""
Runs the Gym environment with human input.
"""
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.")
parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.")
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
args = parser.parse_args(sys.argv[1:])
with open("evaluation_examples/examples/multi_apps/4c26e3f3-3a14-4d86-b44a-d3cedebbb487.json", "r", encoding="utf-8") as f:
example_path = args.example if args.example is not None and os.path.exists(args.example) else \
'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json'
with open(example_path, "r", encoding="utf-8") as f:
example = json.load(f)
example["snapshot"] = "exp_v5"
if args.snapshot is not None:
example['snapshot'] = args.snapshot
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
env = DesktopEnv(
path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx",
action_space="computer_13",
task_config=example
path_to_vm=args.path,
snapshot_name=args.snapshot,
action_space="computer_13"
)
# reset the environment to certain snapshot
observation = env.reset()
observation = env.reset(task_config=example)
done = False
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
trajectory = [
{

View File

@@ -26,7 +26,7 @@ def find_leaf_nodes(xlm_file_str):
state_ns = "uri:deskat:state.at-spi.gnome.org"
component_ns = "uri:deskat:component.at-spi.gnome.org"
def judge_node(node: ET, platform="ubuntu") -> bool:
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
keeps: bool = node.tag.startswith("document")\
or node.tag.endswith("item")\
or node.tag.endswith("button")\
@@ -60,18 +60,20 @@ def judge_node(node: ET, platform="ubuntu") -> bool:
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
)\
and (node.get("name", "") != "" or node.text is not None and len(node.text)>0)
and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\
or check_image and node.get("image", "false")=="true"
)
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
return keeps
def filter_nodes(root: ET, platform="ubuntu"):
def filter_nodes(root: ET, platform="ubuntu", check_image=False):
filtered_nodes = []
for node in root.iter():
if judge_node(node, platform):
if judge_node(node, platform, check_image):
filtered_nodes.append(node)
#print(ET.tostring(node, encoding="unicode"))
@@ -155,12 +157,12 @@ def print_nodes_with_indent(nodes, indent=0):
if __name__ == '__main__':
import json
with open('4.json', 'r', encoding='utf-8') as f:
xml_file_str = json.load(f)["AT"]
with open('selection_sorted(imaged).xml', 'r', encoding='utf-8') as f:
xml_file_str = f.read()
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
print(len(filtered_nodes))
masks = draw_bounding_boxes( filtered_nodes, '4.png'
, '4.a.png'
masks = draw_bounding_boxes( filtered_nodes, 'selection_sorted(imaged).png'
, 'selection_sorted(imaged).ai.png'
)
# print(masks)

View File

@@ -11,26 +11,17 @@ from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
from google.api_core.exceptions import InvalidArgument
import backoff
import dashscope
import google.generativeai as genai
import requests
from PIL import Image
from vertexai.preview.generative_models import (
HarmBlockThreshold,
HarmCategory,
Image,
)
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
# todo: cross-check with visualwebarena
SYS_PROMPT_IN_SOM_OUT_TAG
logger = logging.getLogger("desktopenv.agent")
@@ -45,7 +36,7 @@ def linearize_accessibility_tree(accessibility_tree):
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
@@ -73,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree):
uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
# Make tag screenshot
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
@@ -178,7 +170,7 @@ class PromptAgent:
temperature=0.5,
action_space="computer_13",
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=3
):
self.model = model
@@ -207,7 +199,7 @@ class PromptAgent:
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "both":
elif observation_type == "screenshot_a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui":
@@ -218,14 +210,7 @@ class PromptAgent:
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "seeact":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_SEEACT
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
else:
@@ -235,8 +220,7 @@ class PromptAgent:
"""
Predict the next action(s) based on the current observation.
"""
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
instruction)
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
# Prepare the payload for the API call
messages = []
@@ -247,7 +231,7 @@ class PromptAgent:
"content": [
{
"type": "text",
"text": self.system_message
"text": system_message
},
]
})
@@ -268,7 +252,7 @@ class PromptAgent:
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1
if self.observation_type == "both":
if self.observation_type == "screenshot_a11y_tree":
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
@@ -290,18 +274,15 @@ class PromptAgent:
}
]
})
elif self.observation_type in ["som", "seeact"]:
elif self.observation_type in ["som"]:
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
@@ -358,11 +339,11 @@ class PromptAgent:
})
# {{{1
if self.observation_type in ["screenshot", "both"]:
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
if self.observation_type == "both":
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
@@ -414,11 +395,9 @@ class PromptAgent:
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
"screenshot": base64_image
})
messages.append({
@@ -426,35 +405,7 @@ class PromptAgent:
"content": [
{
"type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
]
})
elif self.observation_type == "seeact":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
@@ -475,43 +426,13 @@ class PromptAgent:
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
logger.info("RESPONSE: %s", response)
if self.observation_type == "seeact":
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": response
}
]
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
ACTION_GROUNDING_PROMPT_SEEACT)
}
]
})
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
logger.info("RESPONSE: %s", response)
try:
actions = self.parse_actions(response, masks)
self.thoughts.append(response)
@@ -524,16 +445,15 @@ class PromptAgent:
@backoff.on_exception(
backoff.expo,
# here you should add more model exceptions as you want,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
max_tries=5
)
def call_llm(self, payload):
if self.model.startswith("gpt"):
@@ -551,14 +471,15 @@ class PromptAgent:
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
logger.error("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:]
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload
)
if retry_response.status_code != 200:
logger.error("Failed to call LLM: " + retry_response.text)
logger.error(
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
return ""
logger.error("Failed to call LLM: " + response.text)
@@ -665,8 +586,9 @@ class PromptAgent:
for message in gemini_messages:
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
print(gemini_messages)
# print(gemini_messages)
api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key)
@@ -680,11 +602,10 @@ class PromptAgent:
"temperature": temperature
},
safety_settings={
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
}
)
@@ -735,7 +656,7 @@ class PromptAgent:
def parse_actions(self, response: str, masks=None):
if self.observation_type in ["screenshot", "a11y_tree", "both"]:
if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
@@ -747,7 +668,7 @@ class PromptAgent:
self.actions.append(actions)
return actions
elif self.observation_type in ["som", "seeact"]:
elif self.observation_type in ["som"]:
# parse from the response
if self.action_space == "computer_13":
raise ValueError("Invalid action space: " + self.action_space)

View File

@@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti
You CAN predict multiple actions at one step, but you should only return one action for each step.
""".strip()
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """
SYS_PROMPT_IN_SOM_OUT_TAG = """
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library.
For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image.
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
You can replace x, y in the code with the tag of the element you want to operate with. such as:

47
run.py
View File

@@ -7,7 +7,8 @@ import json
import logging
import os
import sys
# import signal
from tqdm # import tqdm
import time
import timeout_decorator
@@ -48,6 +49,7 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment")
# make sure each example won't exceed the time limit
# def handler(signo, frame):
# raise RuntimeError("Time limit exceeded!")
@@ -73,7 +75,7 @@ def config() -> argparse.Namespace:
"screenshot_a11y_tree",
"som"
],
default="a11y_tree",
default="som",
help="Observation type",
)
parser.add_argument("--screen_width", type=int, default=1920)
@@ -126,8 +128,8 @@ def test(
headless=args.headless,
)
for domain in test_all_meta:
for example_id in test_all_meta[domain]:
for domain in tqdm(test_all_meta, desc="Domain"):
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
# example setting
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f:
@@ -169,7 +171,7 @@ def test(
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
observation, reward, done, info = env.step(action, args.sleep_after_execution)
obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
@@ -177,8 +179,8 @@ def test(
# Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
with open(observation['screenshot'], "rb") as __f:
"wb") as _f:
with open(obs['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
@@ -198,10 +200,12 @@ def test(
logger.info("The episode is done.")
break
step_idx += 1
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
# example start running
@@ -218,18 +222,23 @@ def test(
}))
f.write("\n")
continue
except Exception as e:
logger.error(f"Error in example {domain}/{example_id}: {e}")
continue
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
@@ -245,13 +254,14 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
if domain in total_file_json:
total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples]
return total_file_json
if __name__ == '__main__':
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -260,11 +270,16 @@ if __name__ == '__main__':
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
test_file_list = get_unfinished(args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta)
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
os.environ['OPENAI_API_KEY'] = "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"
test(args, test_all_meta)
test(args, test_all_meta)