update claude (#280)
* add uitars agent code * improve claude * improve claude * improve claude * improve claude * improve claude
This commit is contained in:
@@ -23,6 +23,10 @@ from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to
|
||||
import logging
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
# MAX_HISTORY = 10
|
||||
API_RETRY_TIMES = 500
|
||||
API_RETRY_INTERVAL = 5
|
||||
|
||||
class AnthropicAgent:
|
||||
def __init__(self,
|
||||
platform: str = "Ubuntu",
|
||||
@@ -107,9 +111,24 @@ class AnthropicAgent:
|
||||
int(coordinate[0] * self.resize_factor[0]),
|
||||
int(coordinate[1] * self.resize_factor[1])
|
||||
)
|
||||
|
||||
if action == "left_mouse_down":
|
||||
result += "pyautogui.mouseDown()\n"
|
||||
elif action == "left_mouse_up":
|
||||
result += "pyautogui.mouseUp()\n"
|
||||
|
||||
elif action == "hold_key":
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"{text} must be a string")
|
||||
|
||||
keys = text.split('+')
|
||||
for key in keys:
|
||||
key = key.strip().lower()
|
||||
result += f"pyautogui.keyDown('{key}')\n"
|
||||
expected_outcome = f"Keys {text} held down."
|
||||
|
||||
# Handle mouse move and drag actions
|
||||
if action in ("mouse_move", "left_click_drag"):
|
||||
elif action in ("mouse_move", "left_click_drag"):
|
||||
if coordinate is None:
|
||||
raise ValueError(f"coordinate is required for {action}")
|
||||
if text is not None:
|
||||
@@ -189,7 +208,7 @@ class AnthropicAgent:
|
||||
expected_outcome = "Scroll action finished"
|
||||
|
||||
# Handle click actions
|
||||
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press"):
|
||||
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press", "triple_click"):
|
||||
if coordinate is not None:
|
||||
x, y = coordinate
|
||||
if action == "left_click":
|
||||
@@ -204,6 +223,9 @@ class AnthropicAgent:
|
||||
result += (f"pyautogui.mouseDown({x}, {y})\n")
|
||||
result += ("time.sleep(1)\n")
|
||||
result += (f"pyautogui.mouseUp({x}, {y})\n")
|
||||
elif action == "triple_click":
|
||||
result += (f"pyautogui.tripleClick({x}, {y})\n")
|
||||
|
||||
else:
|
||||
if action == "left_click":
|
||||
result += ("pyautogui.click()\n")
|
||||
@@ -217,6 +239,8 @@ class AnthropicAgent:
|
||||
result += ("pyautogui.mouseDown()\n")
|
||||
result += ("time.sleep(1)\n")
|
||||
result += ("pyautogui.mouseUp()\n")
|
||||
elif action == "triple_click":
|
||||
result += ("pyautogui.tripleClick()\n")
|
||||
expected_outcome = "Click action finished"
|
||||
|
||||
elif action == "wait":
|
||||
@@ -239,6 +263,54 @@ class AnthropicAgent:
|
||||
|
||||
return result
|
||||
|
||||
def _trim_history(self, max_rounds=4):
|
||||
|
||||
messages = self.messages
|
||||
if not messages or len(messages) <= 1:
|
||||
return
|
||||
|
||||
# 计算需要保留的最近轮次数
|
||||
actual_max_rounds = max_rounds * 2
|
||||
|
||||
# 如果消息数量不超过限制,不需要处理
|
||||
if len(messages) <= actual_max_rounds:
|
||||
return
|
||||
|
||||
# 保留前3条消息(初始消息)和最近的actual_max_rounds条消息 messages[0:1] + messages[-actual_max_rounds:]
|
||||
keep_messages = []
|
||||
|
||||
# 对于中间被删除的消息,只保留非图片内容
|
||||
for i in range(1, len(messages) - actual_max_rounds):
|
||||
old_message = messages[i]
|
||||
if old_message["role"] == "user" and "content" in old_message:
|
||||
# 过滤掉image类型的内容块,保留其他类型
|
||||
filtered_content = []
|
||||
for content_block in old_message["content"]:
|
||||
filtered_content_item = []
|
||||
if content_block.get("type") == "tool_result":
|
||||
for content_block_item in content_block["content"]:
|
||||
if content_block_item.get("type") != "image":
|
||||
filtered_content_item.append(content_block_item)
|
||||
filtered_content.append({
|
||||
"type": content_block.get("type"),
|
||||
"tool_use_id": content_block.get("tool_use_id"),
|
||||
"content": filtered_content_item
|
||||
})
|
||||
else:
|
||||
filtered_content.append(content_block)
|
||||
|
||||
# 如果过滤后还有内容,则保留这条消息
|
||||
if filtered_content:
|
||||
keep_messages.append({
|
||||
"role": old_message["role"],
|
||||
"content": filtered_content
|
||||
})
|
||||
else:
|
||||
# 非用户消息或没有content的消息直接保留
|
||||
keep_messages.append(old_message)
|
||||
|
||||
self.messages = messages[0:1] + keep_messages + messages[-actual_max_rounds:]
|
||||
|
||||
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
|
||||
system = BetaTextBlockParam(
|
||||
type="text",
|
||||
@@ -326,8 +398,10 @@ class AnthropicAgent:
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
#self._trim_history(max_rounds=MAX_HISTORY)
|
||||
|
||||
try:
|
||||
if self.model_name == "claude-3-5-sonnet-20241022":
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
@@ -336,7 +410,7 @@ class AnthropicAgent:
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
]
|
||||
elif self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
# {'type': 'bash_20250124', 'name': 'bash'},
|
||||
@@ -348,25 +422,54 @@ class AnthropicAgent:
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
||||
}
|
||||
response = None
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
|
||||
for attempt in range(API_RETRY_TIMES):
|
||||
try:
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
logger.info(f"Response: {response}")
|
||||
break # 成功则跳出重试循环
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||
error_msg = str(e)
|
||||
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||||
|
||||
# 检查是否是25MB限制错误
|
||||
if "25000000" in error_msg or "Member must have length less than or equal to" in error_msg:
|
||||
logger.warning("检测到25MB限制错误,自动裁剪图片数量")
|
||||
# 将图片数量减半
|
||||
current_image_count = self.only_n_most_recent_images
|
||||
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||||
self.only_n_most_recent_images = new_image_count
|
||||
|
||||
# 重新应用图片过滤
|
||||
_maybe_filter_to_n_most_recent_images(
|
||||
self.messages,
|
||||
new_image_count,
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
logger.info(f"图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||||
|
||||
if attempt < API_RETRY_TIMES - 1:
|
||||
time.sleep(API_RETRY_INTERVAL)
|
||||
else:
|
||||
raise # 全部失败后抛出异常,进入原有except逻辑
|
||||
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||
logger.exception(f"Anthropic API error: {str(e)}")
|
||||
@@ -374,8 +477,7 @@ class AnthropicAgent:
|
||||
logger.warning("Retrying with backup API key...")
|
||||
|
||||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
|
||||
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
@@ -396,7 +498,25 @@ class AnthropicAgent:
|
||||
)
|
||||
logger.info("Successfully used backup API key")
|
||||
except Exception as backup_e:
|
||||
logger.exception(f"Backup API call also failed: {str(backup_e)}")
|
||||
backup_error_msg = str(backup_e)
|
||||
logger.exception(f"Backup API call also failed: {backup_error_msg}")
|
||||
|
||||
# 检查备用API是否也是25MB限制错误
|
||||
if "25000000" in backup_error_msg or "Member must have length less than or equal to" in backup_error_msg:
|
||||
logger.warning("备用API也遇到25MB限制错误,进一步裁剪图片数量")
|
||||
# 将图片数量再减半
|
||||
current_image_count = self.only_n_most_recent_images
|
||||
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||||
self.only_n_most_recent_images = new_image_count
|
||||
|
||||
# 重新应用图片过滤
|
||||
_maybe_filter_to_n_most_recent_images(
|
||||
self.messages,
|
||||
new_image_count,
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
logger.info(f"备用API图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||||
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
@@ -412,29 +532,77 @@ class AnthropicAgent:
|
||||
"content": response_params
|
||||
})
|
||||
|
||||
actions: list[Any] = []
|
||||
reasonings: list[str] = []
|
||||
for content_block in response_params:
|
||||
if content_block["type"] == "tool_use":
|
||||
actions.append({
|
||||
"name": content_block["name"],
|
||||
"input": cast(dict[str, Any], content_block["input"]),
|
||||
"id": content_block["id"],
|
||||
"action_type": content_block.get("type"),
|
||||
"command": self.parse_actions_from_tool_call(content_block)
|
||||
max_parse_retry = 3
|
||||
for parse_retry in range(max_parse_retry):
|
||||
actions: list[Any] = []
|
||||
reasonings: list[str] = []
|
||||
try:
|
||||
for content_block in response_params:
|
||||
if content_block["type"] == "tool_use":
|
||||
actions.append({
|
||||
"name": content_block["name"],
|
||||
"input": cast(dict[str, Any], content_block["input"]),
|
||||
"id": content_block["id"],
|
||||
"action_type": content_block.get("type"),
|
||||
"command": self.parse_actions_from_tool_call(content_block)
|
||||
})
|
||||
elif content_block["type"] == "text":
|
||||
reasonings.append(content_block["text"])
|
||||
if isinstance(reasonings, list) and len(reasonings) > 0:
|
||||
reasonings = reasonings[0]
|
||||
else:
|
||||
reasonings = ""
|
||||
logger.info(f"Received actions: {actions}")
|
||||
logger.info(f"Received reasonings: {reasonings}")
|
||||
if len(actions) == 0:
|
||||
actions = ["DONE"]
|
||||
return reasonings, actions
|
||||
except Exception as e:
|
||||
logger.warning(f"parse_actions_from_tool_call解析失败(第{parse_retry+1}/3次),将重新请求API: {e}")
|
||||
# 删除刚刚append的assistant消息,避免污染history
|
||||
self.messages.pop()
|
||||
# 重新请求API
|
||||
response = None
|
||||
for attempt in range(API_RETRY_TIMES):
|
||||
try:
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
logger.info(f"Response: {response}")
|
||||
break # 成功则跳出重试循环
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e2:
|
||||
error_msg = str(e2)
|
||||
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||||
if attempt < API_RETRY_TIMES - 1:
|
||||
time.sleep(API_RETRY_INTERVAL)
|
||||
else:
|
||||
raise
|
||||
response_params = _response_to_params(response)
|
||||
logger.info(f"Received response params: {response_params}")
|
||||
self.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_params
|
||||
})
|
||||
elif content_block["type"] == "text":
|
||||
reasonings.append(content_block["text"])
|
||||
if isinstance(reasonings, list) and len(reasonings) > 0:
|
||||
reasonings = reasonings[0]
|
||||
else:
|
||||
reasonings = ""
|
||||
logger.info(f"Received actions: {actions}")
|
||||
logger.info(f"Received reasonings: {reasonings}")
|
||||
if len(actions) == 0:
|
||||
actions = ["DONE"]
|
||||
return reasonings, actions
|
||||
|
||||
if parse_retry == max_parse_retry - 1:
|
||||
logger.error(f"连续3次parse_actions_from_tool_call解析失败,终止: {e}")
|
||||
actions = ["FAIL"]
|
||||
return reasonings, actions
|
||||
def reset(self, _logger = None, *args, **kwargs):
|
||||
"""
|
||||
Reset the agent's state.
|
||||
|
||||
Reference in New Issue
Block a user