update claude (#280)
* add uitars agent code * improve claude * improve claude * improve claude * improve claude * improve claude
This commit is contained in:
@@ -32,6 +32,8 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
logger.info("Step %d: %s", step_idx + 1, action)
|
logger.info("Step %d: %s", step_idx + 1, action)
|
||||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||||
|
time.sleep(3)
|
||||||
|
obs = env._get_obs()
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
logger.info("Reward: %.2f", reward)
|
||||||
logger.info("Done: %s", done)
|
logger.info("Done: %s", done)
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to
|
|||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger("desktopenv.agent")
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
# MAX_HISTORY = 10
|
||||||
|
API_RETRY_TIMES = 500
|
||||||
|
API_RETRY_INTERVAL = 5
|
||||||
|
|
||||||
class AnthropicAgent:
|
class AnthropicAgent:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
platform: str = "Ubuntu",
|
platform: str = "Ubuntu",
|
||||||
@@ -107,9 +111,24 @@ class AnthropicAgent:
|
|||||||
int(coordinate[0] * self.resize_factor[0]),
|
int(coordinate[0] * self.resize_factor[0]),
|
||||||
int(coordinate[1] * self.resize_factor[1])
|
int(coordinate[1] * self.resize_factor[1])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if action == "left_mouse_down":
|
||||||
|
result += "pyautogui.mouseDown()\n"
|
||||||
|
elif action == "left_mouse_up":
|
||||||
|
result += "pyautogui.mouseUp()\n"
|
||||||
|
|
||||||
|
elif action == "hold_key":
|
||||||
|
if not isinstance(text, str):
|
||||||
|
raise ValueError(f"{text} must be a string")
|
||||||
|
|
||||||
|
keys = text.split('+')
|
||||||
|
for key in keys:
|
||||||
|
key = key.strip().lower()
|
||||||
|
result += f"pyautogui.keyDown('{key}')\n"
|
||||||
|
expected_outcome = f"Keys {text} held down."
|
||||||
|
|
||||||
# Handle mouse move and drag actions
|
# Handle mouse move and drag actions
|
||||||
if action in ("mouse_move", "left_click_drag"):
|
elif action in ("mouse_move", "left_click_drag"):
|
||||||
if coordinate is None:
|
if coordinate is None:
|
||||||
raise ValueError(f"coordinate is required for {action}")
|
raise ValueError(f"coordinate is required for {action}")
|
||||||
if text is not None:
|
if text is not None:
|
||||||
@@ -189,7 +208,7 @@ class AnthropicAgent:
|
|||||||
expected_outcome = "Scroll action finished"
|
expected_outcome = "Scroll action finished"
|
||||||
|
|
||||||
# Handle click actions
|
# Handle click actions
|
||||||
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press"):
|
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press", "triple_click"):
|
||||||
if coordinate is not None:
|
if coordinate is not None:
|
||||||
x, y = coordinate
|
x, y = coordinate
|
||||||
if action == "left_click":
|
if action == "left_click":
|
||||||
@@ -204,6 +223,9 @@ class AnthropicAgent:
|
|||||||
result += (f"pyautogui.mouseDown({x}, {y})\n")
|
result += (f"pyautogui.mouseDown({x}, {y})\n")
|
||||||
result += ("time.sleep(1)\n")
|
result += ("time.sleep(1)\n")
|
||||||
result += (f"pyautogui.mouseUp({x}, {y})\n")
|
result += (f"pyautogui.mouseUp({x}, {y})\n")
|
||||||
|
elif action == "triple_click":
|
||||||
|
result += (f"pyautogui.tripleClick({x}, {y})\n")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if action == "left_click":
|
if action == "left_click":
|
||||||
result += ("pyautogui.click()\n")
|
result += ("pyautogui.click()\n")
|
||||||
@@ -217,6 +239,8 @@ class AnthropicAgent:
|
|||||||
result += ("pyautogui.mouseDown()\n")
|
result += ("pyautogui.mouseDown()\n")
|
||||||
result += ("time.sleep(1)\n")
|
result += ("time.sleep(1)\n")
|
||||||
result += ("pyautogui.mouseUp()\n")
|
result += ("pyautogui.mouseUp()\n")
|
||||||
|
elif action == "triple_click":
|
||||||
|
result += ("pyautogui.tripleClick()\n")
|
||||||
expected_outcome = "Click action finished"
|
expected_outcome = "Click action finished"
|
||||||
|
|
||||||
elif action == "wait":
|
elif action == "wait":
|
||||||
@@ -239,6 +263,54 @@ class AnthropicAgent:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _trim_history(self, max_rounds=4):
|
||||||
|
|
||||||
|
messages = self.messages
|
||||||
|
if not messages or len(messages) <= 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 计算需要保留的最近轮次数
|
||||||
|
actual_max_rounds = max_rounds * 2
|
||||||
|
|
||||||
|
# 如果消息数量不超过限制,不需要处理
|
||||||
|
if len(messages) <= actual_max_rounds:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 保留前3条消息(初始消息)和最近的actual_max_rounds条消息 messages[0:1] + messages[-actual_max_rounds:]
|
||||||
|
keep_messages = []
|
||||||
|
|
||||||
|
# 对于中间被删除的消息,只保留非图片内容
|
||||||
|
for i in range(1, len(messages) - actual_max_rounds):
|
||||||
|
old_message = messages[i]
|
||||||
|
if old_message["role"] == "user" and "content" in old_message:
|
||||||
|
# 过滤掉image类型的内容块,保留其他类型
|
||||||
|
filtered_content = []
|
||||||
|
for content_block in old_message["content"]:
|
||||||
|
filtered_content_item = []
|
||||||
|
if content_block.get("type") == "tool_result":
|
||||||
|
for content_block_item in content_block["content"]:
|
||||||
|
if content_block_item.get("type") != "image":
|
||||||
|
filtered_content_item.append(content_block_item)
|
||||||
|
filtered_content.append({
|
||||||
|
"type": content_block.get("type"),
|
||||||
|
"tool_use_id": content_block.get("tool_use_id"),
|
||||||
|
"content": filtered_content_item
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
filtered_content.append(content_block)
|
||||||
|
|
||||||
|
# 如果过滤后还有内容,则保留这条消息
|
||||||
|
if filtered_content:
|
||||||
|
keep_messages.append({
|
||||||
|
"role": old_message["role"],
|
||||||
|
"content": filtered_content
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# 非用户消息或没有content的消息直接保留
|
||||||
|
keep_messages.append(old_message)
|
||||||
|
|
||||||
|
self.messages = messages[0:1] + keep_messages + messages[-actual_max_rounds:]
|
||||||
|
|
||||||
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
|
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
|
||||||
system = BetaTextBlockParam(
|
system = BetaTextBlockParam(
|
||||||
type="text",
|
type="text",
|
||||||
@@ -326,8 +398,10 @@ class AnthropicAgent:
|
|||||||
min_removal_threshold=image_truncation_threshold,
|
min_removal_threshold=image_truncation_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
|
#self._trim_history(max_rounds=MAX_HISTORY)
|
||||||
|
|
||||||
|
try:
|
||||||
if self.model_name == "claude-3-5-sonnet-20241022":
|
if self.model_name == "claude-3-5-sonnet-20241022":
|
||||||
tools = [
|
tools = [
|
||||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||||
@@ -336,7 +410,7 @@ class AnthropicAgent:
|
|||||||
] if self.platform == 'Ubuntu' else [
|
] if self.platform == 'Ubuntu' else [
|
||||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||||
]
|
]
|
||||||
elif self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||||
tools = [
|
tools = [
|
||||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||||
# {'type': 'bash_20250124', 'name': 'bash'},
|
# {'type': 'bash_20250124', 'name': 'bash'},
|
||||||
@@ -348,25 +422,54 @@ class AnthropicAgent:
|
|||||||
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
||||||
}
|
}
|
||||||
response = None
|
response = None
|
||||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
|
||||||
response = client.beta.messages.create(
|
for attempt in range(API_RETRY_TIMES):
|
||||||
max_tokens=self.max_tokens,
|
try:
|
||||||
messages=self.messages,
|
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
response = client.beta.messages.create(
|
||||||
system=[system],
|
max_tokens=self.max_tokens,
|
||||||
tools=tools,
|
messages=self.messages,
|
||||||
betas=betas,
|
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||||
extra_body=extra_body
|
system=[system],
|
||||||
)
|
tools=tools,
|
||||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
betas=betas,
|
||||||
response = client.beta.messages.create(
|
extra_body=extra_body
|
||||||
max_tokens=self.max_tokens,
|
)
|
||||||
messages=self.messages,
|
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
response = client.beta.messages.create(
|
||||||
system=[system],
|
max_tokens=self.max_tokens,
|
||||||
tools=tools,
|
messages=self.messages,
|
||||||
betas=betas,
|
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||||
)
|
system=[system],
|
||||||
|
tools=tools,
|
||||||
|
betas=betas,
|
||||||
|
)
|
||||||
|
logger.info(f"Response: {response}")
|
||||||
|
break # 成功则跳出重试循环
|
||||||
|
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||||||
|
|
||||||
|
# 检查是否是25MB限制错误
|
||||||
|
if "25000000" in error_msg or "Member must have length less than or equal to" in error_msg:
|
||||||
|
logger.warning("检测到25MB限制错误,自动裁剪图片数量")
|
||||||
|
# 将图片数量减半
|
||||||
|
current_image_count = self.only_n_most_recent_images
|
||||||
|
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||||||
|
self.only_n_most_recent_images = new_image_count
|
||||||
|
|
||||||
|
# 重新应用图片过滤
|
||||||
|
_maybe_filter_to_n_most_recent_images(
|
||||||
|
self.messages,
|
||||||
|
new_image_count,
|
||||||
|
min_removal_threshold=image_truncation_threshold,
|
||||||
|
)
|
||||||
|
logger.info(f"图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||||||
|
|
||||||
|
if attempt < API_RETRY_TIMES - 1:
|
||||||
|
time.sleep(API_RETRY_INTERVAL)
|
||||||
|
else:
|
||||||
|
raise # 全部失败后抛出异常,进入原有except逻辑
|
||||||
|
|
||||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||||
logger.exception(f"Anthropic API error: {str(e)}")
|
logger.exception(f"Anthropic API error: {str(e)}")
|
||||||
@@ -374,8 +477,7 @@ class AnthropicAgent:
|
|||||||
logger.warning("Retrying with backup API key...")
|
logger.warning("Retrying with backup API key...")
|
||||||
|
|
||||||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
|
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
|
||||||
|
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
|
||||||
response = backup_client.beta.messages.create(
|
response = backup_client.beta.messages.create(
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
@@ -396,7 +498,25 @@ class AnthropicAgent:
|
|||||||
)
|
)
|
||||||
logger.info("Successfully used backup API key")
|
logger.info("Successfully used backup API key")
|
||||||
except Exception as backup_e:
|
except Exception as backup_e:
|
||||||
logger.exception(f"Backup API call also failed: {str(backup_e)}")
|
backup_error_msg = str(backup_e)
|
||||||
|
logger.exception(f"Backup API call also failed: {backup_error_msg}")
|
||||||
|
|
||||||
|
# 检查备用API是否也是25MB限制错误
|
||||||
|
if "25000000" in backup_error_msg or "Member must have length less than or equal to" in backup_error_msg:
|
||||||
|
logger.warning("备用API也遇到25MB限制错误,进一步裁剪图片数量")
|
||||||
|
# 将图片数量再减半
|
||||||
|
current_image_count = self.only_n_most_recent_images
|
||||||
|
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||||||
|
self.only_n_most_recent_images = new_image_count
|
||||||
|
|
||||||
|
# 重新应用图片过滤
|
||||||
|
_maybe_filter_to_n_most_recent_images(
|
||||||
|
self.messages,
|
||||||
|
new_image_count,
|
||||||
|
min_removal_threshold=image_truncation_threshold,
|
||||||
|
)
|
||||||
|
logger.info(f"备用API图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -412,29 +532,77 @@ class AnthropicAgent:
|
|||||||
"content": response_params
|
"content": response_params
|
||||||
})
|
})
|
||||||
|
|
||||||
actions: list[Any] = []
|
max_parse_retry = 3
|
||||||
reasonings: list[str] = []
|
for parse_retry in range(max_parse_retry):
|
||||||
for content_block in response_params:
|
actions: list[Any] = []
|
||||||
if content_block["type"] == "tool_use":
|
reasonings: list[str] = []
|
||||||
actions.append({
|
try:
|
||||||
"name": content_block["name"],
|
for content_block in response_params:
|
||||||
"input": cast(dict[str, Any], content_block["input"]),
|
if content_block["type"] == "tool_use":
|
||||||
"id": content_block["id"],
|
actions.append({
|
||||||
"action_type": content_block.get("type"),
|
"name": content_block["name"],
|
||||||
"command": self.parse_actions_from_tool_call(content_block)
|
"input": cast(dict[str, Any], content_block["input"]),
|
||||||
|
"id": content_block["id"],
|
||||||
|
"action_type": content_block.get("type"),
|
||||||
|
"command": self.parse_actions_from_tool_call(content_block)
|
||||||
|
})
|
||||||
|
elif content_block["type"] == "text":
|
||||||
|
reasonings.append(content_block["text"])
|
||||||
|
if isinstance(reasonings, list) and len(reasonings) > 0:
|
||||||
|
reasonings = reasonings[0]
|
||||||
|
else:
|
||||||
|
reasonings = ""
|
||||||
|
logger.info(f"Received actions: {actions}")
|
||||||
|
logger.info(f"Received reasonings: {reasonings}")
|
||||||
|
if len(actions) == 0:
|
||||||
|
actions = ["DONE"]
|
||||||
|
return reasonings, actions
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"parse_actions_from_tool_call解析失败(第{parse_retry+1}/3次),将重新请求API: {e}")
|
||||||
|
# 删除刚刚append的assistant消息,避免污染history
|
||||||
|
self.messages.pop()
|
||||||
|
# 重新请求API
|
||||||
|
response = None
|
||||||
|
for attempt in range(API_RETRY_TIMES):
|
||||||
|
try:
|
||||||
|
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||||
|
response = client.beta.messages.create(
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
messages=self.messages,
|
||||||
|
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||||
|
system=[system],
|
||||||
|
tools=tools,
|
||||||
|
betas=betas,
|
||||||
|
extra_body=extra_body
|
||||||
|
)
|
||||||
|
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||||
|
response = client.beta.messages.create(
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
messages=self.messages,
|
||||||
|
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||||
|
system=[system],
|
||||||
|
tools=tools,
|
||||||
|
betas=betas,
|
||||||
|
)
|
||||||
|
logger.info(f"Response: {response}")
|
||||||
|
break # 成功则跳出重试循环
|
||||||
|
except (APIError, APIStatusError, APIResponseValidationError) as e2:
|
||||||
|
error_msg = str(e2)
|
||||||
|
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||||||
|
if attempt < API_RETRY_TIMES - 1:
|
||||||
|
time.sleep(API_RETRY_INTERVAL)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
response_params = _response_to_params(response)
|
||||||
|
logger.info(f"Received response params: {response_params}")
|
||||||
|
self.messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response_params
|
||||||
})
|
})
|
||||||
elif content_block["type"] == "text":
|
if parse_retry == max_parse_retry - 1:
|
||||||
reasonings.append(content_block["text"])
|
logger.error(f"连续3次parse_actions_from_tool_call解析失败,终止: {e}")
|
||||||
if isinstance(reasonings, list) and len(reasonings) > 0:
|
actions = ["FAIL"]
|
||||||
reasonings = reasonings[0]
|
return reasonings, actions
|
||||||
else:
|
|
||||||
reasonings = ""
|
|
||||||
logger.info(f"Received actions: {actions}")
|
|
||||||
logger.info(f"Received reasonings: {reasonings}")
|
|
||||||
if len(actions) == 0:
|
|
||||||
actions = ["DONE"]
|
|
||||||
return reasonings, actions
|
|
||||||
|
|
||||||
def reset(self, _logger = None, *args, **kwargs):
|
def reset(self, _logger = None, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Reset the agent's state.
|
Reset the agent's state.
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import sys
|
|||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from multiprocessing import Process, Manager
|
from multiprocessing import Process, Manager, current_process
|
||||||
# import lib_run_single
|
import lib_run_single
|
||||||
# from desktop_env.desktop_env import DesktopEnv
|
from desktop_env.desktop_env import DesktopEnv
|
||||||
from mm_agents.anthropic import AnthropicAgent as PromptAgent
|
from mm_agents.anthropic import AnthropicAgent as PromptAgent
|
||||||
|
|
||||||
import fake_run_single as lib_run_single
|
import fake_run_single as lib_run_single
|
||||||
@@ -26,41 +26,28 @@ load_dotenv()
|
|||||||
|
|
||||||
# Logger Configs {{{ #
|
# Logger Configs {{{ #
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
file_handler = logging.FileHandler(
|
file_handler = logging.FileHandler(
|
||||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
)
|
)
|
||||||
debug_handler = logging.FileHandler(
|
|
||||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
|
||||||
)
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
sdebug_handler = logging.FileHandler(
|
|
||||||
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
|
||||||
)
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
file_handler.setLevel(logging.INFO)
|
||||||
debug_handler.setLevel(logging.DEBUG)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
stdout_handler.setLevel(logging.INFO)
|
||||||
sdebug_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||||
)
|
)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
debug_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
stdout_handler.setFormatter(formatter)
|
||||||
sdebug_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
logger.addHandler(debug_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
logger.addHandler(stdout_handler)
|
||||||
logger.addHandler(sdebug_handler)
|
|
||||||
# }}} Logger Configs #
|
# }}} Logger Configs #
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
@@ -85,8 +72,18 @@ def config() -> argparse.Namespace:
|
|||||||
default="a11y_tree",
|
default="a11y_tree",
|
||||||
help="Observation type",
|
help="Observation type",
|
||||||
)
|
)
|
||||||
parser.add_argument("--screen_width", type=int, default=1920)
|
parser.add_argument(
|
||||||
parser.add_argument("--screen_height", type=int, default=1080)
|
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--client_password", type=str, default="", help="Client password"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--screen_width", type=int, default=1920, help="Screen width"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--screen_height", type=int, default=1080, help="Screen height"
|
||||||
|
)
|
||||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||||
parser.add_argument("--max_steps", type=int, default=15)
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
|
||||||
@@ -122,7 +119,7 @@ def config() -> argparse.Namespace:
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||||
"""Distribute tasks evenly across environments."""
|
"""Distribute tasks evenly across environments."""
|
||||||
# Flatten the tasks into a single list
|
# Flatten the tasks into a single list
|
||||||
all_tasks = []
|
all_tasks = []
|
||||||
@@ -130,97 +127,35 @@ def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
|||||||
for example_id in examples:
|
for example_id in examples:
|
||||||
all_tasks.append((domain, example_id))
|
all_tasks.append((domain, example_id))
|
||||||
|
|
||||||
# Calculate tasks per environment
|
return all_tasks
|
||||||
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
|
||||||
|
|
||||||
# Distribute tasks
|
|
||||||
distributed_tasks = []
|
|
||||||
for i in range(num_envs):
|
|
||||||
env_tasks = {}
|
|
||||||
start_idx = i * tasks_per_env
|
|
||||||
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
|
||||||
|
|
||||||
for domain, example_id in all_tasks[start_idx:end_idx]:
|
|
||||||
if domain not in env_tasks:
|
|
||||||
env_tasks[domain] = []
|
|
||||||
env_tasks[domain].append(example_id)
|
|
||||||
|
|
||||||
distributed_tasks.append(env_tasks)
|
|
||||||
|
|
||||||
return distributed_tasks
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
def run_env_tasks(task_queue, args, shared_scores):
|
||||||
"""Run tasks for a single environment."""
|
"""Run tasks for a single environment."""
|
||||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
active_environments = []
|
||||||
|
env = None
|
||||||
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
try:
|
||||||
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||||
config_file = os.path.join(
|
REGION = args.region
|
||||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
screen_size = (args.screen_width, args.screen_height)
|
||||||
)
|
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||||
with open(config_file, "r", encoding="utf-8") as f:
|
env = DesktopEnv(
|
||||||
example = json.load(f)
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
provider_name=args.provider_name,
|
||||||
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
region=REGION,
|
||||||
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
snapshot_name=ami_id,
|
||||||
|
screen_size=screen_size,
|
||||||
example_result_dir = os.path.join(
|
headless=args.headless,
|
||||||
args.result_dir,
|
os_type="Ubuntu",
|
||||||
args.action_space,
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
args.observation_type,
|
enable_proxy=True,
|
||||||
args.model,
|
client_password=args.client_password
|
||||||
domain,
|
)
|
||||||
example_id,
|
active_environments.append(env)
|
||||||
)
|
|
||||||
os.makedirs(example_result_dir, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
lib_run_single.run_single_example(
|
|
||||||
agent,
|
|
||||||
env,
|
|
||||||
example,
|
|
||||||
args.max_steps,
|
|
||||||
example["instruction"],
|
|
||||||
args,
|
|
||||||
example_result_dir,
|
|
||||||
shared_scores,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
# logger traceback
|
|
||||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
env.controller.end_recording(
|
|
||||||
os.path.join(example_result_dir, "recording.mp4")
|
|
||||||
)
|
|
||||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|
||||||
logger.info("Args: %s", args)
|
|
||||||
|
|
||||||
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
|
||||||
|
|
||||||
# First, set up all environments
|
|
||||||
logger.info("Setting up all environments...")
|
|
||||||
envs = []
|
|
||||||
agents = []
|
|
||||||
|
|
||||||
for env_idx in range(args.num_envs):
|
|
||||||
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
|
|
||||||
|
|
||||||
agent = PromptAgent(
|
agent = PromptAgent(
|
||||||
|
env=env,
|
||||||
model=args.model,
|
model=args.model,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
@@ -228,50 +163,193 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
observation_type=args.observation_type,
|
observation_type=args.observation_type,
|
||||||
max_trajectory_length=args.max_trajectory_length,
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
screen_size=(args.screen_width, args.screen_height),
|
client_password=args.client_password,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
screen_width=args.screen_width,
|
||||||
|
screen_height=args.screen_height
|
||||||
)
|
)
|
||||||
agents.append(agent)
|
logger.info(f"Process {current_process().name} started.")
|
||||||
|
while True:
|
||||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
try:
|
||||||
REGION = "us-east-1"
|
item = task_queue.get(timeout=5)
|
||||||
env = DesktopEnv(
|
except Exception:
|
||||||
path_to_vm=args.path_to_vm,
|
break
|
||||||
action_space=agent.action_space,
|
domain, example_id = item
|
||||||
|
try:
|
||||||
|
config_file = os.path.join(
|
||||||
|
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||||
|
)
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
example = json.load(f)
|
||||||
|
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||||
|
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||||
|
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||||
|
example_result_dir = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
domain,
|
||||||
|
example_id,
|
||||||
|
)
|
||||||
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
try:
|
||||||
|
lib_run_single.run_single_example(
|
||||||
|
agent,
|
||||||
|
env,
|
||||||
|
example,
|
||||||
|
args.max_steps,
|
||||||
|
example["instruction"],
|
||||||
|
args,
|
||||||
|
example_result_dir,
|
||||||
|
shared_scores,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
try:
|
||||||
|
env.controller.end_recording(
|
||||||
|
os.path.join(example_result_dir, "recording.mp4")
|
||||||
|
)
|
||||||
|
except Exception as rec_e:
|
||||||
|
logger.error(f"Failed to end recording: {rec_e}")
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{"Error": f"{domain}/{example_id} - {e}"}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
finally:
|
||||||
|
logger.info(f"{current_process().name} cleaning up environment...")
|
||||||
|
try:
|
||||||
|
if env:
|
||||||
|
env.close()
|
||||||
|
logger.info(f"{current_process().name} environment closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||||
|
|
||||||
provider_name="aws",
|
|
||||||
region="us-east-1",
|
def process_signal_handler(signum, frame, env_idx):
|
||||||
snapshot_name=IMAGE_ID_MAP[REGION],
|
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||||
screen_size=(args.screen_width, args.screen_height),
|
local_vars = frame.f_locals
|
||||||
headless=args.headless,
|
active_environments = local_vars.get('active_environments', [])
|
||||||
os_type="Ubuntu",
|
for env in active_environments:
|
||||||
require_a11y_tree=args.observation_type
|
if env is not None:
|
||||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
try:
|
||||||
)
|
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||||
envs.append(env)
|
env.close()
|
||||||
|
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||||
logger.info("All environments are ready. Starting parallel task execution...")
|
except Exception as e:
|
||||||
|
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||||
# Create a shared list for scores across processes
|
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
global is_terminating, active_environments, processes
|
||||||
|
if is_terminating:
|
||||||
|
return
|
||||||
|
is_terminating = True
|
||||||
|
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||||
|
for env in active_environments:
|
||||||
|
try:
|
||||||
|
logger.info(f"Closing environment...")
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Environment closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing environment: {e}")
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Sending termination signal to process {p.name}...")
|
||||||
|
p.terminate()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending termination signal to process: {e}")
|
||||||
|
time.sleep(1)
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Forcefully terminating process {p.name}...")
|
||||||
|
import signal as sig
|
||||||
|
os.kill(p.pid, sig.SIGKILL)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error forcefully terminating process: {e}")
|
||||||
|
logger.info("Shutdown complete. Exiting.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||||
|
global processes
|
||||||
|
logger.info("Args: %s", args)
|
||||||
|
all_tasks = distribute_tasks(test_all_meta)
|
||||||
|
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||||
with Manager() as manager:
|
with Manager() as manager:
|
||||||
shared_scores = manager.list()
|
shared_scores = manager.list()
|
||||||
|
task_queue = manager.Queue()
|
||||||
# Create and start processes for each environment
|
for item in all_tasks:
|
||||||
|
task_queue.put(item)
|
||||||
|
num_envs = args.num_envs
|
||||||
processes = []
|
processes = []
|
||||||
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
|
for i in range(num_envs):
|
||||||
p = Process(
|
p = Process(
|
||||||
target=run_env_tasks,
|
target=run_env_tasks,
|
||||||
args=(env_idx, env, agent, env_tasks, args, shared_scores)
|
args=(task_queue, args, shared_scores),
|
||||||
|
name=f"EnvProcess-{i+1}"
|
||||||
)
|
)
|
||||||
processes.append(p)
|
p.daemon = True
|
||||||
p.start()
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
# Wait for all processes to complete
|
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||||
for p in processes:
|
try:
|
||||||
p.join()
|
while True:
|
||||||
|
alive_count = 0
|
||||||
# Convert shared list to regular list
|
for idx, p in enumerate(processes):
|
||||||
|
if not p.is_alive():
|
||||||
|
logger.warning(f"Process {p.name} died, restarting...")
|
||||||
|
new_p = Process(
|
||||||
|
target=run_env_tasks,
|
||||||
|
args=(task_queue, args, shared_scores),
|
||||||
|
name=f"EnvProcess-Restart-{idx+1}"
|
||||||
|
)
|
||||||
|
new_p.daemon = True
|
||||||
|
new_p.start()
|
||||||
|
processes[idx] = new_p
|
||||||
|
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||||
|
else:
|
||||||
|
alive_count += 1
|
||||||
|
if task_queue.empty():
|
||||||
|
logger.info("All tasks finished.")
|
||||||
|
break
|
||||||
|
if alive_count == 0:
|
||||||
|
logger.error("All processes died, exiting.")
|
||||||
|
break
|
||||||
|
time.sleep(5)
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Terminating process {p.name} due to error...")
|
||||||
|
p.terminate()
|
||||||
|
except Exception as term_e:
|
||||||
|
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||||
|
raise
|
||||||
scores = list(shared_scores)
|
scores = list(shared_scores)
|
||||||
|
|
||||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||||
|
|
||||||
|
|
||||||
@@ -350,46 +428,43 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
####### The complete version of the list of examples #######
|
####### The complete version of the list of examples #######
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
try:
|
||||||
|
args = config()
|
||||||
|
|
||||||
args = config()
|
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
test_all_meta = json.load(f)
|
||||||
path_to_args = os.path.join(
|
|
||||||
args.result_dir,
|
|
||||||
args.action_space,
|
|
||||||
args.observation_type,
|
|
||||||
args.model,
|
|
||||||
"args.json",
|
|
||||||
)
|
|
||||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
|
||||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(vars(args), f, indent=4)
|
|
||||||
|
|
||||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
if args.domain != "all":
|
||||||
test_all_meta = json.load(f)
|
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||||
|
|
||||||
if args.domain != "all":
|
test_file_list = get_unfinished(
|
||||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
left_info = ""
|
||||||
|
for domain in test_file_list:
|
||||||
|
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||||
|
logger.info(f"Left tasks:\n{left_info}")
|
||||||
|
|
||||||
test_file_list = get_unfinished(
|
get_result(
|
||||||
args.action_space,
|
args.action_space,
|
||||||
args.model,
|
args.model,
|
||||||
args.observation_type,
|
args.observation_type,
|
||||||
args.result_dir,
|
args.result_dir,
|
||||||
test_all_meta,
|
test_all_meta,
|
||||||
)
|
)
|
||||||
left_info = ""
|
test(args, test_file_list)
|
||||||
for domain in test_file_list:
|
except KeyboardInterrupt:
|
||||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
logger.info("Main process received KeyboardInterrupt.")
|
||||||
logger.info(f"Left tasks:\n{left_info}")
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||||
get_result(
|
signal_handler(signal.SIGTERM, None)
|
||||||
args.action_space,
|
finally:
|
||||||
args.model,
|
logger.info("Main process final cleanup...")
|
||||||
args.observation_type,
|
|
||||||
args.result_dir,
|
|
||||||
test_all_meta,
|
|
||||||
)
|
|
||||||
test(args, test_file_list)
|
|
||||||
|
|
||||||
|
|
||||||
# path_to_vm can be a list["xxx","xxx"]
|
|
||||||
Reference in New Issue
Block a user