629 lines
22 KiB
Python
629 lines
22 KiB
Python
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import tempfile
|
|
import time
|
|
from http import HTTPStatus
|
|
from io import BytesIO
|
|
from typing import Dict, List, Tuple
|
|
|
|
import backoff
|
|
import openai
|
|
import requests
|
|
from PIL import Image
|
|
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
|
|
from requests.exceptions import SSLError
|
|
from mm_agents.prompts import (
|
|
AGUVIS_PLANNER_SYS_PROMPT,
|
|
AGUVIS_SYS_PROMPT,
|
|
AGUVIS_PLANNING_PROMPT,
|
|
AGUVIS_INNER_MONOLOGUE_APPEND_PROMPT,
|
|
AGUVIS_GROUNDING_PROMPT,
|
|
AGUVIS_GROUNDING_APPEND_PROMPT
|
|
)
|
|
|
|
logger = None
|
|
|
|
|
|
# Function to encode the image
|
|
def encode_image(image_content):
|
|
return base64.b64encode(image_content).decode('utf-8')
|
|
|
|
|
|
def encoded_img_to_pil_img(data_str):
|
|
base64_str = data_str.replace("data:image/png;base64,", "")
|
|
image_data = base64.b64decode(base64_str)
|
|
image = Image.open(BytesIO(image_data))
|
|
|
|
return image
|
|
|
|
|
|
def save_to_tmp_img_file(data_str):
|
|
base64_str = data_str.replace("data:image/png;base64,", "")
|
|
image_data = base64.b64decode(base64_str)
|
|
image = Image.open(BytesIO(image_data))
|
|
|
|
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
|
|
image.save(tmp_img_path)
|
|
|
|
return tmp_img_path
|
|
|
|
|
|
# FIXME: hardcoded screen size and planner system message
|
|
SCREEN_LOGIC_SIZE = (1280, 720)
|
|
|
|
|
|
def parse_code_from_planner_response(input_string: str) -> List[str]:
|
|
"""Parse the planner's response containing executable pyautogui code"""
|
|
|
|
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
|
|
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
|
return [input_string.strip()]
|
|
|
|
# This regular expression will match both ```code``` and ```python code```
|
|
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
|
pattern = r"```(?:\w+\s+)?(.*?)```"
|
|
# Find all non-overlapping matches in the string
|
|
matches = re.findall(pattern, input_string, re.DOTALL)
|
|
|
|
# The regex above captures the content inside the triple backticks.
|
|
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
|
# so the code inside backticks can span multiple lines.
|
|
|
|
# matches now contains all the captured code snippets
|
|
codes = []
|
|
|
|
for match in matches:
|
|
match = match.strip()
|
|
commands = ['WAIT', 'DONE', 'FAIL']
|
|
|
|
if match in commands:
|
|
codes.append(match.strip())
|
|
elif match.split('\n')[-1] in commands:
|
|
if len(match.split('\n')) > 1:
|
|
codes.append("\n".join(match.split('\n')[:-1]))
|
|
codes.append(match.split('\n')[-1])
|
|
else:
|
|
codes.append(match)
|
|
|
|
return codes
|
|
|
|
|
|
def parse_aguvis_response(input_string, screen_logic_size=SCREEN_LOGIC_SIZE) -> Tuple[str, List[str]]:
|
|
if input_string.lower().startswith("wait"):
|
|
return "WAIT", "WAIT"
|
|
elif input_string.lower().startswith("done"):
|
|
return "DONE", "DONE"
|
|
elif input_string.lower().startswith("fail"):
|
|
return "FAIL", "FAIL"
|
|
|
|
try:
|
|
lines = input_string.strip().split("\n")
|
|
lines = [line for line in lines if line.strip() != ""]
|
|
low_level_instruction = lines[0]
|
|
|
|
pyautogui_index = -1
|
|
|
|
for i, line in enumerate(lines):
|
|
if line.strip() == "assistantos" or line.strip().startswith("pyautogui"):
|
|
pyautogui_index = i
|
|
break
|
|
|
|
if pyautogui_index == -1:
|
|
print(f"Error: Could not parse response {input_string}")
|
|
return None, None
|
|
|
|
pyautogui_code_relative_coordinates = "\n".join(lines[pyautogui_index:])
|
|
pyautogui_code_relative_coordinates = pyautogui_code_relative_coordinates.replace("assistantos", "").strip()
|
|
corrected_code = correct_pyautogui_arguments(pyautogui_code_relative_coordinates)
|
|
|
|
parsed_action = _pyautogui_code_to_absolute_coordinates(corrected_code, screen_logic_size)
|
|
return low_level_instruction, parsed_action
|
|
except Exception as e:
|
|
print(f"Error: Could not parse response {input_string}")
|
|
return None, None
|
|
|
|
def correct_pyautogui_arguments(code: str) -> str:
|
|
function_corrections = {
|
|
'write': {
|
|
'incorrect_args': ['text'],
|
|
'correct_args': [],
|
|
'keyword_arg': 'message'
|
|
},
|
|
'press': {
|
|
'incorrect_args': ['key', 'button'],
|
|
'correct_args': [],
|
|
'keyword_arg': None
|
|
},
|
|
'hotkey': {
|
|
'incorrect_args': ['key1', 'key2', 'keys'],
|
|
'correct_args': [],
|
|
'keyword_arg': None
|
|
},
|
|
}
|
|
|
|
lines = code.strip().split('\n')
|
|
corrected_lines = []
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
match = re.match(r'(pyautogui\.(\w+))\((.*)\)', line)
|
|
if match:
|
|
full_func_call = match.group(1)
|
|
func_name = match.group(2)
|
|
args_str = match.group(3)
|
|
|
|
if func_name in function_corrections:
|
|
func_info = function_corrections[func_name]
|
|
args = split_args(args_str)
|
|
corrected_args = []
|
|
|
|
for arg in args:
|
|
arg = arg.strip()
|
|
kwarg_match = re.match(r'(\w+)\s*=\s*(.*)', arg)
|
|
if kwarg_match:
|
|
arg_name = kwarg_match.group(1)
|
|
arg_value = kwarg_match.group(2)
|
|
|
|
if arg_name in func_info['incorrect_args']:
|
|
if func_info['keyword_arg']:
|
|
corrected_args.append(f"{func_info['keyword_arg']}={arg_value}")
|
|
else:
|
|
corrected_args.append(arg_value)
|
|
else:
|
|
corrected_args.append(f'{arg_name}={arg_value}')
|
|
else:
|
|
corrected_args.append(arg)
|
|
|
|
corrected_args_str = ', '.join(corrected_args)
|
|
corrected_line = f'{full_func_call}({corrected_args_str})'
|
|
corrected_lines.append(corrected_line)
|
|
else:
|
|
corrected_lines.append(line)
|
|
else:
|
|
corrected_lines.append(line)
|
|
|
|
corrected_code = '\n'.join(corrected_lines)
|
|
return corrected_code
|
|
|
|
def split_args(args_str: str) -> List[str]:
|
|
args = []
|
|
current_arg = ''
|
|
within_string = False
|
|
string_char = ''
|
|
prev_char = ''
|
|
for char in args_str:
|
|
if char in ['"', "'"]:
|
|
if not within_string:
|
|
within_string = True
|
|
string_char = char
|
|
elif within_string and prev_char != '\\' and char == string_char:
|
|
within_string = False
|
|
if char == ',' and not within_string:
|
|
args.append(current_arg)
|
|
current_arg = ''
|
|
else:
|
|
current_arg += char
|
|
prev_char = char
|
|
if current_arg:
|
|
args.append(current_arg)
|
|
return args
|
|
|
|
def extract_coordinates(text, logical_screen_size=SCREEN_LOGIC_SIZE) -> Tuple[int, int] | None:
|
|
# Pattern to match (x=0.1, y=0.2) or (0.1, 0.2) format
|
|
text = text.strip()
|
|
logger.info(f"Extracting coordinates from: {text}")
|
|
pattern = r'\((?:x=)?([-+]?\d*\.\d+|\d+)(?:,\s*(?:y=)?([-+]?\d*\.\d+|\d+))?\)'
|
|
|
|
match = re.search(pattern, text)
|
|
if match:
|
|
x = int(float(match.group(1)) * logical_screen_size[0])
|
|
y = int(float(match.group(2)) * logical_screen_size[1]) if match.group(2) else None
|
|
|
|
if y is not None:
|
|
return (x, y)
|
|
|
|
logger.info(f"Error: No coordinates found in: {text}")
|
|
return None
|
|
|
|
|
|
def _pyautogui_code_to_absolute_coordinates(pyautogui_code_relative_coordinates, logical_screen_size=SCREEN_LOGIC_SIZE):
|
|
"""
|
|
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
|
|
"""
|
|
import re
|
|
import ast
|
|
|
|
width, height = logical_screen_size
|
|
|
|
pattern = r'(pyautogui\.\w+\([^\)]*\))'
|
|
|
|
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
|
|
|
|
new_code = pyautogui_code_relative_coordinates
|
|
|
|
for full_call in matches:
|
|
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
|
|
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
|
|
if not func_match:
|
|
continue
|
|
|
|
func_name = func_match.group(1)
|
|
args_str = func_match.group(2)
|
|
|
|
try:
|
|
parsed = ast.parse(f"func({args_str})").body[0].value
|
|
parsed_args = parsed.args
|
|
parsed_keywords = parsed.keywords
|
|
except SyntaxError:
|
|
continue
|
|
|
|
function_parameters = {
|
|
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
|
|
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
|
|
'moveRel': ['xOffset', 'yOffset', 'duration', 'tween', 'pause'],
|
|
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
|
|
'dragRel': ['xOffset', 'yOffset', 'duration', 'button', 'mouseDownUp', 'pause'],
|
|
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
|
}
|
|
|
|
func_base_name = func_name.split('.')[-1]
|
|
|
|
param_names = function_parameters.get(func_base_name, [])
|
|
|
|
args = {}
|
|
for idx, arg in enumerate(parsed_args):
|
|
if idx < len(param_names):
|
|
param_name = param_names[idx]
|
|
arg_value = ast.literal_eval(arg)
|
|
args[param_name] = arg_value
|
|
|
|
for kw in parsed_keywords:
|
|
param_name = kw.arg
|
|
arg_value = ast.literal_eval(kw.value)
|
|
args[param_name] = arg_value
|
|
|
|
updated = False
|
|
if 'x' in args:
|
|
try:
|
|
x_rel = float(args['x'])
|
|
x_abs = int(round(x_rel * width))
|
|
args['x'] = x_abs
|
|
updated = True
|
|
except ValueError:
|
|
pass
|
|
if 'y' in args:
|
|
try:
|
|
y_rel = float(args['y'])
|
|
y_abs = int(round(y_rel * height))
|
|
args['y'] = y_abs
|
|
updated = True
|
|
except ValueError:
|
|
pass
|
|
if 'xOffset' in args:
|
|
try:
|
|
x_rel = float(args['xOffset'])
|
|
x_abs = int(round(x_rel * width))
|
|
args['xOffset'] = x_abs
|
|
updated = True
|
|
except ValueError:
|
|
pass
|
|
if 'yOffset' in args:
|
|
try:
|
|
y_rel = float(args['yOffset'])
|
|
y_abs = int(round(y_rel * height))
|
|
args['yOffset'] = y_abs
|
|
updated = True
|
|
except ValueError:
|
|
pass
|
|
|
|
if updated:
|
|
reconstructed_args = []
|
|
for idx, param_name in enumerate(param_names):
|
|
if param_name in args:
|
|
arg_value = args[param_name]
|
|
if isinstance(arg_value, str):
|
|
arg_repr = f"'{arg_value}'"
|
|
else:
|
|
arg_repr = str(arg_value)
|
|
reconstructed_args.append(arg_repr)
|
|
else:
|
|
break
|
|
|
|
used_params = set(param_names[:len(reconstructed_args)])
|
|
for kw in parsed_keywords:
|
|
if kw.arg not in used_params:
|
|
arg_value = args[kw.arg]
|
|
if isinstance(arg_value, str):
|
|
arg_repr = f"{kw.arg}='{arg_value}'"
|
|
else:
|
|
arg_repr = f"{kw.arg}={arg_value}"
|
|
reconstructed_args.append(arg_repr)
|
|
|
|
new_args_str = ', '.join(reconstructed_args)
|
|
new_full_call = f"{func_name}({new_args_str})"
|
|
new_code = new_code.replace(full_call, new_full_call)
|
|
|
|
return new_code
|
|
|
|
|
|
class AguvisAgent:
|
|
def __init__(
|
|
self,
|
|
platform="ubuntu",
|
|
planner_model="gpt-4o",
|
|
executor_model="qwen-aguvis-7b",
|
|
max_tokens=1500,
|
|
top_p=0.9,
|
|
temperature=0.5,
|
|
action_space="pyautogui",
|
|
observation_type="screenshot",
|
|
):
|
|
self.platform = platform
|
|
self.planner_model = planner_model
|
|
self.executor_model = executor_model
|
|
assert self.executor_model is not None, "Executor model cannot be None"
|
|
self.max_tokens = max_tokens
|
|
self.top_p = top_p
|
|
self.temperature = temperature
|
|
self.action_space = action_space
|
|
self.observation_type = observation_type
|
|
assert action_space in ["pyautogui"], "Invalid action space"
|
|
assert observation_type in ["screenshot"], "Invalid observation type"
|
|
self.thoughts = []
|
|
self.actions = []
|
|
self.observations = []
|
|
|
|
def predict(self, instruction: str, obs: Dict) -> List:
|
|
"""
|
|
Predict the next action(s) based on the current observation.
|
|
"""
|
|
previous_actions = "\n".join([f"Step {i+1}: {action}" for i, action in enumerate(self.actions)]) if self.actions else "None"
|
|
|
|
if self.planner_model is None:
|
|
aguvis_messages = []
|
|
aguvis_messages.append({
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": AGUVIS_SYS_PROMPT}]
|
|
})
|
|
aguvis_messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": AGUVIS_PLANNING_PROMPT.format(
|
|
instruction=instruction,
|
|
previous_actions=previous_actions,
|
|
)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}"}
|
|
}
|
|
],
|
|
})
|
|
aguvis_messages.append({
|
|
"role": "assistant",
|
|
"content": [
|
|
{"type": "text", "text": AGUVIS_INNER_MONOLOGUE_APPEND_PROMPT}
|
|
]
|
|
})
|
|
aguvis_response = self.call_llm({
|
|
"model": self.executor_model,
|
|
"messages": aguvis_messages,
|
|
"max_tokens": self.max_tokens,
|
|
"top_p": self.top_p,
|
|
"temperature": self.temperature
|
|
}, self.executor_model)
|
|
logger.info(f"Aguvis Output: {aguvis_response}")
|
|
low_level_instruction, pyautogui_actions = parse_aguvis_response(aguvis_response)
|
|
|
|
self.actions.append(low_level_instruction)
|
|
return aguvis_response, [pyautogui_actions]
|
|
else:
|
|
# FIXME [junli]:
|
|
# Using an external planner (GPT-4o) requires relying on more
|
|
# detailed prompt to provide Aguvis with low level instructions.
|
|
# So we temporarily separate the planner prompt and aguvis prompt.
|
|
|
|
planner_messages = []
|
|
planner_system_message = AGUVIS_PLANNER_SYS_PROMPT
|
|
planner_messages.append({
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": planner_system_message}]
|
|
})
|
|
planner_messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"You are asked to complete the following task: {instruction}"
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
)
|
|
planner_response = self.call_llm({
|
|
"model": self.planner_model,
|
|
"messages": planner_messages,
|
|
"max_tokens": self.max_tokens,
|
|
"top_p": self.top_p,
|
|
"temperature": self.temperature
|
|
}, self.planner_model)
|
|
logger.info(f"Planner output: {planner_response}")
|
|
code = parse_code_from_planner_response(planner_response)
|
|
pyautogui_actions = []
|
|
for line in code:
|
|
code = self.convert_action_to_grounding_model_instruction(
|
|
line,
|
|
obs,
|
|
instruction,
|
|
)
|
|
pyautogui_actions.append(code)
|
|
|
|
return "", pyautogui_actions
|
|
|
|
def convert_action_to_grounding_model_instruction(
|
|
self, line: str, obs: Dict, instruction: str
|
|
) -> str:
|
|
pattern = r'(#.*?)\n(pyautogui\.(moveTo|click|rightClick)\((?:x=)?(\d+)(?:,\s*|\s*,\s*y=)(\d+)(?:,\s*duration=[\d.]+)?\))'
|
|
matches = re.findall(pattern, line, re.DOTALL)
|
|
if not matches:
|
|
return line
|
|
new_instruction = line
|
|
for match in matches:
|
|
comment = match[0].split("#")[1].strip()
|
|
original_action = match[1]
|
|
func_name = match[2].strip()
|
|
|
|
if "click()" in original_action.lower():
|
|
continue # Skip click() without coordinates
|
|
|
|
aguvis_messages = []
|
|
aguvis_messages.append({
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": AGUVIS_SYS_PROMPT}]
|
|
})
|
|
aguvis_messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}",
|
|
"detail": "high",
|
|
},
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": '\n' + comment,
|
|
},
|
|
],
|
|
}
|
|
)
|
|
aguvis_messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": [
|
|
{"type": "text", "text": AGUVIS_GROUNDING_APPEND_PROMPT.format(function_name=func_name)}
|
|
],
|
|
}
|
|
)
|
|
grounding_response = self.call_llm({
|
|
"model": self.executor_model,
|
|
"messages": aguvis_messages,
|
|
"max_tokens": self.max_tokens,
|
|
"top_p": self.top_p,
|
|
"temperature": self.temperature
|
|
}, self.executor_model)
|
|
coordinates = extract_coordinates(grounding_response, SCREEN_LOGIC_SIZE)
|
|
# FIXME [junli]: Use ast to reconstruct the action with coordinates
|
|
action_parts = original_action.split('(')
|
|
new_action = f"{action_parts[0]}({coordinates[0]}, {coordinates[1]}"
|
|
if len(action_parts) > 1 and 'duration' in action_parts[1]:
|
|
duration_part = action_parts[1].split(',')[-1]
|
|
new_action += f", {duration_part}"
|
|
elif len(action_parts) > 1 and 'button' in action_parts[1]:
|
|
button_part = action_parts[1].split(',')[-1]
|
|
new_action += f", {button_part}"
|
|
else:
|
|
new_action += ")"
|
|
logger.info(new_action)
|
|
new_instruction = new_instruction.replace(original_action, new_action)
|
|
|
|
return new_instruction
|
|
|
|
@backoff.on_exception(
|
|
backoff.constant,
|
|
# here you should add more model exceptions as you want,
|
|
# but you are forbidden to add "Exception", that is, a common type of exception
|
|
# because we want to catch this kind of Exception in the outside to ensure
|
|
# each example won't exceed the time limit
|
|
(
|
|
# General exceptions
|
|
SSLError,
|
|
|
|
# OpenAI exceptions
|
|
openai.RateLimitError,
|
|
openai.BadRequestError,
|
|
openai.InternalServerError,
|
|
|
|
# Google exceptions
|
|
InvalidArgument,
|
|
ResourceExhausted,
|
|
InternalServerError,
|
|
BadRequest,
|
|
|
|
# Groq exceptions
|
|
# todo: check
|
|
),
|
|
interval=30,
|
|
max_tries=10
|
|
)
|
|
def call_llm(self, payload, model):
|
|
if model.startswith("gpt"):
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
|
# "Authorization": f"Bearer {os.environ['MIT_SPIDER_TOKEN']}"
|
|
}
|
|
logger.info("Generating content with GPT model: %s", model)
|
|
response = requests.post(
|
|
"https://api.openai.com/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
logger.error("Failed to call LLM: " + response.text)
|
|
time.sleep(5)
|
|
return ""
|
|
else:
|
|
return response.json()['choices'][0]['message']['content']
|
|
elif "aguvis" in model:
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
logger.info("Generating content with Aguvis model: %s", model)
|
|
|
|
if "7b" in model:
|
|
response = requests.post(
|
|
"http://101.132.136.195:7908/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload
|
|
)
|
|
elif "72b" in model:
|
|
response = requests.post(
|
|
"http://123.57.10.166:7908/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload
|
|
)
|
|
else:
|
|
raise Exception("Unsupported Aguvis model version")
|
|
|
|
if response.status_code != 200:
|
|
logger.error("Failed to call LLM: " + response.text)
|
|
time.sleep(5)
|
|
return ""
|
|
else:
|
|
return response.json()['choices'][0]['message']['content']
|
|
|
|
def reset(self, _logger=None):
|
|
global logger
|
|
logger = _logger if _logger is not None else logging.getLogger("desktopenv.aguvis_agent")
|
|
|
|
self.thoughts = []
|
|
self.action_descriptions = []
|
|
self.actions = []
|
|
self.observations = []
|