Refactor experiments and agent implementation
This commit is contained in:
@@ -5,21 +5,20 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
from PIL import Image
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
RateLimitError
|
||||
from vertexai.preview.generative_models import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
Image,
|
||||
)
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||
@@ -29,7 +28,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
|
||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
||||
|
||||
import logging
|
||||
# todo: cross-check with visualwebarena
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
@@ -42,7 +40,7 @@ def encode_image(image_path):
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree):
|
||||
#leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||
|
||||
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
||||
@@ -172,60 +170,56 @@ def parse_code_from_som_string(input_string, masks):
|
||||
class PromptAgent:
|
||||
def __init__(
|
||||
self,
|
||||
api_key,
|
||||
instruction,
|
||||
model="gpt-4-vision-preview",
|
||||
max_tokens=500,
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
temperature=0.5,
|
||||
action_space="computer_13",
|
||||
exp="screenshot_a11y_tree"
|
||||
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
||||
max_trajectory_length=3
|
||||
):
|
||||
|
||||
self.instruction = instruction
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.exp = exp
|
||||
self.max_trajectory_length = 3
|
||||
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
|
||||
if exp == "screenshot":
|
||||
if observation_type == "screenshot":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "a11y_tree":
|
||||
elif observation_type == "a11y_tree":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "both":
|
||||
elif observation_type == "both":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "som":
|
||||
elif observation_type == "som":
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif exp == "seeact":
|
||||
elif observation_type == "seeact":
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
@@ -233,15 +227,14 @@ class PromptAgent:
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + exp)
|
||||
raise ValueError("Invalid experiment type: " + observation_type)
|
||||
|
||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
||||
self.instruction)
|
||||
|
||||
def predict(self, obs: Dict) -> List:
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
||||
instruction)
|
||||
|
||||
# Prepare the payload for the API call
|
||||
messages = []
|
||||
@@ -273,7 +266,7 @@ class PromptAgent:
|
||||
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
||||
|
||||
# {{{1
|
||||
if self.exp == "both":
|
||||
if self.observation_type == "both":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||
@@ -295,7 +288,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp in ["som", "seeact"]:
|
||||
elif self.observation_type in ["som", "seeact"]:
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
||||
@@ -317,7 +310,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "screenshot":
|
||||
elif self.observation_type == "screenshot":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
|
||||
messages.append({
|
||||
@@ -336,7 +329,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "a11y_tree":
|
||||
elif self.observation_type == "a11y_tree":
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
|
||||
messages.append({
|
||||
@@ -350,7 +343,7 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
||||
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
@@ -363,11 +356,11 @@ class PromptAgent:
|
||||
})
|
||||
|
||||
# {{{1
|
||||
if self.exp in ["screenshot", "both"]:
|
||||
if self.observation_type in ["screenshot", "both"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
if self.exp == "both":
|
||||
if self.observation_type == "both":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -384,7 +377,7 @@ class PromptAgent:
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
||||
if self.exp == "screenshot"
|
||||
if self.observation_type == "screenshot"
|
||||
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
linearized_accessibility_tree)
|
||||
},
|
||||
@@ -397,7 +390,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "a11y_tree":
|
||||
elif self.observation_type == "a11y_tree":
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
@@ -415,7 +408,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "som":
|
||||
elif self.observation_type == "som":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
@@ -443,7 +436,7 @@ class PromptAgent:
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "seeact":
|
||||
elif self.observation_type == "seeact":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
@@ -471,21 +464,21 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
||||
|
||||
with open("messages.json", "w") as f:
|
||||
f.write(json.dumps(messages, indent=4))
|
||||
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
||||
|
||||
# with open("messages.json", "w") as f:
|
||||
# f.write(json.dumps(messages, indent=4))
|
||||
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
})
|
||||
|
||||
logger.debug("RESPONSE: %s", response)
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
if self.exp == "seeact":
|
||||
if self.observation_type == "seeact":
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
@@ -507,12 +500,15 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature
|
||||
})
|
||||
print(response)
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
try:
|
||||
actions = self.parse_actions(response, masks)
|
||||
@@ -527,85 +523,90 @@ class PromptAgent:
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
(Exception),
|
||||
max_tries=10
|
||||
max_tries=5
|
||||
)
|
||||
def call_llm(self, payload):
|
||||
|
||||
if self.model.startswith("gpt"):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||
}
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
if response.json()['error']['code'] == "context_length_exceeded":
|
||||
print("Context length exceeded. Retrying with a smaller context.")
|
||||
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||
payload["messages"] = payload["messages"][-1:]
|
||||
retry_response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
if retry_response.status_code != 200:
|
||||
print("Failed to call LLM: " + retry_response.text)
|
||||
logger.error("Failed to call LLM: " + retry_response.text)
|
||||
return ""
|
||||
|
||||
print("Failed to call LLM: " + response.text)
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
|
||||
elif self.model.startswith("mistral"):
|
||||
print("call mistral")
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
|
||||
misrtal_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
mistral_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
for part in message["content"]:
|
||||
mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||
|
||||
misrtal_messages.append(mistral_message)
|
||||
|
||||
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
if misrtal_messages[0]['role'] == "system":
|
||||
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||
misrtal_messages.pop(0)
|
||||
|
||||
# openai.api_base = "http://localhost:8000/v1"
|
||||
# openai.api_key = "test"
|
||||
# response = openai.ChatCompletion.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# )
|
||||
|
||||
from openai import OpenAI
|
||||
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||
|
||||
client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
logger.info("Generating content with Mistral model: %s", self.model)
|
||||
response = client.chat.completions.create(
|
||||
messages=misrtal_messages,
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
try:
|
||||
# return response['choices'][0]['message']['content']
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
return ""
|
||||
# elif self.model.startswith("mistral"):
|
||||
# print("Call mistral")
|
||||
# messages = payload["messages"]
|
||||
# max_tokens = payload["max_tokens"]
|
||||
#
|
||||
# misrtal_messages = []
|
||||
#
|
||||
# for i, message in enumerate(messages):
|
||||
# mistral_message = {
|
||||
# "role": message["role"],
|
||||
# "content": []
|
||||
# }
|
||||
#
|
||||
# for part in message["content"]:
|
||||
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||
#
|
||||
# misrtal_messages.append(mistral_message)
|
||||
#
|
||||
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
# if misrtal_messages[0]['role'] == "system":
|
||||
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||
# misrtal_messages.pop(0)
|
||||
#
|
||||
# # openai.api_base = "http://localhost:8000/v1"
|
||||
# # openai.api_key = "test"
|
||||
# # response = openai.ChatCompletion.create(
|
||||
# # messages=misrtal_messages,
|
||||
# # model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# # )
|
||||
#
|
||||
# from openai import OpenAI
|
||||
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||
#
|
||||
# client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||
# base_url='https://api.together.xyz',
|
||||
# )
|
||||
# logger.info("Generating content with Mistral model: %s", self.model)
|
||||
# response = client.chat.completions.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
# max_tokens=1024
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # return response['choices'][0]['message']['content']
|
||||
# return response.choices[0].message.content
|
||||
# except Exception as e:
|
||||
# print("Failed to call LLM: " + str(e))
|
||||
# return ""
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
@@ -617,6 +618,8 @@ class PromptAgent:
|
||||
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
gemini_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
@@ -662,7 +665,17 @@ class PromptAgent:
|
||||
response = genai.GenerativeModel(self.model).generate_content(
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"max_output_tokens": max_tokens
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
},
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -673,6 +686,8 @@ class PromptAgent:
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
qwen_messages = []
|
||||
|
||||
@@ -683,13 +698,16 @@ class PromptAgent:
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||
|
||||
qwen_messages.append(qwen_message)
|
||||
|
||||
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
|
||||
messages=messages)
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
model='qwen-vl-plus',
|
||||
messages=messages, # todo: add the hyperparameters
|
||||
)
|
||||
# The response status_code is HTTPStatus.OK indicate success,
|
||||
# otherwise indicate request is failed, you can get error code
|
||||
# and message from code and message.
|
||||
@@ -708,7 +726,7 @@ class PromptAgent:
|
||||
|
||||
def parse_actions(self, response: str, masks=None):
|
||||
|
||||
if self.exp in ["screenshot", "a11y_tree", "both"]:
|
||||
if self.observation_type in ["screenshot", "a11y_tree", "both"]:
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
actions = parse_actions_from_string(response)
|
||||
@@ -720,7 +738,7 @@ class PromptAgent:
|
||||
self.actions.append(actions)
|
||||
|
||||
return actions
|
||||
elif self.exp in ["som", "seeact"]:
|
||||
elif self.observation_type in ["som", "seeact"]:
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + self.action_space)
|
||||
@@ -732,3 +750,8 @@ class PromptAgent:
|
||||
self.actions.append(actions)
|
||||
|
||||
return actions
|
||||
|
||||
def reset(self):
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
|
||||
Reference in New Issue
Block a user