* add uitars agent code * improve claude * improve claude * improve claude * improve claude * improve claude
618 lines
28 KiB
Python
618 lines
28 KiB
Python
import base64
|
||
import os
|
||
import time
|
||
from typing import Any, cast, Optional, Dict
|
||
from PIL import Image
|
||
import io
|
||
|
||
from anthropic import (
|
||
Anthropic,
|
||
AnthropicBedrock,
|
||
AnthropicVertex,
|
||
APIError,
|
||
APIResponseValidationError,
|
||
APIStatusError,
|
||
)
|
||
from anthropic.types.beta import (
|
||
BetaMessageParam,
|
||
BetaTextBlockParam,
|
||
)
|
||
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME
|
||
from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to_n_most_recent_images
|
||
|
||
import logging
|
||
logger = logging.getLogger("desktopenv.agent")
|
||
|
||
# MAX_HISTORY = 10
|
||
API_RETRY_TIMES = 500
|
||
API_RETRY_INTERVAL = 5
|
||
|
||
class AnthropicAgent:
|
||
def __init__(self,
|
||
platform: str = "Ubuntu",
|
||
model: str = "claude-3-5-sonnet-20241022",
|
||
provider: APIProvider = APIProvider.BEDROCK,
|
||
max_tokens: int = 4096,
|
||
api_key: str = os.environ.get("ANTHROPIC_API_KEY", None),
|
||
system_prompt_suffix: str = "",
|
||
only_n_most_recent_images: Optional[int] = 10,
|
||
action_space: str = "claude_computer_use",
|
||
screen_size: tuple[int, int] = (1920, 1080),
|
||
*args, **kwargs
|
||
):
|
||
self.platform = platform
|
||
self.action_space = action_space
|
||
self.logger = logger
|
||
self.class_name = self.__class__.__name__
|
||
self.model_name = model
|
||
self.provider = provider
|
||
self.max_tokens = max_tokens
|
||
self.api_key = api_key
|
||
self.system_prompt_suffix = system_prompt_suffix
|
||
self.only_n_most_recent_images = only_n_most_recent_images
|
||
self.messages: list[BetaMessageParam] = []
|
||
self.screen_size = screen_size
|
||
self.resize_factor = (
|
||
screen_size[0] / 1280, # Assuming 1280 is the base width
|
||
screen_size[1] / 720 # Assuming 720 is the base height
|
||
)
|
||
|
||
def add_tool_result(self, tool_call_id: str, result: str, screenshot: bytes = None):
|
||
"""Add tool result to message history"""
|
||
tool_result_content = [
|
||
{
|
||
"type": "tool_result",
|
||
"tool_use_id": tool_call_id,
|
||
"content": [{"type": "text", "text": result}]
|
||
}
|
||
]
|
||
|
||
# Add screenshot if provided
|
||
if screenshot is not None:
|
||
screenshot_base64 = base64.b64encode(screenshot).decode('utf-8')
|
||
tool_result_content[0]["content"].append({
|
||
"type": "image",
|
||
"source": {
|
||
"type": "base64",
|
||
"media_type": "image/png",
|
||
"data": screenshot_base64
|
||
}
|
||
})
|
||
|
||
self.messages.append({
|
||
"role": "user",
|
||
"content": tool_result_content
|
||
})
|
||
|
||
def parse_actions_from_tool_call(self, tool_call: Dict) -> str:
|
||
result = ""
|
||
function_args = (
|
||
tool_call["input"]
|
||
)
|
||
|
||
action = function_args.get("action")
|
||
if not action:
|
||
action = tool_call.function.name
|
||
action_conversion = {
|
||
"left click": "click",
|
||
"right click": "right_click"
|
||
}
|
||
action = action_conversion.get(action, action)
|
||
|
||
text = function_args.get("text")
|
||
coordinate = function_args.get("coordinate")
|
||
scroll_direction = function_args.get("scroll_direction")
|
||
scroll_amount = function_args.get("scroll_amount")
|
||
duration = function_args.get("duration")
|
||
|
||
# resize coordinates if resize_factor is set
|
||
if coordinate and self.resize_factor:
|
||
coordinate = (
|
||
int(coordinate[0] * self.resize_factor[0]),
|
||
int(coordinate[1] * self.resize_factor[1])
|
||
)
|
||
|
||
if action == "left_mouse_down":
|
||
result += "pyautogui.mouseDown()\n"
|
||
elif action == "left_mouse_up":
|
||
result += "pyautogui.mouseUp()\n"
|
||
|
||
elif action == "hold_key":
|
||
if not isinstance(text, str):
|
||
raise ValueError(f"{text} must be a string")
|
||
|
||
keys = text.split('+')
|
||
for key in keys:
|
||
key = key.strip().lower()
|
||
result += f"pyautogui.keyDown('{key}')\n"
|
||
expected_outcome = f"Keys {text} held down."
|
||
|
||
# Handle mouse move and drag actions
|
||
elif action in ("mouse_move", "left_click_drag"):
|
||
if coordinate is None:
|
||
raise ValueError(f"coordinate is required for {action}")
|
||
if text is not None:
|
||
raise ValueError(f"text is not accepted for {action}")
|
||
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
|
||
raise ValueError(f"{coordinate} must be a tuple of length 2")
|
||
if not all(isinstance(i, int) for i in coordinate):
|
||
raise ValueError(f"{coordinate} must be a tuple of ints")
|
||
|
||
x, y = coordinate[0], coordinate[1]
|
||
if action == "mouse_move":
|
||
result += (
|
||
f"pyautogui.moveTo({x}, {y}, duration={duration or 0.5})\n"
|
||
)
|
||
expected_outcome = f"Mouse moved to ({x},{y})."
|
||
elif action == "left_click_drag":
|
||
result += (
|
||
f"pyautogui.dragTo({x}, {y}, duration={duration or 0.5})\n"
|
||
)
|
||
expected_outcome = f"Cursor dragged to ({x},{y})."
|
||
|
||
# Handle keyboard actions
|
||
elif action in ("key", "type"):
|
||
if text is None:
|
||
raise ValueError(f"text is required for {action}")
|
||
if coordinate is not None:
|
||
raise ValueError(f"coordinate is not accepted for {action}")
|
||
if not isinstance(text, str):
|
||
raise ValueError(f"{text} must be a string")
|
||
|
||
if action == "key":
|
||
key_conversion = {
|
||
"page_down": "pagedown",
|
||
"page_up": "pageup",
|
||
"super_l": "win",
|
||
"super": "command",
|
||
"escape": "esc"
|
||
}
|
||
keys = text.split('+')
|
||
for key in keys:
|
||
key = key.strip().lower()
|
||
key = key_conversion.get(key, key)
|
||
result += (f"pyautogui.keyDown('{key}')\n")
|
||
for key in reversed(keys):
|
||
key = key.strip().lower()
|
||
key = key_conversion.get(key, key)
|
||
result += (f"pyautogui.keyUp('{key}')\n")
|
||
expected_outcome = f"Key {key} pressed."
|
||
elif action == "type":
|
||
result += (
|
||
f"pyautogui.typewrite(\"\"\"{text}\"\"\", interval=0.01)\n"
|
||
)
|
||
expected_outcome = f"Text {text} written."
|
||
|
||
# Handle scroll actions
|
||
elif action == "scroll":
|
||
if coordinate is None:
|
||
if scroll_direction in ("up", "down"):
|
||
result += (
|
||
f"pyautogui.scroll({scroll_amount if scroll_direction == 'up' else -scroll_amount})\n"
|
||
)
|
||
elif scroll_direction in ("left", "right"):
|
||
result += (
|
||
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount})\n"
|
||
)
|
||
else:
|
||
if scroll_direction in ("up", "down"):
|
||
x, y = coordinate[0], coordinate[1]
|
||
result += (
|
||
f"pyautogui.scroll({scroll_amount if scroll_direction == 'up' else -scroll_amount}, {x}, {y})\n"
|
||
)
|
||
elif scroll_direction in ("left", "right"):
|
||
x, y = coordinate[0], coordinate[1]
|
||
result += (
|
||
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount}, {x}, {y})\n"
|
||
)
|
||
expected_outcome = "Scroll action finished"
|
||
|
||
# Handle click actions
|
||
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press", "triple_click"):
|
||
if coordinate is not None:
|
||
x, y = coordinate
|
||
if action == "left_click":
|
||
result += (f"pyautogui.click({x}, {y})\n")
|
||
elif action == "right_click":
|
||
result += (f"pyautogui.rightClick({x}, {y})\n")
|
||
elif action == "double_click":
|
||
result += (f"pyautogui.doubleClick({x}, {y})\n")
|
||
elif action == "middle_click":
|
||
result += (f"pyautogui.middleClick({x}, {y})\n")
|
||
elif action == "left_press":
|
||
result += (f"pyautogui.mouseDown({x}, {y})\n")
|
||
result += ("time.sleep(1)\n")
|
||
result += (f"pyautogui.mouseUp({x}, {y})\n")
|
||
elif action == "triple_click":
|
||
result += (f"pyautogui.tripleClick({x}, {y})\n")
|
||
|
||
else:
|
||
if action == "left_click":
|
||
result += ("pyautogui.click()\n")
|
||
elif action == "right_click":
|
||
result += ("pyautogui.rightClick()\n")
|
||
elif action == "double_click":
|
||
result += ("pyautogui.doubleClick()\n")
|
||
elif action == "middle_click":
|
||
result += ("pyautogui.middleClick()\n")
|
||
elif action == "left_press":
|
||
result += ("pyautogui.mouseDown()\n")
|
||
result += ("time.sleep(1)\n")
|
||
result += ("pyautogui.mouseUp()\n")
|
||
elif action == "triple_click":
|
||
result += ("pyautogui.tripleClick()\n")
|
||
expected_outcome = "Click action finished"
|
||
|
||
elif action == "wait":
|
||
result += "pyautogui.sleep(0.5)\n"
|
||
expected_outcome = "Wait for 0.5 seconds"
|
||
elif action == "fail":
|
||
result += "FAIL"
|
||
expected_outcome = "Finished"
|
||
elif action == "done":
|
||
result += "DONE"
|
||
expected_outcome = "Finished"
|
||
elif action == "call_user":
|
||
result += "CALL_USER"
|
||
expected_outcome = "Call user"
|
||
elif action == "screenshot":
|
||
result += "pyautogui.sleep(0.1)\n"
|
||
expected_outcome = "Screenshot taken"
|
||
else:
|
||
raise ValueError(f"Invalid action: {action}")
|
||
|
||
return result
|
||
|
||
def _trim_history(self, max_rounds=4):
|
||
|
||
messages = self.messages
|
||
if not messages or len(messages) <= 1:
|
||
return
|
||
|
||
# 计算需要保留的最近轮次数
|
||
actual_max_rounds = max_rounds * 2
|
||
|
||
# 如果消息数量不超过限制,不需要处理
|
||
if len(messages) <= actual_max_rounds:
|
||
return
|
||
|
||
# 保留前3条消息(初始消息)和最近的actual_max_rounds条消息 messages[0:1] + messages[-actual_max_rounds:]
|
||
keep_messages = []
|
||
|
||
# 对于中间被删除的消息,只保留非图片内容
|
||
for i in range(1, len(messages) - actual_max_rounds):
|
||
old_message = messages[i]
|
||
if old_message["role"] == "user" and "content" in old_message:
|
||
# 过滤掉image类型的内容块,保留其他类型
|
||
filtered_content = []
|
||
for content_block in old_message["content"]:
|
||
filtered_content_item = []
|
||
if content_block.get("type") == "tool_result":
|
||
for content_block_item in content_block["content"]:
|
||
if content_block_item.get("type") != "image":
|
||
filtered_content_item.append(content_block_item)
|
||
filtered_content.append({
|
||
"type": content_block.get("type"),
|
||
"tool_use_id": content_block.get("tool_use_id"),
|
||
"content": filtered_content_item
|
||
})
|
||
else:
|
||
filtered_content.append(content_block)
|
||
|
||
# 如果过滤后还有内容,则保留这条消息
|
||
if filtered_content:
|
||
keep_messages.append({
|
||
"role": old_message["role"],
|
||
"content": filtered_content
|
||
})
|
||
else:
|
||
# 非用户消息或没有content的消息直接保留
|
||
keep_messages.append(old_message)
|
||
|
||
self.messages = messages[0:1] + keep_messages + messages[-actual_max_rounds:]
|
||
|
||
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
|
||
system = BetaTextBlockParam(
|
||
type="text",
|
||
text=f"{SYSTEM_PROMPT_WINDOWS if self.platform == 'Windows' else SYSTEM_PROMPT}{' ' + self.system_prompt_suffix if self.system_prompt_suffix else ''}"
|
||
)
|
||
|
||
# resize screenshot if resize_factor is set
|
||
if obs and "screenshot" in obs:
|
||
# Convert bytes to PIL Image
|
||
screenshot_bytes = obs["screenshot"]
|
||
screenshot_image = Image.open(io.BytesIO(screenshot_bytes))
|
||
|
||
# Calculate new size based on resize factor
|
||
new_width, new_height = 1280, 720
|
||
|
||
# Resize the image
|
||
resized_image = screenshot_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||
|
||
# Convert back to bytes
|
||
output_buffer = io.BytesIO()
|
||
resized_image.save(output_buffer, format='PNG')
|
||
obs["screenshot"] = output_buffer.getvalue()
|
||
|
||
|
||
if not self.messages:
|
||
|
||
init_screenshot = obs
|
||
init_screenshot_base64 = base64.b64encode(init_screenshot["screenshot"]).decode('utf-8')
|
||
self.messages.append({
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image",
|
||
"source": {
|
||
"type": "base64",
|
||
"media_type": "image/png",
|
||
"data": init_screenshot_base64,
|
||
},
|
||
},
|
||
{"type": "text", "text": task_instruction},
|
||
]
|
||
})
|
||
|
||
if self.messages and "tool_use" in [content_block["type"] for content_block in self.messages[-1]["content"]]:
|
||
self.add_tool_result(
|
||
self.messages[-1]["content"][-1]["id"],
|
||
f"Success",
|
||
screenshot=obs.get("screenshot") if obs else None
|
||
)
|
||
|
||
enable_prompt_caching = False
|
||
betas = ["computer-use-2025-01-24"]
|
||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||
betas = ["computer-use-2025-01-24"]
|
||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||
betas = [COMPUTER_USE_BETA_FLAG]
|
||
|
||
image_truncation_threshold = 10
|
||
if self.provider == APIProvider.ANTHROPIC:
|
||
client = Anthropic(api_key=self.api_key, max_retries=4)
|
||
enable_prompt_caching = True
|
||
elif self.provider == APIProvider.VERTEX:
|
||
client = AnthropicVertex()
|
||
elif self.provider == APIProvider.BEDROCK:
|
||
client = AnthropicBedrock(
|
||
# Authenticate by either providing the keys below or use the default AWS credential providers, such as
|
||
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
|
||
aws_access_key=os.getenv('AWS_ACCESS_KEY_ID'),
|
||
aws_secret_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
|
||
# aws_region changes the aws region to which the request is made. By default, we read AWS_REGION,
|
||
# and if that's not present, we default to us-east-1. Note that we do not read ~/.aws/config for the region.
|
||
aws_region=os.getenv('AWS_DEFAULT_REGION'),
|
||
)
|
||
|
||
if enable_prompt_caching:
|
||
betas.append(PROMPT_CACHING_BETA_FLAG)
|
||
_inject_prompt_caching(self.messages)
|
||
image_truncation_threshold = 50
|
||
system["cache_control"] = {"type": "ephemeral"}
|
||
|
||
if self.only_n_most_recent_images:
|
||
_maybe_filter_to_n_most_recent_images(
|
||
self.messages,
|
||
self.only_n_most_recent_images,
|
||
min_removal_threshold=image_truncation_threshold,
|
||
)
|
||
|
||
|
||
#self._trim_history(max_rounds=MAX_HISTORY)
|
||
|
||
try:
|
||
if self.model_name == "claude-3-5-sonnet-20241022":
|
||
tools = [
|
||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||
# {'type': 'bash_20241022', 'name': 'bash'},
|
||
# {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}
|
||
] if self.platform == 'Ubuntu' else [
|
||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||
]
|
||
elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||
tools = [
|
||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||
# {'type': 'bash_20250124', 'name': 'bash'},
|
||
# {'name': 'str_replace_editor', 'type': 'text_editor_20250124'}
|
||
] if self.platform == 'Ubuntu' else [
|
||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||
]
|
||
extra_body = {
|
||
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
||
}
|
||
response = None
|
||
|
||
for attempt in range(API_RETRY_TIMES):
|
||
try:
|
||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||
response = client.beta.messages.create(
|
||
max_tokens=self.max_tokens,
|
||
messages=self.messages,
|
||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||
system=[system],
|
||
tools=tools,
|
||
betas=betas,
|
||
extra_body=extra_body
|
||
)
|
||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||
response = client.beta.messages.create(
|
||
max_tokens=self.max_tokens,
|
||
messages=self.messages,
|
||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||
system=[system],
|
||
tools=tools,
|
||
betas=betas,
|
||
)
|
||
logger.info(f"Response: {response}")
|
||
break # 成功则跳出重试循环
|
||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||
error_msg = str(e)
|
||
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||
|
||
# 检查是否是25MB限制错误
|
||
if "25000000" in error_msg or "Member must have length less than or equal to" in error_msg:
|
||
logger.warning("检测到25MB限制错误,自动裁剪图片数量")
|
||
# 将图片数量减半
|
||
current_image_count = self.only_n_most_recent_images
|
||
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||
self.only_n_most_recent_images = new_image_count
|
||
|
||
# 重新应用图片过滤
|
||
_maybe_filter_to_n_most_recent_images(
|
||
self.messages,
|
||
new_image_count,
|
||
min_removal_threshold=image_truncation_threshold,
|
||
)
|
||
logger.info(f"图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||
|
||
if attempt < API_RETRY_TIMES - 1:
|
||
time.sleep(API_RETRY_INTERVAL)
|
||
else:
|
||
raise # 全部失败后抛出异常,进入原有except逻辑
|
||
|
||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||
logger.exception(f"Anthropic API error: {str(e)}")
|
||
try:
|
||
logger.warning("Retrying with backup API key...")
|
||
|
||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
|
||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||
response = backup_client.beta.messages.create(
|
||
max_tokens=self.max_tokens,
|
||
messages=self.messages,
|
||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
|
||
system=[system],
|
||
tools=tools,
|
||
betas=betas,
|
||
extra_body=extra_body
|
||
)
|
||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||
response = backup_client.beta.messages.create(
|
||
max_tokens=self.max_tokens,
|
||
messages=self.messages,
|
||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
|
||
system=[system],
|
||
tools=tools,
|
||
betas=betas,
|
||
)
|
||
logger.info("Successfully used backup API key")
|
||
except Exception as backup_e:
|
||
backup_error_msg = str(backup_e)
|
||
logger.exception(f"Backup API call also failed: {backup_error_msg}")
|
||
|
||
# 检查备用API是否也是25MB限制错误
|
||
if "25000000" in backup_error_msg or "Member must have length less than or equal to" in backup_error_msg:
|
||
logger.warning("备用API也遇到25MB限制错误,进一步裁剪图片数量")
|
||
# 将图片数量再减半
|
||
current_image_count = self.only_n_most_recent_images
|
||
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||
self.only_n_most_recent_images = new_image_count
|
||
|
||
# 重新应用图片过滤
|
||
_maybe_filter_to_n_most_recent_images(
|
||
self.messages,
|
||
new_image_count,
|
||
min_removal_threshold=image_truncation_threshold,
|
||
)
|
||
logger.info(f"备用API图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||
|
||
return None, None
|
||
|
||
except Exception as e:
|
||
logger.exception(f"Error in Anthropic API: {str(e)}")
|
||
return None, None
|
||
|
||
response_params = _response_to_params(response)
|
||
logger.info(f"Received response params: {response_params}")
|
||
|
||
# Store response in message history
|
||
self.messages.append({
|
||
"role": "assistant",
|
||
"content": response_params
|
||
})
|
||
|
||
max_parse_retry = 3
|
||
for parse_retry in range(max_parse_retry):
|
||
actions: list[Any] = []
|
||
reasonings: list[str] = []
|
||
try:
|
||
for content_block in response_params:
|
||
if content_block["type"] == "tool_use":
|
||
actions.append({
|
||
"name": content_block["name"],
|
||
"input": cast(dict[str, Any], content_block["input"]),
|
||
"id": content_block["id"],
|
||
"action_type": content_block.get("type"),
|
||
"command": self.parse_actions_from_tool_call(content_block)
|
||
})
|
||
elif content_block["type"] == "text":
|
||
reasonings.append(content_block["text"])
|
||
if isinstance(reasonings, list) and len(reasonings) > 0:
|
||
reasonings = reasonings[0]
|
||
else:
|
||
reasonings = ""
|
||
logger.info(f"Received actions: {actions}")
|
||
logger.info(f"Received reasonings: {reasonings}")
|
||
if len(actions) == 0:
|
||
actions = ["DONE"]
|
||
return reasonings, actions
|
||
except Exception as e:
|
||
logger.warning(f"parse_actions_from_tool_call解析失败(第{parse_retry+1}/3次),将重新请求API: {e}")
|
||
# 删除刚刚append的assistant消息,避免污染history
|
||
self.messages.pop()
|
||
# 重新请求API
|
||
response = None
|
||
for attempt in range(API_RETRY_TIMES):
|
||
try:
|
||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||
response = client.beta.messages.create(
|
||
max_tokens=self.max_tokens,
|
||
messages=self.messages,
|
||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||
system=[system],
|
||
tools=tools,
|
||
betas=betas,
|
||
extra_body=extra_body
|
||
)
|
||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||
response = client.beta.messages.create(
|
||
max_tokens=self.max_tokens,
|
||
messages=self.messages,
|
||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||
system=[system],
|
||
tools=tools,
|
||
betas=betas,
|
||
)
|
||
logger.info(f"Response: {response}")
|
||
break # 成功则跳出重试循环
|
||
except (APIError, APIStatusError, APIResponseValidationError) as e2:
|
||
error_msg = str(e2)
|
||
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||
if attempt < API_RETRY_TIMES - 1:
|
||
time.sleep(API_RETRY_INTERVAL)
|
||
else:
|
||
raise
|
||
response_params = _response_to_params(response)
|
||
logger.info(f"Received response params: {response_params}")
|
||
self.messages.append({
|
||
"role": "assistant",
|
||
"content": response_params
|
||
})
|
||
if parse_retry == max_parse_retry - 1:
|
||
logger.error(f"连续3次parse_actions_from_tool_call解析失败,终止: {e}")
|
||
actions = ["FAIL"]
|
||
return reasonings, actions
|
||
def reset(self, _logger = None, *args, **kwargs):
|
||
"""
|
||
Reset the agent's state.
|
||
"""
|
||
global logger
|
||
if _logger:
|
||
logger = _logger
|
||
else:
|
||
logger = logging.getLogger("desktopenv.agent")
|
||
self.messages = []
|
||
logger.info(f"{self.class_name} reset.")
|
||
|