Refactor experiments and agent implementation

This commit is contained in:
Timothyxxx
2024-03-14 22:32:49 +08:00
parent 71ca8fbe1c
commit 44ff027801
8 changed files with 359 additions and 1944 deletions

View File

@@ -5,21 +5,20 @@ import os
import re
import time
import uuid
import xml.etree.ElementTree as ET
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
import xml.etree.ElementTree as ET
import backoff
import dashscope
import google.generativeai as genai
import openai
import requests
from PIL import Image
from openai import (
APIConnectionError,
APIError,
RateLimitError
from vertexai.preview.generative_models import (
HarmBlockThreshold,
HarmCategory,
Image,
)
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
@@ -29,7 +28,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
import logging
# todo: cross-check with visualwebarena
logger = logging.getLogger("desktopenv.agent")
@@ -42,7 +40,7 @@ def encode_image(image_path):
def linearize_accessibility_tree(accessibility_tree):
#leaf_nodes = find_leaf_nodes(accessibility_tree)
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
@@ -172,60 +170,56 @@ def parse_code_from_som_string(input_string, masks):
class PromptAgent:
def __init__(
self,
api_key,
instruction,
model="gpt-4-vision-preview",
max_tokens=500,
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="computer_13",
exp="screenshot_a11y_tree"
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
max_trajectory_length=3
):
self.instruction = instruction
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.exp = exp
self.max_trajectory_length = 3
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.thoughts = []
self.actions = []
self.observations = []
if exp == "screenshot":
if observation_type == "screenshot":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "a11y_tree":
elif observation_type == "a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "both":
elif observation_type == "both":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "som":
elif observation_type == "som":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "seeact":
elif observation_type == "seeact":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
@@ -233,15 +227,14 @@ class PromptAgent:
else:
raise ValueError("Invalid action space: " + action_space)
else:
raise ValueError("Invalid experiment type: " + exp)
raise ValueError("Invalid experiment type: " + observation_type)
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
self.instruction)
def predict(self, obs: Dict) -> List:
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
instruction)
# Prepare the payload for the API call
messages = []
@@ -273,7 +266,7 @@ class PromptAgent:
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1
if self.exp == "both":
if self.observation_type == "both":
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
@@ -295,7 +288,7 @@ class PromptAgent:
}
]
})
elif self.exp in ["som", "seeact"]:
elif self.observation_type in ["som", "seeact"]:
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
@@ -317,7 +310,7 @@ class PromptAgent:
}
]
})
elif self.exp == "screenshot":
elif self.observation_type == "screenshot":
_screenshot = previous_obs["screenshot"]
messages.append({
@@ -336,7 +329,7 @@ class PromptAgent:
}
]
})
elif self.exp == "a11y_tree":
elif self.observation_type == "a11y_tree":
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
messages.append({
@@ -350,7 +343,7 @@ class PromptAgent:
]
})
else:
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
messages.append({
"role": "assistant",
@@ -363,11 +356,11 @@ class PromptAgent:
})
# {{{1
if self.exp in ["screenshot", "both"]:
if self.observation_type in ["screenshot", "both"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
if self.exp == "both":
if self.observation_type == "both":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
@@ -384,7 +377,7 @@ class PromptAgent:
{
"type": "text",
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
if self.exp == "screenshot"
if self.observation_type == "screenshot"
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
@@ -397,7 +390,7 @@ class PromptAgent:
}
]
})
elif self.exp == "a11y_tree":
elif self.observation_type == "a11y_tree":
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
@@ -415,7 +408,7 @@ class PromptAgent:
}
]
})
elif self.exp == "som":
elif self.observation_type == "som":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
@@ -443,7 +436,7 @@ class PromptAgent:
}
]
})
elif self.exp == "seeact":
elif self.observation_type == "seeact":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
@@ -471,21 +464,21 @@ class PromptAgent:
]
})
else:
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
with open("messages.json", "w") as f:
f.write(json.dumps(messages, indent=4))
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
# with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4))
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
})
logger.debug("RESPONSE: %s", response)
logger.info("RESPONSE: %s", response)
if self.exp == "seeact":
if self.observation_type == "seeact":
messages.append({
"role": "assistant",
"content": [
@@ -507,12 +500,15 @@ class PromptAgent:
]
})
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
print(response)
logger.info("RESPONSE: %s", response)
try:
actions = self.parse_actions(response, masks)
@@ -527,85 +523,90 @@ class PromptAgent:
@backoff.on_exception(
backoff.expo,
(Exception),
max_tries=10
max_tries=5
)
def call_llm(self, payload):
if self.model.startswith("gpt"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
}
logger.info("Generating content with GPT model: %s", self.model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
headers=headers,
json=payload
)
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
print("Context length exceeded. Retrying with a smaller context.")
logger.error("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
headers=headers,
json=payload
)
if retry_response.status_code != 200:
print("Failed to call LLM: " + retry_response.text)
logger.error("Failed to call LLM: " + retry_response.text)
return ""
print("Failed to call LLM: " + response.text)
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['choices'][0]['message']['content']
elif self.model.startswith("mistral"):
print("call mistral")
messages = payload["messages"]
max_tokens = payload["max_tokens"]
misrtal_messages = []
for i, message in enumerate(messages):
mistral_message = {
"role": message["role"],
"content": []
}
for part in message["content"]:
mistral_message['content'] = part['text'] if part['type'] == "text" else None
misrtal_messages.append(mistral_message)
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
if misrtal_messages[0]['role'] == "system":
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
misrtal_messages.pop(0)
# openai.api_base = "http://localhost:8000/v1"
# openai.api_key = "test"
# response = openai.ChatCompletion.create(
# messages=misrtal_messages,
# model="Mixtral-8x7B-Instruct-v0.1"
# )
from openai import OpenAI
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
client = OpenAI(api_key=TOGETHER_API_KEY,
base_url='https://api.together.xyz',
)
logger.info("Generating content with Mistral model: %s", self.model)
response = client.chat.completions.create(
messages=misrtal_messages,
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
max_tokens=1024
)
try:
# return response['choices'][0]['message']['content']
return response.choices[0].message.content
except Exception as e:
print("Failed to call LLM: " + str(e))
return ""
# elif self.model.startswith("mistral"):
# print("Call mistral")
# messages = payload["messages"]
# max_tokens = payload["max_tokens"]
#
# misrtal_messages = []
#
# for i, message in enumerate(messages):
# mistral_message = {
# "role": message["role"],
# "content": []
# }
#
# for part in message["content"]:
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
#
# misrtal_messages.append(mistral_message)
#
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
# if misrtal_messages[0]['role'] == "system":
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
# misrtal_messages.pop(0)
#
# # openai.api_base = "http://localhost:8000/v1"
# # openai.api_key = "test"
# # response = openai.ChatCompletion.create(
# # messages=misrtal_messages,
# # model="Mixtral-8x7B-Instruct-v0.1"
# # )
#
# from openai import OpenAI
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
#
# client = OpenAI(api_key=TOGETHER_API_KEY,
# base_url='https://api.together.xyz',
# )
# logger.info("Generating content with Mistral model: %s", self.model)
# response = client.chat.completions.create(
# messages=misrtal_messages,
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
# max_tokens=1024
# )
#
# try:
# # return response['choices'][0]['message']['content']
# return response.choices[0].message.content
# except Exception as e:
# print("Failed to call LLM: " + str(e))
# return ""
elif self.model.startswith("gemini"):
def encoded_img_to_pil_img(data_str):
@@ -617,6 +618,8 @@ class PromptAgent:
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
gemini_messages = []
for i, message in enumerate(messages):
@@ -662,7 +665,17 @@ class PromptAgent:
response = genai.GenerativeModel(self.model).generate_content(
gemini_messages,
generation_config={
"max_output_tokens": max_tokens
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
)
@@ -673,6 +686,8 @@ class PromptAgent:
elif self.model.startswith("qwen"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
qwen_messages = []
@@ -683,13 +698,16 @@ class PromptAgent:
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]:
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
'type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message)
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
messages=messages)
response = dashscope.MultiModalConversation.call(
model='qwen-vl-plus',
messages=messages, # todo: add the hyperparameters
)
# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code
# and message from code and message.
@@ -708,7 +726,7 @@ class PromptAgent:
def parse_actions(self, response: str, masks=None):
if self.exp in ["screenshot", "a11y_tree", "both"]:
if self.observation_type in ["screenshot", "a11y_tree", "both"]:
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
@@ -720,7 +738,7 @@ class PromptAgent:
self.actions.append(actions)
return actions
elif self.exp in ["som", "seeact"]:
elif self.observation_type in ["som", "seeact"]:
# parse from the response
if self.action_space == "computer_13":
raise ValueError("Invalid action space: " + self.action_space)
@@ -732,3 +750,8 @@ class PromptAgent:
self.actions.append(actions)
return actions
def reset(self):
self.thoughts = []
self.actions = []
self.observations = []