update timer
This commit is contained in:
@@ -21,10 +21,12 @@
|
||||
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
|
||||
|
||||
2. Install the environment package, download the examples and the virtual machine image.
|
||||
For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands:
|
||||
```bash
|
||||
pip install desktop-env
|
||||
gdown xxxx
|
||||
gdown xxxx
|
||||
vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui
|
||||
vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -284,6 +284,15 @@ def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = N
|
||||
text = text.replace("\ufffc", "").replace("\ufffd", "")
|
||||
# }}} Text #
|
||||
|
||||
# Image {{{ #
|
||||
try:
|
||||
node.queryImage()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
attribute_dict["image"] = "true"
|
||||
# }}} Image #
|
||||
|
||||
# Selection {{{ #
|
||||
try:
|
||||
node.querySelection()
|
||||
|
||||
24
main.py
24
main.py
@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import argparse
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
|
||||
# Logger Configs {{{ #
|
||||
@@ -46,19 +46,29 @@ def human_agent():
|
||||
"""
|
||||
Runs the Gym environment with human input.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.")
|
||||
parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.")
|
||||
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
|
||||
with open("evaluation_examples/examples/multi_apps/4c26e3f3-3a14-4d86-b44a-d3cedebbb487.json", "r", encoding="utf-8") as f:
|
||||
example_path = args.example if args.example is not None and os.path.exists(args.example) else \
|
||||
'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json'
|
||||
with open(example_path, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
example["snapshot"] = "exp_v5"
|
||||
if args.snapshot is not None:
|
||||
example['snapshot'] = args.snapshot
|
||||
|
||||
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
|
||||
env = DesktopEnv(
|
||||
path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx",
|
||||
action_space="computer_13",
|
||||
task_config=example
|
||||
path_to_vm=args.path,
|
||||
snapshot_name=args.snapshot,
|
||||
action_space="computer_13"
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
observation = env.reset()
|
||||
observation = env.reset(task_config=example)
|
||||
done = False
|
||||
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
|
||||
|
||||
trajectory = [
|
||||
{
|
||||
|
||||
@@ -26,7 +26,7 @@ def find_leaf_nodes(xlm_file_str):
|
||||
|
||||
state_ns = "uri:deskat:state.at-spi.gnome.org"
|
||||
component_ns = "uri:deskat:component.at-spi.gnome.org"
|
||||
def judge_node(node: ET, platform="ubuntu") -> bool:
|
||||
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
|
||||
keeps: bool = node.tag.startswith("document")\
|
||||
or node.tag.endswith("item")\
|
||||
or node.tag.endswith("button")\
|
||||
@@ -60,18 +60,20 @@ def judge_node(node: ET, platform="ubuntu") -> bool:
|
||||
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
|
||||
)\
|
||||
and (node.get("name", "") != "" or node.text is not None and len(node.text)>0)
|
||||
and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\
|
||||
or check_image and node.get("image", "false")=="true"
|
||||
)
|
||||
|
||||
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
|
||||
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
|
||||
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
|
||||
return keeps
|
||||
|
||||
def filter_nodes(root: ET, platform="ubuntu"):
|
||||
def filter_nodes(root: ET, platform="ubuntu", check_image=False):
|
||||
filtered_nodes = []
|
||||
|
||||
for node in root.iter():
|
||||
if judge_node(node, platform):
|
||||
if judge_node(node, platform, check_image):
|
||||
filtered_nodes.append(node)
|
||||
#print(ET.tostring(node, encoding="unicode"))
|
||||
|
||||
@@ -155,12 +157,12 @@ def print_nodes_with_indent(nodes, indent=0):
|
||||
|
||||
if __name__ == '__main__':
|
||||
import json
|
||||
with open('4.json', 'r', encoding='utf-8') as f:
|
||||
xml_file_str = json.load(f)["AT"]
|
||||
with open('selection_sorted(imaged).xml', 'r', encoding='utf-8') as f:
|
||||
xml_file_str = f.read()
|
||||
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
|
||||
print(len(filtered_nodes))
|
||||
masks = draw_bounding_boxes( filtered_nodes, '4.png'
|
||||
, '4.a.png'
|
||||
masks = draw_bounding_boxes( filtered_nodes, 'selection_sorted(imaged).png'
|
||||
, 'selection_sorted(imaged).ai.png'
|
||||
)
|
||||
|
||||
# print(masks)
|
||||
|
||||
@@ -11,26 +11,17 @@ from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
from PIL import Image
|
||||
from vertexai.preview.generative_models import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
Image,
|
||||
)
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
||||
|
||||
# todo: cross-check with visualwebarena
|
||||
SYS_PROMPT_IN_SOM_OUT_TAG
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
@@ -45,7 +36,7 @@ def linearize_accessibility_tree(accessibility_tree):
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||
|
||||
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
||||
linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
|
||||
for node in filtered_nodes:
|
||||
@@ -73,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree):
|
||||
uuid_str = str(uuid.uuid4())
|
||||
os.makedirs("tmp/images", exist_ok=True)
|
||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
||||
# Make tag screenshot
|
||||
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||
|
||||
@@ -178,7 +170,7 @@ class PromptAgent:
|
||||
temperature=0.5,
|
||||
action_space="computer_13",
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=3
|
||||
):
|
||||
self.model = model
|
||||
@@ -207,7 +199,7 @@ class PromptAgent:
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif observation_type == "both":
|
||||
elif observation_type == "screenshot_a11y_tree":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
@@ -218,14 +210,7 @@ class PromptAgent:
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif observation_type == "seeact":
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_SEEACT
|
||||
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
else:
|
||||
@@ -235,8 +220,7 @@ class PromptAgent:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
||||
instruction)
|
||||
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
|
||||
|
||||
# Prepare the payload for the API call
|
||||
messages = []
|
||||
@@ -247,7 +231,7 @@ class PromptAgent:
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.system_message
|
||||
"text": system_message
|
||||
},
|
||||
]
|
||||
})
|
||||
@@ -268,7 +252,7 @@ class PromptAgent:
|
||||
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
||||
|
||||
# {{{1
|
||||
if self.observation_type == "both":
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||
@@ -290,18 +274,15 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.observation_type in ["som", "seeact"]:
|
||||
elif self.observation_type in ["som"]:
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
_linearized_accessibility_tree)
|
||||
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -358,11 +339,11 @@ class PromptAgent:
|
||||
})
|
||||
|
||||
# {{{1
|
||||
if self.observation_type in ["screenshot", "both"]:
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
if self.observation_type == "both":
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -414,11 +395,9 @@ class PromptAgent:
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
"screenshot": base64_image
|
||||
})
|
||||
|
||||
messages.append({
|
||||
@@ -426,35 +405,7 @@ class PromptAgent:
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
linearized_accessibility_tree)
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.observation_type == "seeact":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
|
||||
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -475,43 +426,13 @@ class PromptAgent:
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature
|
||||
})
|
||||
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
if self.observation_type == "seeact":
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": response
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
|
||||
ACTION_GROUNDING_PROMPT_SEEACT)
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature
|
||||
})
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
try:
|
||||
actions = self.parse_actions(response, masks)
|
||||
self.thoughts.append(response)
|
||||
@@ -524,16 +445,15 @@ class PromptAgent:
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
# here you should add more model exceptions as you want,
|
||||
# here you should add more model exceptions as you want,
|
||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
||||
(openai.RateLimitError,
|
||||
openai.BadRequestError,
|
||||
openai.InternalServerError,
|
||||
InvalidArgument),
|
||||
openai.BadRequestError,
|
||||
openai.InternalServerError,
|
||||
InvalidArgument),
|
||||
max_tries=5
|
||||
)
|
||||
|
||||
def call_llm(self, payload):
|
||||
|
||||
if self.model.startswith("gpt"):
|
||||
@@ -551,14 +471,15 @@ class PromptAgent:
|
||||
if response.status_code != 200:
|
||||
if response.json()['error']['code'] == "context_length_exceeded":
|
||||
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||
payload["messages"] = payload["messages"][-1:]
|
||||
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
|
||||
retry_response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
if retry_response.status_code != 200:
|
||||
logger.error("Failed to call LLM: " + retry_response.text)
|
||||
logger.error(
|
||||
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
||||
return ""
|
||||
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
@@ -665,8 +586,9 @@ class PromptAgent:
|
||||
for message in gemini_messages:
|
||||
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
|
||||
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
|
||||
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
|
||||
|
||||
print(gemini_messages)
|
||||
# print(gemini_messages)
|
||||
api_key = os.environ.get("GENAI_API_KEY")
|
||||
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||
genai.configure(api_key=api_key)
|
||||
@@ -680,11 +602,10 @@ class PromptAgent:
|
||||
"temperature": temperature
|
||||
},
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
"harassment": "block_none",
|
||||
"hate": "block_none",
|
||||
"sex": "block_none",
|
||||
"danger": "block_none"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -735,7 +656,7 @@ class PromptAgent:
|
||||
|
||||
def parse_actions(self, response: str, masks=None):
|
||||
|
||||
if self.observation_type in ["screenshot", "a11y_tree", "both"]:
|
||||
if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
actions = parse_actions_from_string(response)
|
||||
@@ -747,7 +668,7 @@ class PromptAgent:
|
||||
self.actions.append(actions)
|
||||
|
||||
return actions
|
||||
elif self.observation_type in ["som", "seeact"]:
|
||||
elif self.observation_type in ["som"]:
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + self.action_space)
|
||||
|
||||
@@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti
|
||||
You CAN predict multiple actions at one step, but you should only return one action for each step.
|
||||
""".strip()
|
||||
|
||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """
|
||||
SYS_PROMPT_IN_SOM_OUT_TAG = """
|
||||
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
||||
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
||||
For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library.
|
||||
For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image.
|
||||
|
||||
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
|
||||
You can replace x, y in the code with the tag of the element you want to operate with. such as:
|
||||
|
||||
47
run.py
47
run.py
@@ -7,7 +7,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
# import signal
|
||||
|
||||
from tqdm # import tqdm
|
||||
import time
|
||||
import timeout_decorator
|
||||
|
||||
@@ -48,6 +49,7 @@ logger.addHandler(sdebug_handler)
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
# make sure each example won't exceed the time limit
|
||||
# def handler(signo, frame):
|
||||
# raise RuntimeError("Time limit exceeded!")
|
||||
@@ -73,7 +75,7 @@ def config() -> argparse.Namespace:
|
||||
"screenshot_a11y_tree",
|
||||
"som"
|
||||
],
|
||||
default="a11y_tree",
|
||||
default="som",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
@@ -126,8 +128,8 @@ def test(
|
||||
headless=args.headless,
|
||||
)
|
||||
|
||||
for domain in test_all_meta:
|
||||
for example_id in test_all_meta[domain]:
|
||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||
# example setting
|
||||
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
@@ -169,7 +171,7 @@ def test(
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
|
||||
observation, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
@@ -177,8 +179,8 @@ def test(
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
with open(observation['screenshot'], "rb") as __f:
|
||||
"wb") as _f:
|
||||
with open(obs['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
@@ -198,10 +200,12 @@ def test(
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
# example start running
|
||||
@@ -218,18 +222,23 @@ def test(
|
||||
}))
|
||||
f.write("\n")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in example {domain}/{example_id}: {e}")
|
||||
continue
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
@@ -245,13 +254,14 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples]
|
||||
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -260,11 +270,16 @@ if __name__ == '__main__':
|
||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
test_file_list = get_unfinished(args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta)
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = "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"
|
||||
test(args, test_all_meta)
|
||||
test(args, test_all_meta)
|
||||
|
||||
Reference in New Issue
Block a user