Change SoM input and output
This commit is contained in:
@@ -21,10 +21,7 @@ from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes
|
|||||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
SYS_PROMPT_IN_SOM_OUT_TAG
|
||||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
|
||||||
|
|
||||||
# todo: cross-check with visualwebarena
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.agent")
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
@@ -67,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree):
|
|||||||
uuid_str = str(uuid.uuid4())
|
uuid_str = str(uuid.uuid4())
|
||||||
os.makedirs("tmp/images", exist_ok=True)
|
os.makedirs("tmp/images", exist_ok=True)
|
||||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||||
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||||
|
nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||||
# Make tag screenshot
|
# Make tag screenshot
|
||||||
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||||
|
|
||||||
@@ -172,7 +170,7 @@ class PromptAgent:
|
|||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
action_space="computer_13",
|
action_space="computer_13",
|
||||||
observation_type="screenshot_a11y_tree",
|
observation_type="screenshot_a11y_tree",
|
||||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
max_trajectory_length=3
|
max_trajectory_length=3
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -212,14 +210,7 @@ class PromptAgent:
|
|||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
elif action_space == "pyautogui":
|
elif action_space == "pyautogui":
|
||||||
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
|
||||||
else:
|
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
|
||||||
elif observation_type == "seeact":
|
|
||||||
if action_space == "computer_13":
|
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
|
||||||
elif action_space == "pyautogui":
|
|
||||||
self.system_message = SYS_PROMPT_SEEACT
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action space: " + action_space)
|
raise ValueError("Invalid action space: " + action_space)
|
||||||
else:
|
else:
|
||||||
@@ -283,18 +274,15 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.observation_type in ["som", "seeact"]:
|
elif self.observation_type in ["som"]:
|
||||||
_screenshot = previous_obs["screenshot"]
|
_screenshot = previous_obs["screenshot"]
|
||||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
|
||||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||||
_linearized_accessibility_tree)
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@@ -407,11 +395,9 @@ class PromptAgent:
|
|||||||
# Add som to the screenshot
|
# Add som to the screenshot
|
||||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||||
base64_image = encode_image(tagged_screenshot)
|
base64_image = encode_image(tagged_screenshot)
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
|
||||||
})
|
})
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -419,35 +405,7 @@ class PromptAgent:
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
||||||
linearized_accessibility_tree)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/png;base64,{base64_image}",
|
|
||||||
"detail": "high"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
elif self.observation_type == "seeact":
|
|
||||||
# Add som to the screenshot
|
|
||||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
|
||||||
base64_image = encode_image(tagged_screenshot)
|
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
||||||
|
|
||||||
self.observations.append({
|
|
||||||
"screenshot": base64_image,
|
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
|
||||||
})
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@@ -475,38 +433,6 @@ class PromptAgent:
|
|||||||
|
|
||||||
logger.info("RESPONSE: %s", response)
|
logger.info("RESPONSE: %s", response)
|
||||||
|
|
||||||
if self.observation_type == "seeact":
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": response
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
|
|
||||||
ACTION_GROUNDING_PROMPT_SEEACT)
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info("Generating content with GPT model: %s", self.model)
|
|
||||||
response = self.call_llm({
|
|
||||||
"model": self.model,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": self.max_tokens,
|
|
||||||
"top_p": self.top_p,
|
|
||||||
"temperature": self.temperature
|
|
||||||
})
|
|
||||||
logger.info("RESPONSE: %s", response)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
actions = self.parse_actions(response, masks)
|
actions = self.parse_actions(response, masks)
|
||||||
self.thoughts.append(response)
|
self.thoughts.append(response)
|
||||||
@@ -523,12 +449,11 @@ class PromptAgent:
|
|||||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||||
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
||||||
(openai.RateLimitError,
|
(openai.RateLimitError,
|
||||||
openai.BadRequestError,
|
openai.BadRequestError,
|
||||||
openai.InternalServerError,
|
openai.InternalServerError,
|
||||||
InvalidArgument),
|
InvalidArgument),
|
||||||
max_tries=5
|
max_tries=5
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_llm(self, payload):
|
def call_llm(self, payload):
|
||||||
|
|
||||||
if self.model.startswith("gpt"):
|
if self.model.startswith("gpt"):
|
||||||
@@ -553,7 +478,8 @@ class PromptAgent:
|
|||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
if retry_response.status_code != 200:
|
if retry_response.status_code != 200:
|
||||||
logger.error("Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
logger.error(
|
||||||
|
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
logger.error("Failed to call LLM: " + response.text)
|
logger.error("Failed to call LLM: " + response.text)
|
||||||
@@ -742,7 +668,7 @@ class PromptAgent:
|
|||||||
self.actions.append(actions)
|
self.actions.append(actions)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
elif self.observation_type in ["som", "seeact"]:
|
elif self.observation_type in ["som"]:
|
||||||
# parse from the response
|
# parse from the response
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
raise ValueError("Invalid action space: " + self.action_space)
|
raise ValueError("Invalid action space: " + self.action_space)
|
||||||
|
|||||||
@@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti
|
|||||||
You CAN predict multiple actions at one step, but you should only return one action for each step.
|
You CAN predict multiple actions at one step, but you should only return one action for each step.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """
|
SYS_PROMPT_IN_SOM_OUT_TAG = """
|
||||||
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
||||||
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
||||||
For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library.
|
For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image.
|
||||||
|
|
||||||
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
|
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
|
||||||
You can replace x, y in the code with the tag of the element you want to operate with. such as:
|
You can replace x, y in the code with the tag of the element you want to operate with. such as:
|
||||||
|
|||||||
Reference in New Issue
Block a user