Add Jedi agent implementation to mm_agents (#192)

* feat: implement Jedi agent

* chore: code clean
This commit is contained in:
MillanK
2025-05-10 19:55:33 +08:00
committed by GitHub
parent 5678b510d7
commit 51f5ddea04
4 changed files with 1171 additions and 1 deletions

233
mm_agents/img_utils.py Normal file
View File

@@ -0,0 +1,233 @@
import math
from typing import List, Union, Dict, Any
def round_by_factor(number: int, factor: int) -> int:
"""返回最接近 number 的且能被 factor 整除的整数"""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""返回大于等于 number 的且能被 factor 整除的整数"""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""返回小于等于 number 的且能被 factor 整除的整数"""
return math.floor(number / factor) * factor
def smart_resize(height, width, factor=28, min_pixels=56 * 56, max_pixels=14 * 14 * 4 * 1280, max_long_side=8192):
"""缩放后图片满足以下条件:
1. 长宽能被 factor 整除
2. pixels 总数被限制在 [min_pixels, max_pixels] 内
3. 最长边限制在 max_long_side 内
4. 保证其长宽比基本不变
"""
if height < 2 or width < 2:
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
elif max(height, width) / min(height, width) > 200:
raise ValueError(f"absolute aspect ratio must be smaller than 100, got {height} / {width}")
if max(height, width) > max_long_side:
beta = max(height, width) / max_long_side
height, width = int(height / beta), int(width / beta)
h_bar = round_by_factor(height, factor)
w_bar = round_by_factor(width, factor)
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def update_image_size_(image_ele: dict, min_tokens=1, max_tokens=12800, merge_base=2, patch_size=14):
"""根据 min_tokens, max_tokens 更新 image_ele 的尺寸信息
Args:
image_ele (dict):
- image_ele["image"]: str 图片路径
- image_ele["height"]: int 图片原始高度
- image_ele["width"]: int 图片原始宽度
Returns:
更新后的 image_ele, 新增如下 key-value pair
dict:
- image_ele["resized_height"]: int 输入到模型的真实高度
- image_ele["resized_width"]: int 输入到模型的真实宽度
- image_ele["seq_len"]: int 输入到模型所占的序列长度
"""
height, width = image_ele["height"], image_ele["width"]
pixels_per_token = patch_size * patch_size * merge_base * merge_base
resized_height, resized_width = smart_resize(
height,
width,
factor=merge_base * patch_size,
min_pixels=pixels_per_token * min_tokens,
max_pixels=pixels_per_token * max_tokens,
max_long_side=50000,
)
image_ele.update(
{
"resized_height": resized_height,
"resized_width": resized_width,
"seq_len": resized_height * resized_width // pixels_per_token + 2,
}
)
return image_ele
def _convert_bbox_format_from_abs_origin(bbox, image_ele: dict, *, tgt_format: str):
x1, y1, x2, y2 = bbox
if tgt_format == "abs_origin":
new_bbox = [int(x1), int(y1), int(x2), int(y2)]
elif tgt_format == "abs_resized":
new_bbox = [
int(x1 / image_ele["width"] * image_ele["resized_width"]),
int(y1 / image_ele["height"] * image_ele["resized_height"]),
int(x2 / image_ele["width"] * image_ele["resized_width"]),
int(y2 / image_ele["height"] * image_ele["resized_height"]),
]
elif tgt_format == "qwen-vl":
new_bbox = [
int(x1 / image_ele["width"] * 999),
int(y1 / image_ele["height"] * 999),
int(x2 / image_ele["width"] * 999),
int(y2 / image_ele["height"] * 999),
]
elif tgt_format == "rel":
new_bbox = [
float(x1 / image_ele["width"]),
float(y1 / image_ele["height"]),
float(x2 / image_ele["width"]),
float(y2 / image_ele["height"]),
]
elif tgt_format == "molmo":
new_bbox = [
round(x1 / image_ele["width"] * 100, ndigits=1),
round(y1 / image_ele["height"] * 100, ndigits=1),
round(x2 / image_ele["width"] * 100, ndigits=1),
round(y2 / image_ele["height"] * 100, ndigits=1),
]
else:
assert False, f"Unknown tgt_format: {tgt_format}"
return new_bbox
def _convert_bbox_format_to_abs_origin(bbox, image_ele: dict, *, src_format: str):
x1, y1, x2, y2 = bbox
if src_format == "abs_origin":
new_bbox = [int(x1), int(y1), int(x2), int(y2)]
elif src_format == "abs_resized":
new_bbox = [
int(x1 / image_ele["resized_width"] * image_ele["width"]),
int(y1 / image_ele["resized_height"] * image_ele["height"]),
int(x2 / image_ele["resized_width"] * image_ele["width"]),
int(y2 / image_ele["resized_height"] * image_ele["height"]),
]
elif src_format == "qwen-vl":
new_bbox = [
int(x1 / 999 * image_ele["width"]),
int(y1 / 999 * image_ele["height"]),
int(x2 / 999 * image_ele["width"]),
int(y2 / 999 * image_ele["height"]),
]
elif src_format == "rel":
new_bbox = [
int(x1 * image_ele["width"]),
int(y1 * image_ele["height"]),
int(x2 * image_ele["width"]),
int(y2 * image_ele["height"]),
]
elif src_format == "molmo":
new_bbox = [
int(x1 / 100 * image_ele["width"]),
int(y1 / 100 * image_ele["height"]),
int(x2 / 100 * image_ele["width"]),
int(y2 / 100 * image_ele["height"]),
]
else:
assert False, f"Unknown src_format: {src_format}"
return new_bbox
def convert_bbox_format(bbox, image_ele: dict, *, src_format: str, tgt_format: str):
bbox_abs_origin = _convert_bbox_format_to_abs_origin(bbox, image_ele, src_format=src_format)
bbox_tgt_format = _convert_bbox_format_from_abs_origin(bbox_abs_origin, image_ele, tgt_format=tgt_format)
return bbox_tgt_format
def _convert_point_format_from_abs_origin(point, image_ele: dict, *, tgt_format: str):
x, y = point
if tgt_format == "abs_origin":
new_point = [int(x), int(y)]
elif tgt_format == "abs_resized":
new_point = [
int(x / image_ele["width"] * image_ele["resized_width"]),
int(y / image_ele["height"] * image_ele["resized_height"]),
]
elif tgt_format == "qwen-vl":
new_point = [
int(x / image_ele["width"] * 999),
int(y / image_ele["height"] * 999),
]
elif tgt_format == "rel":
new_point = [
float(x / image_ele["width"]),
float(y / image_ele["height"]),
]
elif tgt_format == "molmo":
new_point = [
round(x / image_ele["width"] * 100, ndigits=1),
round(y / image_ele["height"] * 100, ndigits=1),
]
else:
assert False, f"Unknown tgt_format: {tgt_format}"
return new_point
def _convert_point_format_to_abs_origin(point, image_ele: dict, *, src_format: str):
x, y = point
if src_format == "abs_origin":
new_point = [int(x), int(y)]
elif src_format == "abs_resized":
new_point = [
int(x / image_ele["resized_width"] * image_ele["width"]),
int(y / image_ele["resized_height"] * image_ele["height"]),
]
elif src_format == "qwen-vl":
new_point = [
int(x / 999 * image_ele["width"]),
int(y / 999 * image_ele["height"]),
]
elif src_format == "rel":
new_point = [
int(x * image_ele["width"]),
int(y * image_ele["height"]),
]
elif src_format == "molmo":
new_point = [
int(x / 100 * image_ele["width"]),
int(y / 100 * image_ele["height"]),
]
else:
assert False, f"Unknown src_format: {src_format}"
return new_point
def convert_point_format(point, image_ele: dict, *, src_format: str, tgt_format: str):
point_abs_origin = _convert_point_format_to_abs_origin(point, image_ele, src_format=src_format)
point_tgt_format = _convert_point_format_from_abs_origin(point_abs_origin, image_ele, tgt_format=tgt_format)
return point_tgt_format
__all__ = [
"update_image_size_",
"convert_bbox_format",
"convert_point_format",
]

441
mm_agents/jedi_3b_agent.py Normal file
View File

@@ -0,0 +1,441 @@
import base64
import json
import logging
import os
import re
import time
from io import BytesIO
from typing import Dict, List
import backoff
import openai
import requests
from PIL import Image
from google.api_core.exceptions import (
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
)
from requests.exceptions import SSLError
logger = None
OPENAI_API_KEY = "Your OpenAI API Key"
JEDI_API_KEY = "Your Jedi API Key"
JEDI_SERVICE_URL = "Your Jedi Service URL"
from mm_agents.prompts import JEDI_PLANNER_SYS_PROMPT, JEDI_GROUNDER_SYS_PROMPT
from mm_agents.img_utils import smart_resize
def encode_image(image_content):
return base64.b64encode(image_content).decode("utf-8")
class JediAgent3B:
def __init__(
self,
platform="ubuntu",
planner_model="gpt-4o",
executor_model="jedi-3b",
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="pyautogui",
observation_type="screenshot",
max_steps=15,
):
self.platform = platform
self.planner_model = planner_model
self.executor_model = executor_model
assert self.executor_model is not None, "Executor model cannot be None"
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
self.thoughts = []
self.actions = []
self.observations = []
self.observation_captions = []
self.max_image_history_length = 5
self.current_step = 1
self.max_steps = max_steps
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
# get the width and height of the screenshot
image = Image.open(BytesIO(obs["screenshot"]))
width, height = image.convert("RGB").size
previous_actions = ("\n".join([
f"Step {i+1}: {action}" for i, action in enumerate(self.actions)
]) if self.actions else "None")
user_prompt = (
f"""Please generate the next move according to the UI screenshot and instruction. And you can refer to the previous actions and observations for reflection.\n\nInstruction: {instruction}\n\n""")
messages = [{
"role": "system",
"content": [{
"type": "text",
"text": JEDI_PLANNER_SYS_PROMPT.replace("{current_step}", str(self.current_step)).replace("{max_steps}", str(self.max_steps))
}]
}]
# Determine which observations to include images for (only most recent ones)
obs_start_idx = max(0, len(self.observations) - self.max_image_history_length)
# Add all thought and action history
for i in range(len(self.thoughts)):
# For recent steps, include the actual screenshot
if i >= obs_start_idx:
messages.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}",
"detail": "high"
},
}]
})
# For older steps, use the observation caption instead of the image
else:
messages.append({
"role": "user",
"content": [{
"type": "text",
"text": f"Observation: {self.observation_captions[i]}"
}]
})
thought_messages = f"Thought:\n{self.thoughts[i]}"
action_messages = f"Action:"
for action in self.actions[i]:
action_messages += f"\n{action}"
messages.append({
"role": "assistant",
"content": [{
"type": "text",
"text": thought_messages + "\n" + action_messages
}]
})
#print(thought_messages + "\n" + action_messages)
messages.append({
"role":"user",
"content": [
{
"type":"image_url",
"image_url":{
"url":f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high"
},
},
{
"type": "text",
"text": user_prompt
},
],
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
# Add retry logic if no codes were parsed
retry_count = 0
max_retries = 5
while not codes and retry_count < max_retries:
logger.info(f"No codes parsed from planner response. Retrying ({retry_count+1}/{max_retries})...")
messages.append({
"role": "user",
"content": [
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
]
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Retry Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
retry_count += 1
thought = self.parse_thought_from_planner_response(planner_response)
observation_caption = self.parse_observation_caption_from_planner_response(planner_response)
resized_height, resized_width = smart_resize(height, width, max_pixels= 2700 * 28 * 28)
pyautogui_actions = []
for line in codes:
code = self.convert_action_to_grounding_model_instruction(
line,
obs,
instruction,
height,
width,
resized_height,
resized_width
)
pyautogui_actions.append(code)
self.actions.append([pyautogui_actions])
self.observations.append(obs)
self.thoughts.append(thought)
self.observation_captions.append(observation_caption)
self.current_step += 1
return planner_response, pyautogui_actions, {}
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_thought_from_planner_response(self, input_string: str) -> str:
pattern = r"Thought:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_code_from_planner_response(self, input_string: str) -> List[str]:
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
pattern = r"```(?:\w+\s+)?(.*?)```"
matches = re.findall(pattern, input_string, re.DOTALL)
codes = []
for match in matches:
match = match.strip()
commands = ['WAIT', 'DONE', 'FAIL']
if match in commands:
codes.append(match.strip())
elif match.split('\n')[-1] in commands:
if len(match.split('\n')) > 1:
codes.append("\n".join(match.split('\n')[:-1]))
codes.append(match.split('\n')[-1])
else:
codes.append(match)
return codes
def convert_action_to_grounding_model_instruction(self, line: str, obs: Dict, instruction: str, height: int, width: int, resized_height: int, resized_width: int ) -> str:
pattern = r'(#.*?)\n(pyautogui\.(moveTo|click|rightClick|doubleClick|middleClick|dragTo)\((?:x=)?(\d+)(?:,\s*|\s*,\s*y=)(\d+)(?:,\s*duration=[\d.]+)?\))'
matches = re.findall(pattern, line, re.DOTALL)
if not matches:
return line
new_instruction = line
for match in matches:
comment = match[0].split("#")[1].strip()
original_action = match[1]
func_name = match[2].strip()
if "click()" in original_action.lower():
continue
messages = []
messages.append({
"role": "system",
"content": [{"type": "text", "text": JEDI_GROUNDER_SYS_PROMPT.replace("{height}", str(resized_height)).replace("{width}", str(resized_width))}]
})
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high",
},
},
{
"type": "text",
"text": '\n' + comment,
},
],
}
)
grounding_response = self.call_llm({
"model": self.executor_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}, self.executor_model)
coordinates = self.parse_jedi_response(grounding_response, height, width, resized_width, resized_height)
logger.info(coordinates)
if coordinates == [-1, -1]:
continue
action_parts = original_action.split('(')
new_action = f"{action_parts[0]}({coordinates[0]}, {coordinates[1]}"
if len(action_parts) > 1 and 'duration' in action_parts[1]:
duration_part = action_parts[1].split(',')[-1]
new_action += f", {duration_part}"
elif len(action_parts) > 1 and 'button' in action_parts[1]:
button_part = action_parts[1].split(',')[-1]
new_action += f", {button_part}"
else:
new_action += ")"
logger.info(new_action)
new_instruction = new_instruction.replace(original_action, new_action)
return new_instruction
def parse_jedi_response(self, response, width: int, height: int, resized_width: int, resized_height: int) -> List[str]:
"""
Parse the LLM response and convert it to low level action and pyautogui code.
"""
low_level_instruction = ""
pyautogui_code = []
try:
# Define possible tag combinations
start_tags = ["<tool_call>", ""]
end_tags = ["</tool_call>", ""]
# Find valid start and end tags
start_tag = next((tag for tag in start_tags if tag in response), None)
end_tag = next((tag for tag in end_tags if tag in response), None)
if not start_tag or not end_tag:
print("Missing valid start or end tags in the response")
return [-1, -1]
# Split the response to extract low_level_instruction and tool_call
parts = response.split(start_tag)
if len(parts) < 2:
print("Missing start tag in the response")
return [-1, -1]
low_level_instruction = parts[0].strip().replace("Action: ", "")
tool_call_str = parts[1].split(end_tag)[0].strip()
# Fix for double curly braces and clean up JSON string
tool_call_str = tool_call_str.replace("{{", "{").replace("}}", "}")
tool_call_str = tool_call_str.replace("\n", "").replace("\r", "").strip()
try:
tool_call = json.loads(tool_call_str)
action = tool_call.get("arguments", {}).get("action", "")
args = tool_call.get("arguments", {})
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
# Try an alternative parsing approach
try:
# Try to extract the coordinate directly using regex
import re
coordinate_match = re.search(r'"coordinate":\s*\[(\d+),\s*(\d+)\]', tool_call_str)
if coordinate_match:
x = int(coordinate_match.group(1))
y = int(coordinate_match.group(2))
x = int(x * width / resized_width)
y = int(y * height / resized_height)
return [x, y]
except Exception as inner_e:
print(f"Alternative parsing method also failed: {inner_e}")
return [-1, -1]
# convert the coordinate to the original resolution
x = int(args.get("coordinate", [-1, -1])[0] * width / resized_width)
y = int(args.get("coordinate", [-1, -1])[1] * height / resized_height)
return [x, y]
except Exception as e:
logger.error(f"Failed to parse response: {e}")
return [-1, -1]
@backoff.on_exception(
backoff.constant,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure
# each example won't exceed the time limit
(
# General exceptions
SSLError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
# Groq exceptions
# todo: check
),
interval=30,
max_tries=10,
)
def call_llm(self, payload, model):
if model.startswith("gpt"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}",
}
logger.info("Generating content with GPT model: %s", model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()["choices"][0]["message"]["content"]
elif model.startswith("jedi"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {JEDI_API_KEY}",
}
response = requests.post(
f"{JEDI_SERVICE_URL}/v1/chat/completions",
headers=headers,
json=payload,
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()["choices"][0]["message"]["content"]
def reset(self, _logger=None):
global logger
logger = (_logger if _logger is not None else
logging.getLogger("desktopenv.jedi_3b_agent"))
self.thoughts = []
self.action_descriptions = []
self.actions = []
self.observations = []
self.observation_captions = []

427
mm_agents/jedi_7b_agent.py Normal file
View File

@@ -0,0 +1,427 @@
import base64
import json
import logging
import os
import re
import time
from io import BytesIO
from typing import Dict, List
import backoff
import openai
import requests
from PIL import Image
from google.api_core.exceptions import (
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
)
from requests.exceptions import SSLError
logger = None
OPENAI_API_KEY = "Your OpenAI API Key"
JEDI_API_KEY = "Your Jedi API Key"
JEDI_SERVICE_URL = "Your Jedi Service URL"
from mm_agents.prompts import JEDI_PLANNER_SYS_PROMPT, JEDI_GROUNDER_SYS_PROMPT
from mm_agents.img_utils import smart_resize
def encode_image(image_content):
return base64.b64encode(image_content).decode("utf-8")
class JediAgent7B:
def __init__(
self,
platform="ubuntu",
planner_model="gpt-4o",
executor_model="jedi-7b",
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="pyautogui",
observation_type="screenshot",
max_steps=15
):
self.platform = platform
self.planner_model = planner_model
self.executor_model = executor_model
assert self.executor_model is not None, "Executor model cannot be None"
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
self.thoughts = []
self.actions = []
self.observations = []
self.observation_captions = []
self.max_image_history_length = 5
self.current_step = 1
self.max_steps = max_steps
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
# get the width and height of the screenshot
image = Image.open(BytesIO(obs["screenshot"]))
width, height = image.convert("RGB").size
previous_actions = ("\n".join([
f"Step {i+1}: {action}" for i, action in enumerate(self.actions)
]) if self.actions else "None")
user_prompt = (
f"""Please generate the next move according to the UI screenshot and instruction. And you can refer to the previous actions and observations for reflection.\n\nInstruction: {instruction}\n\n""")
messages = [{
"role": "system",
"content": [{
"type": "text",
"text": JEDI_PLANNER_SYS_PROMPT.replace("{current_step}", str(self.current_step)).replace("{max_steps}", str(self.max_steps))
}]
}]
# Determine which observations to include images for (only most recent ones)
obs_start_idx = max(0, len(self.observations) - self.max_image_history_length)
# Add all thought and action history
for i in range(len(self.thoughts)):
# For recent steps, include the actual screenshot
if i >= obs_start_idx:
messages.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}",
"detail": "high"
},
}]
})
# For older steps, use the observation caption instead of the image
else:
messages.append({
"role": "user",
"content": [{
"type": "text",
"text": f"Observation: {self.observation_captions[i]}"
}]
})
thought_messages = f"Thought:\n{self.thoughts[i]}"
action_messages = f"Action:"
for action in self.actions[i]:
action_messages += f"\n{action}"
messages.append({
"role": "assistant",
"content": [{
"type": "text",
"text": thought_messages + "\n" + action_messages
}]
})
#print(thought_messages + "\n" + action_messages)
messages.append({
"role":"user",
"content": [
{
"type":"image_url",
"image_url":{
"url":f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high"
},
},
{
"type": "text",
"text": user_prompt
},
],
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
# Add retry logic if no codes were parsed
retry_count = 0
max_retries = 5
while not codes and retry_count < max_retries:
logger.info(f"No codes parsed from planner response. Retrying ({retry_count+1}/{max_retries})...")
messages.append({
"role": "user",
"content": [
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
]
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Retry Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
retry_count += 1
thought = self.parse_thought_from_planner_response(planner_response)
observation_caption = self.parse_observation_caption_from_planner_response(planner_response)
resized_height, resized_width = smart_resize(height, width, max_pixels= 2700 * 28 * 28)
pyautogui_actions = []
for line in codes:
code = self.convert_action_to_grounding_model_instruction(
line,
obs,
instruction,
height,
width,
resized_height,
resized_width
)
pyautogui_actions.append(code)
self.actions.append([pyautogui_actions])
self.observations.append(obs)
self.thoughts.append(thought)
self.observation_captions.append(observation_caption)
self.current_step += 1
return planner_response, pyautogui_actions, {}
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_thought_from_planner_response(self, input_string: str) -> str:
pattern = r"Thought:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_code_from_planner_response(self, input_string: str) -> List[str]:
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
pattern = r"```(?:\w+\s+)?(.*?)```"
matches = re.findall(pattern, input_string, re.DOTALL)
codes = []
for match in matches:
match = match.strip()
commands = ['WAIT', 'DONE', 'FAIL']
if match in commands:
codes.append(match.strip())
elif match.split('\n')[-1] in commands:
if len(match.split('\n')) > 1:
codes.append("\n".join(match.split('\n')[:-1]))
codes.append(match.split('\n')[-1])
else:
codes.append(match)
return codes
def convert_action_to_grounding_model_instruction(self, line: str, obs: Dict, instruction: str, height: int, width: int, resized_height: int, resized_width: int ) -> str:
pattern = r'(#.*?)\n(pyautogui\.(moveTo|click|rightClick|doubleClick|middleClick|dragTo)\((?:x=)?(\d+)(?:,\s*|\s*,\s*y=)(\d+)(?:,\s*duration=[\d.]+)?\))'
matches = re.findall(pattern, line, re.DOTALL)
if not matches:
return line
new_instruction = line
for match in matches:
comment = match[0].split("#")[1].strip()
original_action = match[1]
func_name = match[2].strip()
if "click()" in original_action.lower():
continue
messages = []
messages.append({
"role": "system",
"content": [{"type": "text", "text": JEDI_GROUNDER_SYS_PROMPT.replace("{height}", str(resized_height)).replace("{width}", str(resized_width))}]
})
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high",
},
},
{
"type": "text",
"text": '\n' + comment,
},
],
}
)
grounding_response = self.call_llm({
"model": self.executor_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}, self.executor_model)
coordinates = self.parse_jedi_response(grounding_response, width, height, resized_width, resized_height)
logger.info(coordinates)
if coordinates == [-1, -1]:
continue
action_parts = original_action.split('(')
new_action = f"{action_parts[0]}({coordinates[0]}, {coordinates[1]}"
if len(action_parts) > 1 and 'duration' in action_parts[1]:
duration_part = action_parts[1].split(',')[-1]
new_action += f", {duration_part}"
elif len(action_parts) > 1 and 'button' in action_parts[1]:
button_part = action_parts[1].split(',')[-1]
new_action += f", {button_part}"
else:
new_action += ")"
logger.info(new_action)
new_instruction = new_instruction.replace(original_action, new_action)
return new_instruction
def parse_jedi_response(self, response, width: int, height: int, resized_width: int, resized_height: int) -> List[str]:
"""
Parse the LLM response and convert it to low level action and pyautogui code.
"""
low_level_instruction = ""
pyautogui_code = []
try:
# 定义可能的标签组合
start_tags = ["<tool_call>", ""]
end_tags = ["</tool_call>", ""]
# 找到有效的开始和结束标签
start_tag = next((tag for tag in start_tags if tag in response), None)
end_tag = next((tag for tag in end_tags if tag in response), None)
if not start_tag or not end_tag:
print("The response is missing valid start or end tags")
return low_level_instruction, pyautogui_code
# 分割响应以提取low_level_instruction和tool_call
parts = response.split(start_tag)
if len(parts) < 2:
print("The response is missing the start tag")
return low_level_instruction, pyautogui_code
low_level_instruction = parts[0].strip().replace("Action: ", "")
tool_call_str = parts[1].split(end_tag)[0].strip()
try:
tool_call = json.loads(tool_call_str)
action = tool_call.get("arguments", {}).get("action", "")
args = tool_call.get("arguments", {})
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
# 处理解析错误,返回默认值或空值
action = ""
args = {}
# convert the coordinate to the original resolution
x = int(args.get("coordinate", [-1, -1])[0] * width / resized_width)
y = int(args.get("coordinate", [-1, -1])[1] * height / resized_height)
return [x, y]
except Exception as e:
logger.error(f"Failed to parse response: {e}")
return [-1, -1]
@backoff.on_exception(
backoff.constant,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure
# each example won't exceed the time limit
(
# General exceptions
SSLError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
# Groq exceptions
# todo: check
),
interval=30,
max_tries=10,
)
def call_llm(self, payload, model):
if model.startswith("gpt"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"
}
logger.info("Generating content with GPT model: %s", model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()["choices"][0]["message"]["content"]
elif model.startswith("jedi"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {JEDI_API_KEY}"
}
response = requests.post(
f"{JEDI_SERVICE_URL}/v1/chat/completions",
headers=headers,
json=payload,
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()["choices"][0]["message"]["content"]
def reset(self, _logger=None):
global logger
logger = (_logger if _logger is not None else
logging.getLogger("desktopenv.jedi_7b_agent"))
self.thoughts = []
self.action_descriptions = []
self.actions = []
self.observations = []
self.observation_captions = []

View File

@@ -1268,4 +1268,73 @@ Action: ...
## User Instruction
{instruction}
"""
"""
JEDI_GROUNDER_SYS_PROMPT = """You are a helpful assistant.
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{{"type": "function", "function": {{"name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {width}x{height}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {{"properties": {{"action": {{"description": "The action to perform. The available actions are:\n* `key`: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* `type`: Type a string of text on the keyboard.\n* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* `left_click`: Click the left mouse button.\n* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* `right_click`: Click the right mouse button.\n* `middle_click`: Click the middle mouse button.\n* `double_click`: Double-click the left mouse button.\n* `scroll`: Performs a scroll of the mouse scroll wheel.\n* `wait`: Wait specified seconds for the change to happen.\n* `terminate`: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}}, "keys": {{"description": "Required only by `action=key`.", "type": "array"}}, "text": {{"description": "Required only by `action=type`.", "type": "string"}}, "coordinate": {{"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=mouse_move`, `action=left_click_drag`, `action=left_click`, `action=right_click`, `action=double_click`.", "type": "array"}}, "pixels": {{"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by `action=scroll`.", "type": "number"}}, "time": {{"description": "The seconds to wait. Required only by `action=wait`.", "type": "number"}}, "status": {{"description": "The status of the task. Required only by `action=terminate`.", "type": "string", "enum": ["success", "failure"]}}}}, "required": ["action"], "type": "object"}}}}}}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{{"name": <function-name>, "arguments": <args-json-object>}}
</tool_call>"""
JEDI_PLANNER_SYS_PROMPT = """
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
You are on Ubuntu operating system and the resolution of the screen is 1920x1080.
For each step, you will get an observation of an image, which is the screenshot of the computer screen and you will predict the action of the computer based on the image.
The following rules are IMPORTANT:
- If previous actions didn't achieve the expected result, do not repeat them, especially the last one. Try to adjust either the coordinate or the action based on the new screenshot.
- Do not predict multiple clicks at once. Base each action on the current screenshot; do not predict actions for elements or events not yet visible in the screenshot.
- You cannot complete the task by outputting text content in your response. You must use mouse and keyboard to interact with the computer. Return ```Fail``` when you think the task can not be done.
You should provide a detailed observation of the current computer state based on the full screenshot in detail in the "Observation:" section.
Provide any information that is possibly relevant to achieving the task goal and any elements that may affect the task execution, such as pop-ups, notifications, error messages, loading states, etc..
You MUST return the observation before the thought.
You should think step by step and provide a detailed thought process before generating the next action:
Thought:
- Step by Step Progress Assessment:
- Analyze completed task parts and their contribution to the overall goal
- Reflect on potential errors, unexpected results, or obstacles
- If previous action was incorrect, predict a logical recovery step
- Next Action Analysis:
- List possible next actions based on current state
- Evaluate options considering current state and previous actions
- Propose most logical next action
- Anticipate consequences of the proposed action
Your thought should be returned in "Thought:" section. You MUST return the thought before the code.
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
Return exactly ONE line of python code to perform the action each time. At each step, you MUST generate the corresponding instruction to the code before a # in a comment (example: # Click \"Yes, I trust the authors\" button\npyautogui.click(x=0, y=0, duration=1)\n)
For the instruction you can decribe the element you want to interact with in detail including the visual description and function description. And make it clear and concise.
For example you can describe what the element looks like, and what will be the expected result when you interact with it.
You need to to specify the coordinates of by yourself based on your observation of current observation, but you should be careful to ensure that the coordinates are correct.
Remember you should only return ONE line of code, DO NOT RETURN more. You should return the code inside a code block, like this:
```python
# your code here
```
Specially, it is also allowed to return the following special code:
When you think you have to wait for some time, return ```WAIT```;
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
When you think the task is done, return ```DONE```.
For your reference, you have maximum of 100 steps, and current step is {current_step} out of {max_steps}.
If you are in the last step, you should return ```DONE``` or ```FAIL``` according to the result.
Here are some guidelines for you:
1. Remember to generate the corresponding instruction to the code before a # in a comment and only return ONE line of code.
2. If a click action is needed, use only the following functions: pyautogui.click, pyautogui.rightClick or pyautogui.doubleClick.
3. Return ```Done``` when you think the task is done. Return ```Fail``` when you think the task can not be done.
My computer's password is 'password', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR NEVER EVER RETURN ME ANYTHING ELSE.
"""