update timer

This commit is contained in:
Jason Lee
2024-03-15 23:28:45 +08:00
7 changed files with 106 additions and 147 deletions

View File

@@ -21,10 +21,12 @@
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw) Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
2. Install the environment package, download the examples and the virtual machine image. 2. Install the environment package, download the examples and the virtual machine image.
For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands:
```bash ```bash
pip install desktop-env pip install desktop-env
gdown xxxx gdown xxxx
gdown xxxx vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui
vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state"
``` ```
## Quick Start ## Quick Start

View File

@@ -284,6 +284,15 @@ def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = N
text = text.replace("\ufffc", "").replace("\ufffd", "") text = text.replace("\ufffc", "").replace("\ufffd", "")
# }}} Text # # }}} Text #
# Image {{{ #
try:
node.queryImage()
except NotImplementedError:
pass
else:
attribute_dict["image"] = "true"
# }}} Image #
# Selection {{{ # # Selection {{{ #
try: try:
node.querySelection() node.querySelection()

24
main.py
View File

@@ -4,7 +4,7 @@ import logging
import os import os
import sys import sys
import time import time
import argparse
from desktop_env.envs.desktop_env import DesktopEnv from desktop_env.envs.desktop_env import DesktopEnv
# Logger Configs {{{ # # Logger Configs {{{ #
@@ -46,19 +46,29 @@ def human_agent():
""" """
Runs the Gym environment with human input. Runs the Gym environment with human input.
""" """
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.")
parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.")
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
args = parser.parse_args(sys.argv[1:])
with open("evaluation_examples/examples/multi_apps/4c26e3f3-3a14-4d86-b44a-d3cedebbb487.json", "r", encoding="utf-8") as f: example_path = args.example if args.example is not None and os.path.exists(args.example) else \
'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json'
with open(example_path, "r", encoding="utf-8") as f:
example = json.load(f) example = json.load(f)
example["snapshot"] = "exp_v5" if args.snapshot is not None:
example['snapshot'] = args.snapshot
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
env = DesktopEnv( env = DesktopEnv(
path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", path_to_vm=args.path,
action_space="computer_13", snapshot_name=args.snapshot,
task_config=example action_space="computer_13"
) )
# reset the environment to certain snapshot # reset the environment to certain snapshot
observation = env.reset() observation = env.reset(task_config=example)
done = False done = False
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
trajectory = [ trajectory = [
{ {

View File

@@ -26,7 +26,7 @@ def find_leaf_nodes(xlm_file_str):
state_ns = "uri:deskat:state.at-spi.gnome.org" state_ns = "uri:deskat:state.at-spi.gnome.org"
component_ns = "uri:deskat:component.at-spi.gnome.org" component_ns = "uri:deskat:component.at-spi.gnome.org"
def judge_node(node: ET, platform="ubuntu") -> bool: def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
keeps: bool = node.tag.startswith("document")\ keeps: bool = node.tag.startswith("document")\
or node.tag.endswith("item")\ or node.tag.endswith("item")\
or node.tag.endswith("button")\ or node.tag.endswith("button")\
@@ -60,18 +60,20 @@ def judge_node(node: ET, platform="ubuntu") -> bool:
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\ or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true" or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
)\ )\
and (node.get("name", "") != "" or node.text is not None and len(node.text)>0) and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\
or check_image and node.get("image", "false")=="true"
)
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)")) coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)")) sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0 keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
return keeps return keeps
def filter_nodes(root: ET, platform="ubuntu"): def filter_nodes(root: ET, platform="ubuntu", check_image=False):
filtered_nodes = [] filtered_nodes = []
for node in root.iter(): for node in root.iter():
if judge_node(node, platform): if judge_node(node, platform, check_image):
filtered_nodes.append(node) filtered_nodes.append(node)
#print(ET.tostring(node, encoding="unicode")) #print(ET.tostring(node, encoding="unicode"))
@@ -155,12 +157,12 @@ def print_nodes_with_indent(nodes, indent=0):
if __name__ == '__main__': if __name__ == '__main__':
import json import json
with open('4.json', 'r', encoding='utf-8') as f: with open('selection_sorted(imaged).xml', 'r', encoding='utf-8') as f:
xml_file_str = json.load(f)["AT"] xml_file_str = f.read()
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str)) filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
print(len(filtered_nodes)) print(len(filtered_nodes))
masks = draw_bounding_boxes( filtered_nodes, '4.png' masks = draw_bounding_boxes( filtered_nodes, 'selection_sorted(imaged).png'
, '4.a.png' , 'selection_sorted(imaged).ai.png'
) )
# print(masks) # print(masks)

View File

@@ -11,26 +11,17 @@ from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Dict, List from typing import Dict, List
from google.api_core.exceptions import InvalidArgument from google.api_core.exceptions import InvalidArgument
import backoff import backoff
import dashscope import dashscope
import google.generativeai as genai import google.generativeai as genai
import requests import requests
from PIL import Image from PIL import Image
from vertexai.preview.generative_models import (
HarmBlockThreshold,
HarmCategory,
Image,
)
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \ SYS_PROMPT_IN_SOM_OUT_TAG
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
# todo: cross-check with visualwebarena
logger = logging.getLogger("desktopenv.agent") logger = logging.getLogger("desktopenv.agent")
@@ -45,7 +36,7 @@ def linearize_accessibility_tree(accessibility_tree):
# leaf_nodes = find_leaf_nodes(accessibility_tree) # leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n" linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
# Linearize the accessibility tree nodes into a table format # Linearize the accessibility tree nodes into a table format
for node in filtered_nodes: for node in filtered_nodes:
@@ -73,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree):
uuid_str = str(uuid.uuid4()) uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True) os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) # nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
# Make tag screenshot # Make tag screenshot
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
@@ -178,7 +170,7 @@ class PromptAgent:
temperature=0.5, temperature=0.5,
action_space="computer_13", action_space="computer_13",
observation_type="screenshot_a11y_tree", observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"] # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=3 max_trajectory_length=3
): ):
self.model = model self.model = model
@@ -207,7 +199,7 @@ class PromptAgent:
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else: else:
raise ValueError("Invalid action space: " + action_space) raise ValueError("Invalid action space: " + action_space)
elif observation_type == "both": elif observation_type == "screenshot_a11y_tree":
if action_space == "computer_13": if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui": elif action_space == "pyautogui":
@@ -218,14 +210,7 @@ class PromptAgent:
if action_space == "computer_13": if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space) raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui": elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "seeact":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_SEEACT
else: else:
raise ValueError("Invalid action space: " + action_space) raise ValueError("Invalid action space: " + action_space)
else: else:
@@ -235,8 +220,7 @@ class PromptAgent:
""" """
Predict the next action(s) based on the current observation. Predict the next action(s) based on the current observation.
""" """
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format( system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
instruction)
# Prepare the payload for the API call # Prepare the payload for the API call
messages = [] messages = []
@@ -247,7 +231,7 @@ class PromptAgent:
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": self.system_message "text": system_message
}, },
] ]
}) })
@@ -268,7 +252,7 @@ class PromptAgent:
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts): for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1 # {{{1
if self.observation_type == "both": if self.observation_type == "screenshot_a11y_tree":
_screenshot = previous_obs["screenshot"] _screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"] _linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
@@ -290,18 +274,15 @@ class PromptAgent:
} }
] ]
}) })
elif self.observation_type in ["som", "seeact"]: elif self.observation_type in ["som"]:
_screenshot = previous_obs["screenshot"] _screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
messages.append({ messages.append({
"role": "user", "role": "user",
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( "text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
_linearized_accessibility_tree)
}, },
{ {
"type": "image_url", "type": "image_url",
@@ -358,11 +339,11 @@ class PromptAgent:
}) })
# {{{1 # {{{1
if self.observation_type in ["screenshot", "both"]: if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"]) base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
if self.observation_type == "both": if self.observation_type == "screenshot_a11y_tree":
self.observations.append({ self.observations.append({
"screenshot": base64_image, "screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree "accessibility_tree": linearized_accessibility_tree
@@ -414,11 +395,9 @@ class PromptAgent:
# Add som to the screenshot # Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot) base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({ self.observations.append({
"screenshot": base64_image, "screenshot": base64_image
"accessibility_tree": linearized_accessibility_tree
}) })
messages.append({ messages.append({
@@ -426,35 +405,7 @@ class PromptAgent:
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( "text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
]
})
elif self.observation_type == "seeact":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
}, },
{ {
"type": "image_url", "type": "image_url",
@@ -475,43 +426,13 @@ class PromptAgent:
response = self.call_llm({ response = self.call_llm({
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"max_tokens": self.max_tokens "max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}) })
logger.info("RESPONSE: %s", response) logger.info("RESPONSE: %s", response)
if self.observation_type == "seeact":
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": response
}
]
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
ACTION_GROUNDING_PROMPT_SEEACT)
}
]
})
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
logger.info("RESPONSE: %s", response)
try: try:
actions = self.parse_actions(response, masks) actions = self.parse_actions(response, masks)
self.thoughts.append(response) self.thoughts.append(response)
@@ -524,16 +445,15 @@ class PromptAgent:
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
# here you should add more model exceptions as you want, # here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception # but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit # because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(openai.RateLimitError, (openai.RateLimitError,
openai.BadRequestError, openai.BadRequestError,
openai.InternalServerError, openai.InternalServerError,
InvalidArgument), InvalidArgument),
max_tries=5 max_tries=5
) )
def call_llm(self, payload): def call_llm(self, payload):
if self.model.startswith("gpt"): if self.model.startswith("gpt"):
@@ -551,14 +471,15 @@ class PromptAgent:
if response.status_code != 200: if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded": if response.json()['error']['code'] == "context_length_exceeded":
logger.error("Context length exceeded. Retrying with a smaller context.") logger.error("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:] payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
retry_response = requests.post( retry_response = requests.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers=headers,
json=payload json=payload
) )
if retry_response.status_code != 200: if retry_response.status_code != 200:
logger.error("Failed to call LLM: " + retry_response.text) logger.error(
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
return "" return ""
logger.error("Failed to call LLM: " + response.text) logger.error("Failed to call LLM: " + response.text)
@@ -665,8 +586,9 @@ class PromptAgent:
for message in gemini_messages: for message in gemini_messages:
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n" message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}] gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
print(gemini_messages) # print(gemini_messages)
api_key = os.environ.get("GENAI_API_KEY") api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable" assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
@@ -680,11 +602,10 @@ class PromptAgent:
"temperature": temperature "temperature": temperature
}, },
safety_settings={ safety_settings={
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, "harassment": "block_none",
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, "hate": "block_none",
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, "sex": "block_none",
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, "danger": "block_none"
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
} }
) )
@@ -735,7 +656,7 @@ class PromptAgent:
def parse_actions(self, response: str, masks=None): def parse_actions(self, response: str, masks=None):
if self.observation_type in ["screenshot", "a11y_tree", "both"]: if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
# parse from the response # parse from the response
if self.action_space == "computer_13": if self.action_space == "computer_13":
actions = parse_actions_from_string(response) actions = parse_actions_from_string(response)
@@ -747,7 +668,7 @@ class PromptAgent:
self.actions.append(actions) self.actions.append(actions)
return actions return actions
elif self.observation_type in ["som", "seeact"]: elif self.observation_type in ["som"]:
# parse from the response # parse from the response
if self.action_space == "computer_13": if self.action_space == "computer_13":
raise ValueError("Invalid action space: " + self.action_space) raise ValueError("Invalid action space: " + self.action_space)

View File

@@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti
You CAN predict multiple actions at one step, but you should only return one action for each step. You CAN predict multiple actions at one step, but you should only return one action for each step.
""".strip() """.strip()
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """ SYS_PROMPT_IN_SOM_OUT_TAG = """
You are an agent which follow my instruction and perform desktop computer tasks as instructed. You are an agent which follow my instruction and perform desktop computer tasks as instructed.
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library. For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image.
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot. You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
You can replace x, y in the code with the tag of the element you want to operate with. such as: You can replace x, y in the code with the tag of the element you want to operate with. such as:

47
run.py
View File

@@ -7,7 +7,8 @@ import json
import logging import logging
import os import os
import sys import sys
# import signal
from tqdm # import tqdm
import time import time
import timeout_decorator import timeout_decorator
@@ -48,6 +49,7 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment") logger = logging.getLogger("desktopenv.experiment")
# make sure each example won't exceed the time limit # make sure each example won't exceed the time limit
# def handler(signo, frame): # def handler(signo, frame):
# raise RuntimeError("Time limit exceeded!") # raise RuntimeError("Time limit exceeded!")
@@ -73,7 +75,7 @@ def config() -> argparse.Namespace:
"screenshot_a11y_tree", "screenshot_a11y_tree",
"som" "som"
], ],
default="a11y_tree", default="som",
help="Observation type", help="Observation type",
) )
parser.add_argument("--screen_width", type=int, default=1920) parser.add_argument("--screen_width", type=int, default=1920)
@@ -126,8 +128,8 @@ def test(
headless=args.headless, headless=args.headless,
) )
for domain in test_all_meta: for domain in tqdm(test_all_meta, desc="Domain"):
for example_id in test_all_meta[domain]: for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
# example setting # example setting
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f: with open(config_file, "r", encoding="utf-8") as f:
@@ -169,7 +171,7 @@ def test(
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action) logger.info("Step %d: %s", step_idx + 1, action)
observation, reward, done, info = env.step(action, args.sleep_after_execution) obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward) logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done) logger.info("Done: %s", done)
@@ -177,8 +179,8 @@ def test(
# Save screenshot and trajectory information # Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f: "wb") as _f:
with open(observation['screenshot'], "rb") as __f: with open(obs['screenshot'], "rb") as __f:
screenshot = __f.read() screenshot = __f.read()
_f.write(screenshot) _f.write(screenshot)
@@ -198,10 +200,12 @@ def test(
logger.info("The episode is done.") logger.info("The episode is done.")
break break
step_idx += 1 step_idx += 1
result = env.evaluate() result = env.evaluate()
logger.info("Result: %.2f", result) logger.info("Result: %.2f", result)
scores.append(result) scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
# example start running # example start running
@@ -218,18 +222,23 @@ def test(
})) }))
f.write("\n") f.write("\n")
continue continue
except Exception as e:
logger.error(f"Error in example {domain}/{example_id}: {e}")
continue
env.close() env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}") logger.info(f"Average score: {sum(scores) / len(scores)}")
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json): def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model) target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
return total_file_json return total_file_json
finished = {} finished = {}
for domain in os.listdir(target_dir): for domain in os.listdir(target_dir):
finished[domain] = []
finished[domain] = [] finished[domain] = []
domain_path = os.path.join(target_dir, domain) domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path): if os.path.isdir(domain_path):
@@ -245,13 +254,14 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
if not finished: if not finished:
return total_file_json return total_file_json
for domain, examples in finished.items(): for domain, examples in finished.items():
if domain in total_file_json: if domain in total_file_json:
total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples] total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples]
return total_file_json return total_file_json
if __name__ == '__main__': if __name__ == '__main__':
####### The complete version of the list of examples ####### ####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -260,11 +270,16 @@ if __name__ == '__main__':
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
test_all_meta = json.load(f) test_all_meta = json.load(f)
test_file_list = get_unfinished(args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta) test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta
)
left_info = "" left_info = ""
for domain in test_file_list: for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n" left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}") logger.info(f"Left tasks:\n{left_info}")
os.environ['OPENAI_API_KEY'] = "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" test(args, test_all_meta)
test(args, test_all_meta)