update timer
This commit is contained in:
@@ -21,10 +21,12 @@
|
|||||||
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
|
Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw)
|
||||||
|
|
||||||
2. Install the environment package, download the examples and the virtual machine image.
|
2. Install the environment package, download the examples and the virtual machine image.
|
||||||
|
For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands:
|
||||||
```bash
|
```bash
|
||||||
pip install desktop-env
|
pip install desktop-env
|
||||||
gdown xxxx
|
gdown xxxx
|
||||||
gdown xxxx
|
vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui
|
||||||
|
vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|||||||
@@ -284,6 +284,15 @@ def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = N
|
|||||||
text = text.replace("\ufffc", "").replace("\ufffd", "")
|
text = text.replace("\ufffc", "").replace("\ufffd", "")
|
||||||
# }}} Text #
|
# }}} Text #
|
||||||
|
|
||||||
|
# Image {{{ #
|
||||||
|
try:
|
||||||
|
node.queryImage()
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
attribute_dict["image"] = "true"
|
||||||
|
# }}} Image #
|
||||||
|
|
||||||
# Selection {{{ #
|
# Selection {{{ #
|
||||||
try:
|
try:
|
||||||
node.querySelection()
|
node.querySelection()
|
||||||
|
|||||||
24
main.py
24
main.py
@@ -4,7 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import argparse
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv
|
from desktop_env.envs.desktop_env import DesktopEnv
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
# Logger Configs {{{ #
|
||||||
@@ -46,19 +46,29 @@ def human_agent():
|
|||||||
"""
|
"""
|
||||||
Runs the Gym environment with human input.
|
Runs the Gym environment with human input.
|
||||||
"""
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.")
|
||||||
|
parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.")
|
||||||
|
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
|
||||||
|
args = parser.parse_args(sys.argv[1:])
|
||||||
|
|
||||||
with open("evaluation_examples/examples/multi_apps/4c26e3f3-3a14-4d86-b44a-d3cedebbb487.json", "r", encoding="utf-8") as f:
|
example_path = args.example if args.example is not None and os.path.exists(args.example) else \
|
||||||
|
'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json'
|
||||||
|
with open(example_path, "r", encoding="utf-8") as f:
|
||||||
example = json.load(f)
|
example = json.load(f)
|
||||||
example["snapshot"] = "exp_v5"
|
if args.snapshot is not None:
|
||||||
|
example['snapshot'] = args.snapshot
|
||||||
|
|
||||||
|
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
|
||||||
env = DesktopEnv(
|
env = DesktopEnv(
|
||||||
path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx",
|
path_to_vm=args.path,
|
||||||
action_space="computer_13",
|
snapshot_name=args.snapshot,
|
||||||
task_config=example
|
action_space="computer_13"
|
||||||
)
|
)
|
||||||
# reset the environment to certain snapshot
|
# reset the environment to certain snapshot
|
||||||
observation = env.reset()
|
observation = env.reset(task_config=example)
|
||||||
done = False
|
done = False
|
||||||
|
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
|
||||||
|
|
||||||
trajectory = [
|
trajectory = [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def find_leaf_nodes(xlm_file_str):
|
|||||||
|
|
||||||
state_ns = "uri:deskat:state.at-spi.gnome.org"
|
state_ns = "uri:deskat:state.at-spi.gnome.org"
|
||||||
component_ns = "uri:deskat:component.at-spi.gnome.org"
|
component_ns = "uri:deskat:component.at-spi.gnome.org"
|
||||||
def judge_node(node: ET, platform="ubuntu") -> bool:
|
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
|
||||||
keeps: bool = node.tag.startswith("document")\
|
keeps: bool = node.tag.startswith("document")\
|
||||||
or node.tag.endswith("item")\
|
or node.tag.endswith("item")\
|
||||||
or node.tag.endswith("button")\
|
or node.tag.endswith("button")\
|
||||||
@@ -60,18 +60,20 @@ def judge_node(node: ET, platform="ubuntu") -> bool:
|
|||||||
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
|
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
|
||||||
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
|
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
|
||||||
)\
|
)\
|
||||||
and (node.get("name", "") != "" or node.text is not None and len(node.text)>0)
|
and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\
|
||||||
|
or check_image and node.get("image", "false")=="true"
|
||||||
|
)
|
||||||
|
|
||||||
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
|
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
|
||||||
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
|
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
|
||||||
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
|
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
|
||||||
return keeps
|
return keeps
|
||||||
|
|
||||||
def filter_nodes(root: ET, platform="ubuntu"):
|
def filter_nodes(root: ET, platform="ubuntu", check_image=False):
|
||||||
filtered_nodes = []
|
filtered_nodes = []
|
||||||
|
|
||||||
for node in root.iter():
|
for node in root.iter():
|
||||||
if judge_node(node, platform):
|
if judge_node(node, platform, check_image):
|
||||||
filtered_nodes.append(node)
|
filtered_nodes.append(node)
|
||||||
#print(ET.tostring(node, encoding="unicode"))
|
#print(ET.tostring(node, encoding="unicode"))
|
||||||
|
|
||||||
@@ -155,12 +157,12 @@ def print_nodes_with_indent(nodes, indent=0):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import json
|
import json
|
||||||
with open('4.json', 'r', encoding='utf-8') as f:
|
with open('selection_sorted(imaged).xml', 'r', encoding='utf-8') as f:
|
||||||
xml_file_str = json.load(f)["AT"]
|
xml_file_str = f.read()
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
|
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
|
||||||
print(len(filtered_nodes))
|
print(len(filtered_nodes))
|
||||||
masks = draw_bounding_boxes( filtered_nodes, '4.png'
|
masks = draw_bounding_boxes( filtered_nodes, 'selection_sorted(imaged).png'
|
||||||
, '4.a.png'
|
, 'selection_sorted(imaged).ai.png'
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(masks)
|
# print(masks)
|
||||||
|
|||||||
@@ -11,26 +11,17 @@ from http import HTTPStatus
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from google.api_core.exceptions import InvalidArgument
|
from google.api_core.exceptions import InvalidArgument
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import dashscope
|
import dashscope
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from vertexai.preview.generative_models import (
|
|
||||||
HarmBlockThreshold,
|
|
||||||
HarmCategory,
|
|
||||||
Image,
|
|
||||||
)
|
|
||||||
|
|
||||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
SYS_PROMPT_IN_SOM_OUT_TAG
|
||||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
|
||||||
|
|
||||||
# todo: cross-check with visualwebarena
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.agent")
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
@@ -45,7 +36,7 @@ def linearize_accessibility_tree(accessibility_tree):
|
|||||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||||
|
|
||||||
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
|
||||||
# Linearize the accessibility tree nodes into a table format
|
# Linearize the accessibility tree nodes into a table format
|
||||||
|
|
||||||
for node in filtered_nodes:
|
for node in filtered_nodes:
|
||||||
@@ -73,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree):
|
|||||||
uuid_str = str(uuid.uuid4())
|
uuid_str = str(uuid.uuid4())
|
||||||
os.makedirs("tmp/images", exist_ok=True)
|
os.makedirs("tmp/images", exist_ok=True)
|
||||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||||
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||||
|
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
||||||
# Make tag screenshot
|
# Make tag screenshot
|
||||||
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||||
|
|
||||||
@@ -178,7 +170,7 @@ class PromptAgent:
|
|||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
action_space="computer_13",
|
action_space="computer_13",
|
||||||
observation_type="screenshot_a11y_tree",
|
observation_type="screenshot_a11y_tree",
|
||||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
max_trajectory_length=3
|
max_trajectory_length=3
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -207,7 +199,7 @@ class PromptAgent:
|
|||||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif observation_type == "both":
|
elif observation_type == "screenshot_a11y_tree":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
@@ -218,14 +210,7 @@ class PromptAgent:
|
|||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
|
||||||
else:
|
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
|
||||||
elif observation_type == "seeact":
|
|
||||||
if action_space == "computer_13":
|
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
|
||||||
elif action_space == "pyautogui":
|
|
||||||
self.system_message = SYS_PROMPT_SEEACT
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
else:
|
else:
|
||||||
@@ -235,8 +220,7 @@ class PromptAgent:
|
|||||||
"""
|
"""
|
||||||
Predict the next action(s) based on the current observation.
|
Predict the next action(s) based on the current observation.
|
||||||
"""
|
"""
|
||||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
|
||||||
instruction)
|
|
||||||
|
|
||||||
# Prepare the payload for the API call
|
# Prepare the payload for the API call
|
||||||
messages = []
|
messages = []
|
||||||
@@ -247,7 +231,7 @@ class PromptAgent:
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": self.system_message
|
"text": system_message
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
@@ -268,7 +252,7 @@ class PromptAgent:
|
|||||||
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
||||||
|
|
||||||
# {{{1
|
# {{{1
|
||||||
if self.observation_type == "both":
|
if self.observation_type == "screenshot_a11y_tree":
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||||
@@ -290,18 +274,15 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.observation_type in ["som", "seeact"]:
|
elif self.observation_type in ["som"]:
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
|
||||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||||
_linearized_accessibility_tree)
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@@ -358,11 +339,11 @@ class PromptAgent:
|
|||||||
})
|
})
|
||||||
|
|
||||||
# {{{1
|
# {{{1
|
||||||
if self.observation_type in ["screenshot", "both"]:
|
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||||
base64_image = encode_image(obs["screenshot"])
|
base64_image = encode_image(obs["screenshot"])
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||||
|
|
||||||
if self.observation_type == "both":
|
if self.observation_type == "screenshot_a11y_tree":
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image,
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
"accessibility_tree": linearized_accessibility_tree
|
||||||
@@ -414,11 +395,9 @@ class PromptAgent:
|
|||||||
# Add som to the screenshot
|
# Add som to the screenshot
|
||||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||||
base64_image = encode_image(tagged_screenshot)
|
base64_image = encode_image(tagged_screenshot)
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
|
||||||
})
|
})
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -426,35 +405,7 @@ class PromptAgent:
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||||
linearized_accessibility_tree)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/png;base64,{base64_image}",
|
|
||||||
"detail": "high"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
elif self.observation_type == "seeact":
|
|
||||||
# Add som to the screenshot
|
|
||||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
|
||||||
base64_image = encode_image(tagged_screenshot)
|
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
||||||
|
|
||||||
self.observations.append({
|
|
||||||
"screenshot": base64_image,
|
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
|
||||||
})
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@@ -475,43 +426,13 @@ class PromptAgent:
|
|||||||
response = self.call_llm({
|
response = self.call_llm({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"max_tokens": self.max_tokens
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info("RESPONSE: %s", response)
|
logger.info("RESPONSE: %s", response)
|
||||||
|
|
||||||
if self.observation_type == "seeact":
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": response
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
|
|
||||||
ACTION_GROUNDING_PROMPT_SEEACT)
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info("Generating content with GPT model: %s", self.model)
|
|
||||||
response = self.call_llm({
|
|
||||||
"model": self.model,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": self.max_tokens,
|
|
||||||
"top_p": self.top_p,
|
|
||||||
"temperature": self.temperature
|
|
||||||
})
|
|
||||||
logger.info("RESPONSE: %s", response)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
actions = self.parse_actions(response, masks)
|
actions = self.parse_actions(response, masks)
|
||||||
self.thoughts.append(response)
|
self.thoughts.append(response)
|
||||||
@@ -528,12 +449,11 @@ class PromptAgent:
|
|||||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||||
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
||||||
(openai.RateLimitError,
|
(openai.RateLimitError,
|
||||||
openai.BadRequestError,
|
openai.BadRequestError,
|
||||||
openai.InternalServerError,
|
openai.InternalServerError,
|
||||||
InvalidArgument),
|
InvalidArgument),
|
||||||
max_tries=5
|
max_tries=5
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_llm(self, payload):
|
def call_llm(self, payload):
|
||||||
|
|
||||||
if self.model.startswith("gpt"):
|
if self.model.startswith("gpt"):
|
||||||
@@ -551,14 +471,15 @@ class PromptAgent:
|
|||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
if response.json()['error']['code'] == "context_length_exceeded":
|
if response.json()['error']['code'] == "context_length_exceeded":
|
||||||
logger.error("Context length exceeded. Retrying with a smaller context.")
|
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||||
payload["messages"] = payload["messages"][-1:]
|
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
|
||||||
retry_response = requests.post(
|
retry_response = requests.post(
|
||||||
"https://api.openai.com/v1/chat/completions",
|
"https://api.openai.com/v1/chat/completions",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
if retry_response.status_code != 200:
|
if retry_response.status_code != 200:
|
||||||
logger.error("Failed to call LLM: " + retry_response.text)
|
logger.error(
|
||||||
|
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
logger.error("Failed to call LLM: " + response.text)
|
logger.error("Failed to call LLM: " + response.text)
|
||||||
@@ -665,8 +586,9 @@ class PromptAgent:
|
|||||||
for message in gemini_messages:
|
for message in gemini_messages:
|
||||||
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
|
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
|
||||||
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
|
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
|
||||||
|
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
|
||||||
|
|
||||||
print(gemini_messages)
|
# print(gemini_messages)
|
||||||
api_key = os.environ.get("GENAI_API_KEY")
|
api_key = os.environ.get("GENAI_API_KEY")
|
||||||
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
@@ -680,11 +602,10 @@ class PromptAgent:
|
|||||||
"temperature": temperature
|
"temperature": temperature
|
||||||
},
|
},
|
||||||
safety_settings={
|
safety_settings={
|
||||||
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
"harassment": "block_none",
|
||||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
"hate": "block_none",
|
||||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
"sex": "block_none",
|
||||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
"danger": "block_none"
|
||||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -735,7 +656,7 @@ class PromptAgent:
|
|||||||
|
|
||||||
def parse_actions(self, response: str, masks=None):
|
def parse_actions(self, response: str, masks=None):
|
||||||
|
|
||||||
if self.observation_type in ["screenshot", "a11y_tree", "both"]:
|
if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
|
||||||
# parse from the response
|
# parse from the response
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
actions = parse_actions_from_string(response)
|
actions = parse_actions_from_string(response)
|
||||||
@@ -747,7 +668,7 @@ class PromptAgent:
|
|||||||
self.actions.append(actions)
|
self.actions.append(actions)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
elif self.observation_type in ["som", "seeact"]:
|
elif self.observation_type in ["som"]:
|
||||||
# parse from the response
|
# parse from the response
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + self.action_space)
|
raise ValueError("Invalid action space: " + self.action_space)
|
||||||
|
|||||||
@@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti
|
|||||||
You CAN predict multiple actions at one step, but you should only return one action for each step.
|
You CAN predict multiple actions at one step, but you should only return one action for each step.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """
|
SYS_PROMPT_IN_SOM_OUT_TAG = """
|
||||||
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
||||||
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
||||||
For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library.
|
For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image.
|
||||||
|
|
||||||
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
|
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
|
||||||
You can replace x, y in the code with the tag of the element you want to operate with. such as:
|
You can replace x, y in the code with the tag of the element you want to operate with. such as:
|
||||||
|
|||||||
33
run.py
33
run.py
@@ -7,7 +7,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
# import signal
|
|
||||||
|
from tqdm # import tqdm
|
||||||
import time
|
import time
|
||||||
import timeout_decorator
|
import timeout_decorator
|
||||||
|
|
||||||
@@ -48,6 +49,7 @@ logger.addHandler(sdebug_handler)
|
|||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
# make sure each example won't exceed the time limit
|
# make sure each example won't exceed the time limit
|
||||||
# def handler(signo, frame):
|
# def handler(signo, frame):
|
||||||
# raise RuntimeError("Time limit exceeded!")
|
# raise RuntimeError("Time limit exceeded!")
|
||||||
@@ -73,7 +75,7 @@ def config() -> argparse.Namespace:
|
|||||||
"screenshot_a11y_tree",
|
"screenshot_a11y_tree",
|
||||||
"som"
|
"som"
|
||||||
],
|
],
|
||||||
default="a11y_tree",
|
default="som",
|
||||||
help="Observation type",
|
help="Observation type",
|
||||||
)
|
)
|
||||||
parser.add_argument("--screen_width", type=int, default=1920)
|
parser.add_argument("--screen_width", type=int, default=1920)
|
||||||
@@ -126,8 +128,8 @@ def test(
|
|||||||
headless=args.headless,
|
headless=args.headless,
|
||||||
)
|
)
|
||||||
|
|
||||||
for domain in test_all_meta:
|
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||||
for example_id in test_all_meta[domain]:
|
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||||
# example setting
|
# example setting
|
||||||
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||||
with open(config_file, "r", encoding="utf-8") as f:
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
@@ -169,7 +171,7 @@ def test(
|
|||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
logger.info("Step %d: %s", step_idx + 1, action)
|
logger.info("Step %d: %s", step_idx + 1, action)
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action, args.sleep_after_execution)
|
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
logger.info("Reward: %.2f", reward)
|
||||||
logger.info("Done: %s", done)
|
logger.info("Done: %s", done)
|
||||||
@@ -177,8 +179,8 @@ def test(
|
|||||||
|
|
||||||
# Save screenshot and trajectory information
|
# Save screenshot and trajectory information
|
||||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||||
"wb") as _f:
|
"wb") as _f:
|
||||||
with open(observation['screenshot'], "rb") as __f:
|
with open(obs['screenshot'], "rb") as __f:
|
||||||
screenshot = __f.read()
|
screenshot = __f.read()
|
||||||
_f.write(screenshot)
|
_f.write(screenshot)
|
||||||
|
|
||||||
@@ -202,6 +204,8 @@ def test(
|
|||||||
result = env.evaluate()
|
result = env.evaluate()
|
||||||
logger.info("Result: %.2f", result)
|
logger.info("Result: %.2f", result)
|
||||||
scores.append(result)
|
scores.append(result)
|
||||||
|
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||||
|
f.write(f"{result}\n")
|
||||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
|
|
||||||
# example start running
|
# example start running
|
||||||
@@ -218,6 +222,10 @@ def test(
|
|||||||
}))
|
}))
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in example {domain}/{example_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||||
|
|
||||||
@@ -230,6 +238,7 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
|
|||||||
|
|
||||||
finished = {}
|
finished = {}
|
||||||
for domain in os.listdir(target_dir):
|
for domain in os.listdir(target_dir):
|
||||||
|
finished[domain] = []
|
||||||
finished[domain] = []
|
finished[domain] = []
|
||||||
domain_path = os.path.join(target_dir, domain)
|
domain_path = os.path.join(target_dir, domain)
|
||||||
if os.path.isdir(domain_path):
|
if os.path.isdir(domain_path):
|
||||||
@@ -252,6 +261,7 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
|
|||||||
|
|
||||||
return total_file_json
|
return total_file_json
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
####### The complete version of the list of examples #######
|
####### The complete version of the list of examples #######
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -260,11 +270,16 @@ if __name__ == '__main__':
|
|||||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
||||||
test_all_meta = json.load(f)
|
test_all_meta = json.load(f)
|
||||||
|
|
||||||
test_file_list = get_unfinished(args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta)
|
test_file_list = get_unfinished(
|
||||||
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta
|
||||||
|
)
|
||||||
left_info = ""
|
left_info = ""
|
||||||
for domain in test_file_list:
|
for domain in test_file_list:
|
||||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||||
logger.info(f"Left tasks:\n{left_info}")
|
logger.info(f"Left tasks:\n{left_info}")
|
||||||
|
|
||||||
os.environ['OPENAI_API_KEY'] = "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"
|
|
||||||
test(args, test_all_meta)
|
test(args, test_all_meta)
|
||||||
Reference in New Issue
Block a user