Files
sci-gui-agent-benchmark/mm_agents/anthropic/main.py
Yuan Mengqi 0a37cccd53 update claude (#280)
* add uitars agent code

* improve claude

* improve claude

* improve claude

* improve claude

* improve claude
2025-07-23 03:35:49 +08:00

618 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import os
import time
from typing import Any, cast, Optional, Dict
from PIL import Image
import io
from anthropic import (
Anthropic,
AnthropicBedrock,
AnthropicVertex,
APIError,
APIResponseValidationError,
APIStatusError,
)
from anthropic.types.beta import (
BetaMessageParam,
BetaTextBlockParam,
)
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME
from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to_n_most_recent_images
import logging
logger = logging.getLogger("desktopenv.agent")
# MAX_HISTORY = 10
API_RETRY_TIMES = 500
API_RETRY_INTERVAL = 5
class AnthropicAgent:
def __init__(self,
platform: str = "Ubuntu",
model: str = "claude-3-5-sonnet-20241022",
provider: APIProvider = APIProvider.BEDROCK,
max_tokens: int = 4096,
api_key: str = os.environ.get("ANTHROPIC_API_KEY", None),
system_prompt_suffix: str = "",
only_n_most_recent_images: Optional[int] = 10,
action_space: str = "claude_computer_use",
screen_size: tuple[int, int] = (1920, 1080),
*args, **kwargs
):
self.platform = platform
self.action_space = action_space
self.logger = logger
self.class_name = self.__class__.__name__
self.model_name = model
self.provider = provider
self.max_tokens = max_tokens
self.api_key = api_key
self.system_prompt_suffix = system_prompt_suffix
self.only_n_most_recent_images = only_n_most_recent_images
self.messages: list[BetaMessageParam] = []
self.screen_size = screen_size
self.resize_factor = (
screen_size[0] / 1280, # Assuming 1280 is the base width
screen_size[1] / 720 # Assuming 720 is the base height
)
def add_tool_result(self, tool_call_id: str, result: str, screenshot: bytes = None):
"""Add tool result to message history"""
tool_result_content = [
{
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": [{"type": "text", "text": result}]
}
]
# Add screenshot if provided
if screenshot is not None:
screenshot_base64 = base64.b64encode(screenshot).decode('utf-8')
tool_result_content[0]["content"].append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": screenshot_base64
}
})
self.messages.append({
"role": "user",
"content": tool_result_content
})
def parse_actions_from_tool_call(self, tool_call: Dict) -> str:
result = ""
function_args = (
tool_call["input"]
)
action = function_args.get("action")
if not action:
action = tool_call.function.name
action_conversion = {
"left click": "click",
"right click": "right_click"
}
action = action_conversion.get(action, action)
text = function_args.get("text")
coordinate = function_args.get("coordinate")
scroll_direction = function_args.get("scroll_direction")
scroll_amount = function_args.get("scroll_amount")
duration = function_args.get("duration")
# resize coordinates if resize_factor is set
if coordinate and self.resize_factor:
coordinate = (
int(coordinate[0] * self.resize_factor[0]),
int(coordinate[1] * self.resize_factor[1])
)
if action == "left_mouse_down":
result += "pyautogui.mouseDown()\n"
elif action == "left_mouse_up":
result += "pyautogui.mouseUp()\n"
elif action == "hold_key":
if not isinstance(text, str):
raise ValueError(f"{text} must be a string")
keys = text.split('+')
for key in keys:
key = key.strip().lower()
result += f"pyautogui.keyDown('{key}')\n"
expected_outcome = f"Keys {text} held down."
# Handle mouse move and drag actions
elif action in ("mouse_move", "left_click_drag"):
if coordinate is None:
raise ValueError(f"coordinate is required for {action}")
if text is not None:
raise ValueError(f"text is not accepted for {action}")
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
raise ValueError(f"{coordinate} must be a tuple of length 2")
if not all(isinstance(i, int) for i in coordinate):
raise ValueError(f"{coordinate} must be a tuple of ints")
x, y = coordinate[0], coordinate[1]
if action == "mouse_move":
result += (
f"pyautogui.moveTo({x}, {y}, duration={duration or 0.5})\n"
)
expected_outcome = f"Mouse moved to ({x},{y})."
elif action == "left_click_drag":
result += (
f"pyautogui.dragTo({x}, {y}, duration={duration or 0.5})\n"
)
expected_outcome = f"Cursor dragged to ({x},{y})."
# Handle keyboard actions
elif action in ("key", "type"):
if text is None:
raise ValueError(f"text is required for {action}")
if coordinate is not None:
raise ValueError(f"coordinate is not accepted for {action}")
if not isinstance(text, str):
raise ValueError(f"{text} must be a string")
if action == "key":
key_conversion = {
"page_down": "pagedown",
"page_up": "pageup",
"super_l": "win",
"super": "command",
"escape": "esc"
}
keys = text.split('+')
for key in keys:
key = key.strip().lower()
key = key_conversion.get(key, key)
result += (f"pyautogui.keyDown('{key}')\n")
for key in reversed(keys):
key = key.strip().lower()
key = key_conversion.get(key, key)
result += (f"pyautogui.keyUp('{key}')\n")
expected_outcome = f"Key {key} pressed."
elif action == "type":
result += (
f"pyautogui.typewrite(\"\"\"{text}\"\"\", interval=0.01)\n"
)
expected_outcome = f"Text {text} written."
# Handle scroll actions
elif action == "scroll":
if coordinate is None:
if scroll_direction in ("up", "down"):
result += (
f"pyautogui.scroll({scroll_amount if scroll_direction == 'up' else -scroll_amount})\n"
)
elif scroll_direction in ("left", "right"):
result += (
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount})\n"
)
else:
if scroll_direction in ("up", "down"):
x, y = coordinate[0], coordinate[1]
result += (
f"pyautogui.scroll({scroll_amount if scroll_direction == 'up' else -scroll_amount}, {x}, {y})\n"
)
elif scroll_direction in ("left", "right"):
x, y = coordinate[0], coordinate[1]
result += (
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount}, {x}, {y})\n"
)
expected_outcome = "Scroll action finished"
# Handle click actions
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press", "triple_click"):
if coordinate is not None:
x, y = coordinate
if action == "left_click":
result += (f"pyautogui.click({x}, {y})\n")
elif action == "right_click":
result += (f"pyautogui.rightClick({x}, {y})\n")
elif action == "double_click":
result += (f"pyautogui.doubleClick({x}, {y})\n")
elif action == "middle_click":
result += (f"pyautogui.middleClick({x}, {y})\n")
elif action == "left_press":
result += (f"pyautogui.mouseDown({x}, {y})\n")
result += ("time.sleep(1)\n")
result += (f"pyautogui.mouseUp({x}, {y})\n")
elif action == "triple_click":
result += (f"pyautogui.tripleClick({x}, {y})\n")
else:
if action == "left_click":
result += ("pyautogui.click()\n")
elif action == "right_click":
result += ("pyautogui.rightClick()\n")
elif action == "double_click":
result += ("pyautogui.doubleClick()\n")
elif action == "middle_click":
result += ("pyautogui.middleClick()\n")
elif action == "left_press":
result += ("pyautogui.mouseDown()\n")
result += ("time.sleep(1)\n")
result += ("pyautogui.mouseUp()\n")
elif action == "triple_click":
result += ("pyautogui.tripleClick()\n")
expected_outcome = "Click action finished"
elif action == "wait":
result += "pyautogui.sleep(0.5)\n"
expected_outcome = "Wait for 0.5 seconds"
elif action == "fail":
result += "FAIL"
expected_outcome = "Finished"
elif action == "done":
result += "DONE"
expected_outcome = "Finished"
elif action == "call_user":
result += "CALL_USER"
expected_outcome = "Call user"
elif action == "screenshot":
result += "pyautogui.sleep(0.1)\n"
expected_outcome = "Screenshot taken"
else:
raise ValueError(f"Invalid action: {action}")
return result
def _trim_history(self, max_rounds=4):
messages = self.messages
if not messages or len(messages) <= 1:
return
# 计算需要保留的最近轮次数
actual_max_rounds = max_rounds * 2
# 如果消息数量不超过限制,不需要处理
if len(messages) <= actual_max_rounds:
return
# 保留前3条消息初始消息和最近的actual_max_rounds条消息 messages[0:1] + messages[-actual_max_rounds:]
keep_messages = []
# 对于中间被删除的消息,只保留非图片内容
for i in range(1, len(messages) - actual_max_rounds):
old_message = messages[i]
if old_message["role"] == "user" and "content" in old_message:
# 过滤掉image类型的内容块保留其他类型
filtered_content = []
for content_block in old_message["content"]:
filtered_content_item = []
if content_block.get("type") == "tool_result":
for content_block_item in content_block["content"]:
if content_block_item.get("type") != "image":
filtered_content_item.append(content_block_item)
filtered_content.append({
"type": content_block.get("type"),
"tool_use_id": content_block.get("tool_use_id"),
"content": filtered_content_item
})
else:
filtered_content.append(content_block)
# 如果过滤后还有内容,则保留这条消息
if filtered_content:
keep_messages.append({
"role": old_message["role"],
"content": filtered_content
})
else:
# 非用户消息或没有content的消息直接保留
keep_messages.append(old_message)
self.messages = messages[0:1] + keep_messages + messages[-actual_max_rounds:]
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
system = BetaTextBlockParam(
type="text",
text=f"{SYSTEM_PROMPT_WINDOWS if self.platform == 'Windows' else SYSTEM_PROMPT}{' ' + self.system_prompt_suffix if self.system_prompt_suffix else ''}"
)
# resize screenshot if resize_factor is set
if obs and "screenshot" in obs:
# Convert bytes to PIL Image
screenshot_bytes = obs["screenshot"]
screenshot_image = Image.open(io.BytesIO(screenshot_bytes))
# Calculate new size based on resize factor
new_width, new_height = 1280, 720
# Resize the image
resized_image = screenshot_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Convert back to bytes
output_buffer = io.BytesIO()
resized_image.save(output_buffer, format='PNG')
obs["screenshot"] = output_buffer.getvalue()
if not self.messages:
init_screenshot = obs
init_screenshot_base64 = base64.b64encode(init_screenshot["screenshot"]).decode('utf-8')
self.messages.append({
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": init_screenshot_base64,
},
},
{"type": "text", "text": task_instruction},
]
})
if self.messages and "tool_use" in [content_block["type"] for content_block in self.messages[-1]["content"]]:
self.add_tool_result(
self.messages[-1]["content"][-1]["id"],
f"Success",
screenshot=obs.get("screenshot") if obs else None
)
enable_prompt_caching = False
betas = ["computer-use-2025-01-24"]
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
betas = ["computer-use-2025-01-24"]
elif self.model_name == "claude-3-5-sonnet-20241022":
betas = [COMPUTER_USE_BETA_FLAG]
image_truncation_threshold = 10
if self.provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=self.api_key, max_retries=4)
enable_prompt_caching = True
elif self.provider == APIProvider.VERTEX:
client = AnthropicVertex()
elif self.provider == APIProvider.BEDROCK:
client = AnthropicBedrock(
# Authenticate by either providing the keys below or use the default AWS credential providers, such as
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
aws_access_key=os.getenv('AWS_ACCESS_KEY_ID'),
aws_secret_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
# aws_region changes the aws region to which the request is made. By default, we read AWS_REGION,
# and if that's not present, we default to us-east-1. Note that we do not read ~/.aws/config for the region.
aws_region=os.getenv('AWS_DEFAULT_REGION'),
)
if enable_prompt_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
_inject_prompt_caching(self.messages)
image_truncation_threshold = 50
system["cache_control"] = {"type": "ephemeral"}
if self.only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(
self.messages,
self.only_n_most_recent_images,
min_removal_threshold=image_truncation_threshold,
)
#self._trim_history(max_rounds=MAX_HISTORY)
try:
if self.model_name == "claude-3-5-sonnet-20241022":
tools = [
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
# {'type': 'bash_20241022', 'name': 'bash'},
# {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}
] if self.platform == 'Ubuntu' else [
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
]
elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
tools = [
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
# {'type': 'bash_20250124', 'name': 'bash'},
# {'name': 'str_replace_editor', 'type': 'text_editor_20250124'}
] if self.platform == 'Ubuntu' else [
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
]
extra_body = {
"thinking": {"type": "enabled", "budget_tokens": 1024}
}
response = None
for attempt in range(API_RETRY_TIMES):
try:
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
logger.info(f"Response: {response}")
break # 成功则跳出重试循环
except (APIError, APIStatusError, APIResponseValidationError) as e:
error_msg = str(e)
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
# 检查是否是25MB限制错误
if "25000000" in error_msg or "Member must have length less than or equal to" in error_msg:
logger.warning("检测到25MB限制错误自动裁剪图片数量")
# 将图片数量减半
current_image_count = self.only_n_most_recent_images
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
self.only_n_most_recent_images = new_image_count
# 重新应用图片过滤
_maybe_filter_to_n_most_recent_images(
self.messages,
new_image_count,
min_removal_threshold=image_truncation_threshold,
)
logger.info(f"图片数量已从 {current_image_count} 减少到 {new_image_count}")
if attempt < API_RETRY_TIMES - 1:
time.sleep(API_RETRY_INTERVAL)
else:
raise # 全部失败后抛出异常进入原有except逻辑
except (APIError, APIStatusError, APIResponseValidationError) as e:
logger.exception(f"Anthropic API error: {str(e)}")
try:
logger.warning("Retrying with backup API key...")
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
response = backup_client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = backup_client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
logger.info("Successfully used backup API key")
except Exception as backup_e:
backup_error_msg = str(backup_e)
logger.exception(f"Backup API call also failed: {backup_error_msg}")
# 检查备用API是否也是25MB限制错误
if "25000000" in backup_error_msg or "Member must have length less than or equal to" in backup_error_msg:
logger.warning("备用API也遇到25MB限制错误进一步裁剪图片数量")
# 将图片数量再减半
current_image_count = self.only_n_most_recent_images
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
self.only_n_most_recent_images = new_image_count
# 重新应用图片过滤
_maybe_filter_to_n_most_recent_images(
self.messages,
new_image_count,
min_removal_threshold=image_truncation_threshold,
)
logger.info(f"备用API图片数量已从 {current_image_count} 减少到 {new_image_count}")
return None, None
except Exception as e:
logger.exception(f"Error in Anthropic API: {str(e)}")
return None, None
response_params = _response_to_params(response)
logger.info(f"Received response params: {response_params}")
# Store response in message history
self.messages.append({
"role": "assistant",
"content": response_params
})
max_parse_retry = 3
for parse_retry in range(max_parse_retry):
actions: list[Any] = []
reasonings: list[str] = []
try:
for content_block in response_params:
if content_block["type"] == "tool_use":
actions.append({
"name": content_block["name"],
"input": cast(dict[str, Any], content_block["input"]),
"id": content_block["id"],
"action_type": content_block.get("type"),
"command": self.parse_actions_from_tool_call(content_block)
})
elif content_block["type"] == "text":
reasonings.append(content_block["text"])
if isinstance(reasonings, list) and len(reasonings) > 0:
reasonings = reasonings[0]
else:
reasonings = ""
logger.info(f"Received actions: {actions}")
logger.info(f"Received reasonings: {reasonings}")
if len(actions) == 0:
actions = ["DONE"]
return reasonings, actions
except Exception as e:
logger.warning(f"parse_actions_from_tool_call解析失败{parse_retry+1}/3次将重新请求API: {e}")
# 删除刚刚append的assistant消息避免污染history
self.messages.pop()
# 重新请求API
response = None
for attempt in range(API_RETRY_TIMES):
try:
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
logger.info(f"Response: {response}")
break # 成功则跳出重试循环
except (APIError, APIStatusError, APIResponseValidationError) as e2:
error_msg = str(e2)
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
if attempt < API_RETRY_TIMES - 1:
time.sleep(API_RETRY_INTERVAL)
else:
raise
response_params = _response_to_params(response)
logger.info(f"Received response params: {response_params}")
self.messages.append({
"role": "assistant",
"content": response_params
})
if parse_retry == max_parse_retry - 1:
logger.error(f"连续3次parse_actions_from_tool_call解析失败终止: {e}")
actions = ["FAIL"]
return reasonings, actions
def reset(self, _logger = None, *args, **kwargs):
"""
Reset the agent's state.
"""
global logger
if _logger:
logger = _logger
else:
logger = logging.getLogger("desktopenv.agent")
self.messages = []
logger.info(f"{self.class_name} reset.")