Change SoM input and output

This commit is contained in:
Timothyxxx
2024-03-15 22:10:35 +08:00
parent cfa9aaf3a7
commit 1ad4527e8b
2 changed files with 17 additions and 91 deletions

View File

@@ -21,10 +21,7 @@ from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
# todo: cross-check with visualwebarena
SYS_PROMPT_IN_SOM_OUT_TAG
logger = logging.getLogger("desktopenv.agent")
@@ -67,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree):
uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
nodes = filter_nodes(ET.fromstring(accessibility_tree))
# Make tag screenshot
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
@@ -172,7 +170,7 @@ class PromptAgent:
temperature=0.5,
action_space="computer_13",
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=3
):
self.model = model
@@ -212,14 +210,7 @@ class PromptAgent:
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "seeact":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_SEEACT
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
else:
@@ -283,18 +274,15 @@ class PromptAgent:
}
]
})
elif self.observation_type in ["som", "seeact"]:
elif self.observation_type in ["som"]:
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
@@ -407,11 +395,9 @@ class PromptAgent:
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
"screenshot": base64_image
})
messages.append({
@@ -419,35 +405,7 @@ class PromptAgent:
"content": [
{
"type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
]
})
elif self.observation_type == "seeact":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
@@ -475,38 +433,6 @@ class PromptAgent:
logger.info("RESPONSE: %s", response)
if self.observation_type == "seeact":
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": response
}
]
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
ACTION_GROUNDING_PROMPT_SEEACT)
}
]
})
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
logger.info("RESPONSE: %s", response)
try:
actions = self.parse_actions(response, masks)
self.thoughts.append(response)
@@ -523,12 +449,11 @@ class PromptAgent:
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
max_tries=5
)
def call_llm(self, payload):
if self.model.startswith("gpt"):
@@ -553,7 +478,8 @@ class PromptAgent:
json=payload
)
if retry_response.status_code != 200:
logger.error("Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
logger.error(
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
return ""
logger.error("Failed to call LLM: " + response.text)
@@ -742,7 +668,7 @@ class PromptAgent:
self.actions.append(actions)
return actions
elif self.observation_type in ["som", "seeact"]:
elif self.observation_type in ["som"]:
# parse from the response
if self.action_space == "computer_13":
raise ValueError("Invalid action space: " + self.action_space)

View File

@@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti
You CAN predict multiple actions at one step, but you should only return one action for each step.
""".strip()
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """
SYS_PROMPT_IN_SOM_OUT_TAG = """
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library.
For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image.
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
You can replace x, y in the code with the tag of the element you want to operate with. such as: