update claude (#280)

* add uitars agent code

* improve claude

* improve claude

* improve claude

* improve claude

* improve claude
This commit is contained in:
Yuan Mengqi
2025-07-23 03:35:49 +08:00
committed by GitHub
parent 53fb96298a
commit 0a37cccd53
3 changed files with 472 additions and 227 deletions

View File

@@ -32,6 +32,8 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action) logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info = env.step(action, args.sleep_after_execution) obs, reward, done, info = env.step(action, args.sleep_after_execution)
time.sleep(3)
obs = env._get_obs()
logger.info("Reward: %.2f", reward) logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done) logger.info("Done: %s", done)

View File

@@ -23,6 +23,10 @@ from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to
import logging import logging
logger = logging.getLogger("desktopenv.agent") logger = logging.getLogger("desktopenv.agent")
# MAX_HISTORY = 10
API_RETRY_TIMES = 500
API_RETRY_INTERVAL = 5
class AnthropicAgent: class AnthropicAgent:
def __init__(self, def __init__(self,
platform: str = "Ubuntu", platform: str = "Ubuntu",
@@ -107,9 +111,24 @@ class AnthropicAgent:
int(coordinate[0] * self.resize_factor[0]), int(coordinate[0] * self.resize_factor[0]),
int(coordinate[1] * self.resize_factor[1]) int(coordinate[1] * self.resize_factor[1])
) )
if action == "left_mouse_down":
result += "pyautogui.mouseDown()\n"
elif action == "left_mouse_up":
result += "pyautogui.mouseUp()\n"
elif action == "hold_key":
if not isinstance(text, str):
raise ValueError(f"{text} must be a string")
keys = text.split('+')
for key in keys:
key = key.strip().lower()
result += f"pyautogui.keyDown('{key}')\n"
expected_outcome = f"Keys {text} held down."
# Handle mouse move and drag actions # Handle mouse move and drag actions
if action in ("mouse_move", "left_click_drag"): elif action in ("mouse_move", "left_click_drag"):
if coordinate is None: if coordinate is None:
raise ValueError(f"coordinate is required for {action}") raise ValueError(f"coordinate is required for {action}")
if text is not None: if text is not None:
@@ -189,7 +208,7 @@ class AnthropicAgent:
expected_outcome = "Scroll action finished" expected_outcome = "Scroll action finished"
# Handle click actions # Handle click actions
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press"): elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press", "triple_click"):
if coordinate is not None: if coordinate is not None:
x, y = coordinate x, y = coordinate
if action == "left_click": if action == "left_click":
@@ -204,6 +223,9 @@ class AnthropicAgent:
result += (f"pyautogui.mouseDown({x}, {y})\n") result += (f"pyautogui.mouseDown({x}, {y})\n")
result += ("time.sleep(1)\n") result += ("time.sleep(1)\n")
result += (f"pyautogui.mouseUp({x}, {y})\n") result += (f"pyautogui.mouseUp({x}, {y})\n")
elif action == "triple_click":
result += (f"pyautogui.tripleClick({x}, {y})\n")
else: else:
if action == "left_click": if action == "left_click":
result += ("pyautogui.click()\n") result += ("pyautogui.click()\n")
@@ -217,6 +239,8 @@ class AnthropicAgent:
result += ("pyautogui.mouseDown()\n") result += ("pyautogui.mouseDown()\n")
result += ("time.sleep(1)\n") result += ("time.sleep(1)\n")
result += ("pyautogui.mouseUp()\n") result += ("pyautogui.mouseUp()\n")
elif action == "triple_click":
result += ("pyautogui.tripleClick()\n")
expected_outcome = "Click action finished" expected_outcome = "Click action finished"
elif action == "wait": elif action == "wait":
@@ -239,6 +263,54 @@ class AnthropicAgent:
return result return result
def _trim_history(self, max_rounds=4):
messages = self.messages
if not messages or len(messages) <= 1:
return
# 计算需要保留的最近轮次数
actual_max_rounds = max_rounds * 2
# 如果消息数量不超过限制,不需要处理
if len(messages) <= actual_max_rounds:
return
# 保留前3条消息初始消息和最近的actual_max_rounds条消息 messages[0:1] + messages[-actual_max_rounds:]
keep_messages = []
# 对于中间被删除的消息,只保留非图片内容
for i in range(1, len(messages) - actual_max_rounds):
old_message = messages[i]
if old_message["role"] == "user" and "content" in old_message:
# 过滤掉image类型的内容块保留其他类型
filtered_content = []
for content_block in old_message["content"]:
filtered_content_item = []
if content_block.get("type") == "tool_result":
for content_block_item in content_block["content"]:
if content_block_item.get("type") != "image":
filtered_content_item.append(content_block_item)
filtered_content.append({
"type": content_block.get("type"),
"tool_use_id": content_block.get("tool_use_id"),
"content": filtered_content_item
})
else:
filtered_content.append(content_block)
# 如果过滤后还有内容,则保留这条消息
if filtered_content:
keep_messages.append({
"role": old_message["role"],
"content": filtered_content
})
else:
# 非用户消息或没有content的消息直接保留
keep_messages.append(old_message)
self.messages = messages[0:1] + keep_messages + messages[-actual_max_rounds:]
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None): def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
system = BetaTextBlockParam( system = BetaTextBlockParam(
type="text", type="text",
@@ -326,8 +398,10 @@ class AnthropicAgent:
min_removal_threshold=image_truncation_threshold, min_removal_threshold=image_truncation_threshold,
) )
try:
#self._trim_history(max_rounds=MAX_HISTORY)
try:
if self.model_name == "claude-3-5-sonnet-20241022": if self.model_name == "claude-3-5-sonnet-20241022":
tools = [ tools = [
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1}, {'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
@@ -336,7 +410,7 @@ class AnthropicAgent:
] if self.platform == 'Ubuntu' else [ ] if self.platform == 'Ubuntu' else [
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1}, {'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
] ]
elif self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514": elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
tools = [ tools = [
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1}, {'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
# {'type': 'bash_20250124', 'name': 'bash'}, # {'type': 'bash_20250124', 'name': 'bash'},
@@ -348,25 +422,54 @@ class AnthropicAgent:
"thinking": {"type": "enabled", "budget_tokens": 1024} "thinking": {"type": "enabled", "budget_tokens": 1024}
} }
response = None response = None
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
response = client.beta.messages.create( for attempt in range(API_RETRY_TIMES):
max_tokens=self.max_tokens, try:
messages=self.messages, if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name], response = client.beta.messages.create(
system=[system], max_tokens=self.max_tokens,
tools=tools, messages=self.messages,
betas=betas, model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
extra_body=extra_body system=[system],
) tools=tools,
elif self.model_name == "claude-3-5-sonnet-20241022": betas=betas,
response = client.beta.messages.create( extra_body=extra_body
max_tokens=self.max_tokens, )
messages=self.messages, elif self.model_name == "claude-3-5-sonnet-20241022":
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name], response = client.beta.messages.create(
system=[system], max_tokens=self.max_tokens,
tools=tools, messages=self.messages,
betas=betas, model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
) system=[system],
tools=tools,
betas=betas,
)
logger.info(f"Response: {response}")
break # 成功则跳出重试循环
except (APIError, APIStatusError, APIResponseValidationError) as e:
error_msg = str(e)
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
# 检查是否是25MB限制错误
if "25000000" in error_msg or "Member must have length less than or equal to" in error_msg:
logger.warning("检测到25MB限制错误自动裁剪图片数量")
# 将图片数量减半
current_image_count = self.only_n_most_recent_images
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
self.only_n_most_recent_images = new_image_count
# 重新应用图片过滤
_maybe_filter_to_n_most_recent_images(
self.messages,
new_image_count,
min_removal_threshold=image_truncation_threshold,
)
logger.info(f"图片数量已从 {current_image_count} 减少到 {new_image_count}")
if attempt < API_RETRY_TIMES - 1:
time.sleep(API_RETRY_INTERVAL)
else:
raise # 全部失败后抛出异常进入原有except逻辑
except (APIError, APIStatusError, APIResponseValidationError) as e: except (APIError, APIStatusError, APIResponseValidationError) as e:
logger.exception(f"Anthropic API error: {str(e)}") logger.exception(f"Anthropic API error: {str(e)}")
@@ -374,8 +477,7 @@ class AnthropicAgent:
logger.warning("Retrying with backup API key...") logger.warning("Retrying with backup API key...")
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4) backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
response = backup_client.beta.messages.create( response = backup_client.beta.messages.create(
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
messages=self.messages, messages=self.messages,
@@ -396,7 +498,25 @@ class AnthropicAgent:
) )
logger.info("Successfully used backup API key") logger.info("Successfully used backup API key")
except Exception as backup_e: except Exception as backup_e:
logger.exception(f"Backup API call also failed: {str(backup_e)}") backup_error_msg = str(backup_e)
logger.exception(f"Backup API call also failed: {backup_error_msg}")
# 检查备用API是否也是25MB限制错误
if "25000000" in backup_error_msg or "Member must have length less than or equal to" in backup_error_msg:
logger.warning("备用API也遇到25MB限制错误进一步裁剪图片数量")
# 将图片数量再减半
current_image_count = self.only_n_most_recent_images
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
self.only_n_most_recent_images = new_image_count
# 重新应用图片过滤
_maybe_filter_to_n_most_recent_images(
self.messages,
new_image_count,
min_removal_threshold=image_truncation_threshold,
)
logger.info(f"备用API图片数量已从 {current_image_count} 减少到 {new_image_count}")
return None, None return None, None
except Exception as e: except Exception as e:
@@ -412,29 +532,77 @@ class AnthropicAgent:
"content": response_params "content": response_params
}) })
actions: list[Any] = [] max_parse_retry = 3
reasonings: list[str] = [] for parse_retry in range(max_parse_retry):
for content_block in response_params: actions: list[Any] = []
if content_block["type"] == "tool_use": reasonings: list[str] = []
actions.append({ try:
"name": content_block["name"], for content_block in response_params:
"input": cast(dict[str, Any], content_block["input"]), if content_block["type"] == "tool_use":
"id": content_block["id"], actions.append({
"action_type": content_block.get("type"), "name": content_block["name"],
"command": self.parse_actions_from_tool_call(content_block) "input": cast(dict[str, Any], content_block["input"]),
"id": content_block["id"],
"action_type": content_block.get("type"),
"command": self.parse_actions_from_tool_call(content_block)
})
elif content_block["type"] == "text":
reasonings.append(content_block["text"])
if isinstance(reasonings, list) and len(reasonings) > 0:
reasonings = reasonings[0]
else:
reasonings = ""
logger.info(f"Received actions: {actions}")
logger.info(f"Received reasonings: {reasonings}")
if len(actions) == 0:
actions = ["DONE"]
return reasonings, actions
except Exception as e:
logger.warning(f"parse_actions_from_tool_call解析失败{parse_retry+1}/3次将重新请求API: {e}")
# 删除刚刚append的assistant消息避免污染history
self.messages.pop()
# 重新请求API
response = None
for attempt in range(API_RETRY_TIMES):
try:
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
logger.info(f"Response: {response}")
break # 成功则跳出重试循环
except (APIError, APIStatusError, APIResponseValidationError) as e2:
error_msg = str(e2)
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
if attempt < API_RETRY_TIMES - 1:
time.sleep(API_RETRY_INTERVAL)
else:
raise
response_params = _response_to_params(response)
logger.info(f"Received response params: {response_params}")
self.messages.append({
"role": "assistant",
"content": response_params
}) })
elif content_block["type"] == "text": if parse_retry == max_parse_retry - 1:
reasonings.append(content_block["text"]) logger.error(f"连续3次parse_actions_from_tool_call解析失败终止: {e}")
if isinstance(reasonings, list) and len(reasonings) > 0: actions = ["FAIL"]
reasonings = reasonings[0] return reasonings, actions
else:
reasonings = ""
logger.info(f"Received actions: {actions}")
logger.info(f"Received reasonings: {reasonings}")
if len(actions) == 0:
actions = ["DONE"]
return reasonings, actions
def reset(self, _logger = None, *args, **kwargs): def reset(self, _logger = None, *args, **kwargs):
""" """
Reset the agent's state. Reset the agent's state.

View File

@@ -11,9 +11,9 @@ import sys
from typing import List, Dict from typing import List, Dict
import math import math
from tqdm import tqdm from tqdm import tqdm
from multiprocessing import Process, Manager from multiprocessing import Process, Manager, current_process
# import lib_run_single import lib_run_single
# from desktop_env.desktop_env import DesktopEnv from desktop_env.desktop_env import DesktopEnv
from mm_agents.anthropic import AnthropicAgent as PromptAgent from mm_agents.anthropic import AnthropicAgent as PromptAgent
import fake_run_single as lib_run_single import fake_run_single as lib_run_single
@@ -26,41 +26,28 @@ load_dotenv()
# Logger Configs {{{ # # Logger Configs {{{ #
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.DEBUG) logger.setLevel(logging.INFO)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler( file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
) )
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
)
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO) stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
) )
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter) stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv")) stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler) logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler) logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs # # }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment") logger = logging.getLogger("desktopenv.experiment")
@@ -85,8 +72,18 @@ def config() -> argparse.Namespace:
default="a11y_tree", default="a11y_tree",
help="Observation type", help="Observation type",
) )
parser.add_argument("--screen_width", type=int, default=1920) parser.add_argument(
parser.add_argument("--screen_height", type=int, default=1080) "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
parser.add_argument("--sleep_after_execution", type=float, default=0.0) parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15) parser.add_argument("--max_steps", type=int, default=15)
@@ -122,7 +119,7 @@ def config() -> argparse.Namespace:
return args return args
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: def distribute_tasks(test_all_meta: dict) -> List[tuple]:
"""Distribute tasks evenly across environments.""" """Distribute tasks evenly across environments."""
# Flatten the tasks into a single list # Flatten the tasks into a single list
all_tasks = [] all_tasks = []
@@ -130,97 +127,35 @@ def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
for example_id in examples: for example_id in examples:
all_tasks.append((domain, example_id)) all_tasks.append((domain, example_id))
# Calculate tasks per environment return all_tasks
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
# Distribute tasks
distributed_tasks = []
for i in range(num_envs):
env_tasks = {}
start_idx = i * tasks_per_env
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
for domain, example_id in all_tasks[start_idx:end_idx]:
if domain not in env_tasks:
env_tasks[domain] = []
env_tasks[domain].append(example_id)
distributed_tasks.append(env_tasks)
return distributed_tasks
def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): def run_env_tasks(task_queue, args, shared_scores):
"""Run tasks for a single environment.""" """Run tasks for a single environment."""
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") active_environments = []
env = None
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): try:
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): from desktop_env.providers.aws.manager import IMAGE_ID_MAP
config_file = os.path.join( REGION = args.region
args.test_config_base_dir, f"examples/{domain}/{example_id}.json" screen_size = (args.screen_width, args.screen_height)
) ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
with open(config_file, "r", encoding="utf-8") as f: env = DesktopEnv(
example = json.load(f) path_to_vm=args.path_to_vm,
action_space=args.action_space,
logger.info(f"[Env {env_idx+1}][Domain]: {domain}") provider_name=args.provider_name,
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") region=REGION,
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") snapshot_name=ami_id,
screen_size=screen_size,
example_result_dir = os.path.join( headless=args.headless,
args.result_dir, os_type="Ubuntu",
args.action_space, require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
args.observation_type, enable_proxy=True,
args.model, client_password=args.client_password
domain, )
example_id, active_environments.append(env)
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
# logger traceback
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
)
)
f.write("\n")
env.close()
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
logger.info("Args: %s", args)
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
# First, set up all environments
logger.info("Setting up all environments...")
envs = []
agents = []
for env_idx in range(args.num_envs):
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
agent = PromptAgent( agent = PromptAgent(
env=env,
model=args.model, model=args.model,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
top_p=args.top_p, top_p=args.top_p,
@@ -228,50 +163,193 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
action_space=args.action_space, action_space=args.action_space,
observation_type=args.observation_type, observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length, max_trajectory_length=args.max_trajectory_length,
screen_size=(args.screen_width, args.screen_height), client_password=args.client_password,
provider_name=args.provider_name,
screen_width=args.screen_width,
screen_height=args.screen_height
) )
agents.append(agent) logger.info(f"Process {current_process().name} started.")
while True:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP try:
REGION = "us-east-1" item = task_queue.get(timeout=5)
env = DesktopEnv( except Exception:
path_to_vm=args.path_to_vm, break
action_space=agent.action_space, domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
provider_name="aws",
region="us-east-1", def process_signal_handler(signum, frame, env_idx):
snapshot_name=IMAGE_ID_MAP[REGION], logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
screen_size=(args.screen_width, args.screen_height), local_vars = frame.f_locals
headless=args.headless, active_environments = local_vars.get('active_environments', [])
os_type="Ubuntu", for env in active_environments:
require_a11y_tree=args.observation_type if env is not None:
in ["a11y_tree", "screenshot_a11y_tree", "som"], try:
) logger.info(f"Process {env_idx + 1} closing environment...")
envs.append(env) env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
logger.info("All environments are ready. Starting parallel task execution...") except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
# Create a shared list for scores across processes logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def signal_handler(signum, frame):
global is_terminating, active_environments, processes
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
time.sleep(1)
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager: with Manager() as manager:
shared_scores = manager.list() shared_scores = manager.list()
task_queue = manager.Queue()
# Create and start processes for each environment for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = [] processes = []
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): for i in range(num_envs):
p = Process( p = Process(
target=run_env_tasks, target=run_env_tasks,
args=(env_idx, env, agent, env_tasks, args, shared_scores) args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
) )
processes.append(p) p.daemon = True
p.start() p.start()
processes.append(p)
# Wait for all processes to complete logger.info(f"Started process {p.name} with PID {p.pid}")
for p in processes: try:
p.join() while True:
alive_count = 0
# Convert shared list to regular list for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores) scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
@@ -350,46 +428,43 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file
if __name__ == "__main__": if __name__ == "__main__":
####### The complete version of the list of examples ####### ####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
import signal
import time
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
args = config()
args = config() with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
# save args to json in result_dir/action_space/observation_type/model/args.json test_all_meta = json.load(f)
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f: if args.domain != "all":
test_all_meta = json.load(f) test_all_meta = {args.domain: test_all_meta[args.domain]}
if args.domain != "all": test_file_list = get_unfinished(
test_all_meta = {args.domain: test_all_meta[args.domain]} args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
test_file_list = get_unfinished( get_result(
args.action_space, args.action_space,
args.model, args.model,
args.observation_type, args.observation_type,
args.result_dir, args.result_dir,
test_all_meta, test_all_meta,
) )
left_info = "" test(args, test_file_list)
for domain in test_file_list: except KeyboardInterrupt:
left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info("Main process received KeyboardInterrupt.")
logger.info(f"Left tasks:\n{left_info}") except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
get_result( signal_handler(signal.SIGTERM, None)
args.action_space, finally:
args.model, logger.info("Main process final cleanup...")
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
# path_to_vm can be a list["xxx","xxx"]