This commit is contained in:
Timothyxxx
2024-03-15 21:06:50 +08:00
parent 35ed7cec89
commit 5cbf1b28ca
2 changed files with 39 additions and 43 deletions

View File

@@ -10,16 +10,10 @@ from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Dict, List from typing import Dict, List
import backoff
import dashscope import dashscope
import google.generativeai as genai import google.generativeai as genai
import requests import requests
from PIL import Image from PIL import Image
from vertexai.preview.generative_models import (
HarmBlockThreshold,
HarmCategory,
Image,
)
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
@@ -28,8 +22,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \ SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
# todo: cross-check with visualwebarena
logger = logging.getLogger("desktopenv.agent") logger = logging.getLogger("desktopenv.agent")
@@ -43,7 +35,7 @@ def linearize_accessibility_tree(accessibility_tree):
# leaf_nodes = find_leaf_nodes(accessibility_tree) # leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n" linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
# Linearize the accessibility tree nodes into a table format # Linearize the accessibility tree nodes into a table format
for node in filtered_nodes: for node in filtered_nodes:
@@ -205,7 +197,7 @@ class PromptAgent:
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else: else:
raise ValueError("Invalid action space: " + action_space) raise ValueError("Invalid action space: " + action_space)
elif observation_type == "both": elif observation_type == "screenshot_a11y_tree":
if action_space == "computer_13": if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui": elif action_space == "pyautogui":
@@ -233,8 +225,7 @@ class PromptAgent:
""" """
Predict the next action(s) based on the current observation. Predict the next action(s) based on the current observation.
""" """
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format( system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
instruction)
# Prepare the payload for the API call # Prepare the payload for the API call
messages = [] messages = []
@@ -245,7 +236,7 @@ class PromptAgent:
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": self.system_message "text": system_message
}, },
] ]
}) })
@@ -266,7 +257,7 @@ class PromptAgent:
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts): for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1 # {{{1
if self.observation_type == "both": if self.observation_type == "screenshot_a11y_tree":
_screenshot = previous_obs["screenshot"] _screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"] _linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
@@ -356,11 +347,11 @@ class PromptAgent:
}) })
# {{{1 # {{{1
if self.observation_type in ["screenshot", "both"]: if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"]) base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
if self.observation_type == "both": if self.observation_type == "screenshot_a11y_tree":
self.observations.append({ self.observations.append({
"screenshot": base64_image, "screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree "accessibility_tree": linearized_accessibility_tree
@@ -473,7 +464,9 @@ class PromptAgent:
response = self.call_llm({ response = self.call_llm({
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"max_tokens": self.max_tokens "max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}) })
logger.info("RESPONSE: %s", response) logger.info("RESPONSE: %s", response)
@@ -520,11 +513,11 @@ class PromptAgent:
return actions return actions
@backoff.on_exception( # @backoff.on_exception(
backoff.expo, # backoff.expo,
(Exception), # (Exception),
max_tries=5 # max_tries=5
) # )
def call_llm(self, payload): def call_llm(self, payload):
if self.model.startswith("gpt"): if self.model.startswith("gpt"):
@@ -542,14 +535,14 @@ class PromptAgent:
if response.status_code != 200: if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded": if response.json()['error']['code'] == "context_length_exceeded":
logger.error("Context length exceeded. Retrying with a smaller context.") logger.error("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:] payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
retry_response = requests.post( retry_response = requests.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers=headers,
json=payload json=payload
) )
if retry_response.status_code != 200: if retry_response.status_code != 200:
logger.error("Failed to call LLM: " + retry_response.text) logger.error("Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
return "" return ""
logger.error("Failed to call LLM: " + response.text) logger.error("Failed to call LLM: " + response.text)
@@ -656,8 +649,9 @@ class PromptAgent:
for message in gemini_messages: for message in gemini_messages:
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n" message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}] gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
print(gemini_messages) # print(gemini_messages)
api_key = os.environ.get("GENAI_API_KEY") api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable" assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
@@ -671,11 +665,10 @@ class PromptAgent:
"temperature": temperature "temperature": temperature
}, },
safety_settings={ safety_settings={
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, "harassment": "block_none",
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, "hate": "block_none",
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, "sex": "block_none",
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, "danger": "block_none"
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
} }
) )
@@ -726,7 +719,7 @@ class PromptAgent:
def parse_actions(self, response: str, masks=None): def parse_actions(self, response: str, masks=None):
if self.observation_type in ["screenshot", "a11y_tree", "both"]: if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
# parse from the response # parse from the response
if self.action_space == "computer_13": if self.action_space == "computer_13":
actions = parse_actions_from_string(response) actions = parse_actions_from_string(response)

27
run.py
View File

@@ -66,7 +66,7 @@ def config() -> argparse.Namespace:
"screenshot_a11y_tree", "screenshot_a11y_tree",
"som" "som"
], ],
default="a11y_tree", default="som",
help="Observation type", help="Observation type",
) )
parser.add_argument("--screen_width", type=int, default=1920) parser.add_argument("--screen_width", type=int, default=1920)
@@ -146,6 +146,7 @@ def test(
step_idx = 0 step_idx = 0
env.controller.start_recording() env.controller.start_recording()
# todo: update max running time for each example, @xiaochuan
while not done and step_idx < max_steps: while not done and step_idx < max_steps:
actions = agent.predict( actions = agent.predict(
instruction, instruction,
@@ -158,7 +159,7 @@ def test(
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action) logger.info("Step %d: %s", step_idx + 1, action)
observation, reward, done, info = env.step(action, args.sleep_after_execution) obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward) logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done) logger.info("Done: %s", done)
@@ -167,7 +168,7 @@ def test(
# Save screenshot and trajectory information # Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f: "wb") as _f:
with open(observation['screenshot'], "rb") as __f: with open(obs['screenshot'], "rb") as __f:
screenshot = __f.read() screenshot = __f.read()
_f.write(screenshot) _f.write(screenshot)
@@ -186,22 +187,24 @@ def test(
if done: if done:
logger.info("The episode is done.") logger.info("The episode is done.")
break break
try:
result = env.evaluate() result = env.evaluate()
except Exception as e:
logger.error(f"Error in evaluating the example {example_id}: {e}")
result = 0.0
logger.info("Result: %.2f", result) logger.info("Result: %.2f", result)
scores.append(result)
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
env.close() env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}") logger.info(f"Average score: {sum(scores) / len(scores)}")
def get_unfinished(test_file_list, result_dir): def get_unfinished(test, result_dir):
finished = [] # todo @xiaochuan
for domain in os.listdir(result_dir): pass
for example_id in os.listdir(os.path.join(result_dir, domain)):
finished.append(f"{domain}/{example_id}")
return [x for x in test_file_list if x not in finished]
if __name__ == '__main__': if __name__ == '__main__':