update claude (#280)
* add uitars agent code * improve claude * improve claude * improve claude * improve claude * improve claude
This commit is contained in:
@@ -32,6 +32,8 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
time.sleep(3)
|
||||
obs = env._get_obs()
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
@@ -23,6 +23,10 @@ from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to
|
||||
import logging
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
# MAX_HISTORY = 10
|
||||
API_RETRY_TIMES = 500
|
||||
API_RETRY_INTERVAL = 5
|
||||
|
||||
class AnthropicAgent:
|
||||
def __init__(self,
|
||||
platform: str = "Ubuntu",
|
||||
@@ -107,9 +111,24 @@ class AnthropicAgent:
|
||||
int(coordinate[0] * self.resize_factor[0]),
|
||||
int(coordinate[1] * self.resize_factor[1])
|
||||
)
|
||||
|
||||
if action == "left_mouse_down":
|
||||
result += "pyautogui.mouseDown()\n"
|
||||
elif action == "left_mouse_up":
|
||||
result += "pyautogui.mouseUp()\n"
|
||||
|
||||
elif action == "hold_key":
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"{text} must be a string")
|
||||
|
||||
keys = text.split('+')
|
||||
for key in keys:
|
||||
key = key.strip().lower()
|
||||
result += f"pyautogui.keyDown('{key}')\n"
|
||||
expected_outcome = f"Keys {text} held down."
|
||||
|
||||
# Handle mouse move and drag actions
|
||||
if action in ("mouse_move", "left_click_drag"):
|
||||
elif action in ("mouse_move", "left_click_drag"):
|
||||
if coordinate is None:
|
||||
raise ValueError(f"coordinate is required for {action}")
|
||||
if text is not None:
|
||||
@@ -189,7 +208,7 @@ class AnthropicAgent:
|
||||
expected_outcome = "Scroll action finished"
|
||||
|
||||
# Handle click actions
|
||||
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press"):
|
||||
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press", "triple_click"):
|
||||
if coordinate is not None:
|
||||
x, y = coordinate
|
||||
if action == "left_click":
|
||||
@@ -204,6 +223,9 @@ class AnthropicAgent:
|
||||
result += (f"pyautogui.mouseDown({x}, {y})\n")
|
||||
result += ("time.sleep(1)\n")
|
||||
result += (f"pyautogui.mouseUp({x}, {y})\n")
|
||||
elif action == "triple_click":
|
||||
result += (f"pyautogui.tripleClick({x}, {y})\n")
|
||||
|
||||
else:
|
||||
if action == "left_click":
|
||||
result += ("pyautogui.click()\n")
|
||||
@@ -217,6 +239,8 @@ class AnthropicAgent:
|
||||
result += ("pyautogui.mouseDown()\n")
|
||||
result += ("time.sleep(1)\n")
|
||||
result += ("pyautogui.mouseUp()\n")
|
||||
elif action == "triple_click":
|
||||
result += ("pyautogui.tripleClick()\n")
|
||||
expected_outcome = "Click action finished"
|
||||
|
||||
elif action == "wait":
|
||||
@@ -239,6 +263,54 @@ class AnthropicAgent:
|
||||
|
||||
return result
|
||||
|
||||
def _trim_history(self, max_rounds=4):
|
||||
|
||||
messages = self.messages
|
||||
if not messages or len(messages) <= 1:
|
||||
return
|
||||
|
||||
# 计算需要保留的最近轮次数
|
||||
actual_max_rounds = max_rounds * 2
|
||||
|
||||
# 如果消息数量不超过限制,不需要处理
|
||||
if len(messages) <= actual_max_rounds:
|
||||
return
|
||||
|
||||
# 保留前3条消息(初始消息)和最近的actual_max_rounds条消息 messages[0:1] + messages[-actual_max_rounds:]
|
||||
keep_messages = []
|
||||
|
||||
# 对于中间被删除的消息,只保留非图片内容
|
||||
for i in range(1, len(messages) - actual_max_rounds):
|
||||
old_message = messages[i]
|
||||
if old_message["role"] == "user" and "content" in old_message:
|
||||
# 过滤掉image类型的内容块,保留其他类型
|
||||
filtered_content = []
|
||||
for content_block in old_message["content"]:
|
||||
filtered_content_item = []
|
||||
if content_block.get("type") == "tool_result":
|
||||
for content_block_item in content_block["content"]:
|
||||
if content_block_item.get("type") != "image":
|
||||
filtered_content_item.append(content_block_item)
|
||||
filtered_content.append({
|
||||
"type": content_block.get("type"),
|
||||
"tool_use_id": content_block.get("tool_use_id"),
|
||||
"content": filtered_content_item
|
||||
})
|
||||
else:
|
||||
filtered_content.append(content_block)
|
||||
|
||||
# 如果过滤后还有内容,则保留这条消息
|
||||
if filtered_content:
|
||||
keep_messages.append({
|
||||
"role": old_message["role"],
|
||||
"content": filtered_content
|
||||
})
|
||||
else:
|
||||
# 非用户消息或没有content的消息直接保留
|
||||
keep_messages.append(old_message)
|
||||
|
||||
self.messages = messages[0:1] + keep_messages + messages[-actual_max_rounds:]
|
||||
|
||||
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
|
||||
system = BetaTextBlockParam(
|
||||
type="text",
|
||||
@@ -326,8 +398,10 @@ class AnthropicAgent:
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
#self._trim_history(max_rounds=MAX_HISTORY)
|
||||
|
||||
try:
|
||||
if self.model_name == "claude-3-5-sonnet-20241022":
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
@@ -336,7 +410,7 @@ class AnthropicAgent:
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
]
|
||||
elif self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
# {'type': 'bash_20250124', 'name': 'bash'},
|
||||
@@ -348,25 +422,54 @@ class AnthropicAgent:
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
||||
}
|
||||
response = None
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
|
||||
for attempt in range(API_RETRY_TIMES):
|
||||
try:
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
logger.info(f"Response: {response}")
|
||||
break # 成功则跳出重试循环
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||
error_msg = str(e)
|
||||
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||||
|
||||
# 检查是否是25MB限制错误
|
||||
if "25000000" in error_msg or "Member must have length less than or equal to" in error_msg:
|
||||
logger.warning("检测到25MB限制错误,自动裁剪图片数量")
|
||||
# 将图片数量减半
|
||||
current_image_count = self.only_n_most_recent_images
|
||||
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||||
self.only_n_most_recent_images = new_image_count
|
||||
|
||||
# 重新应用图片过滤
|
||||
_maybe_filter_to_n_most_recent_images(
|
||||
self.messages,
|
||||
new_image_count,
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
logger.info(f"图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||||
|
||||
if attempt < API_RETRY_TIMES - 1:
|
||||
time.sleep(API_RETRY_INTERVAL)
|
||||
else:
|
||||
raise # 全部失败后抛出异常,进入原有except逻辑
|
||||
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||
logger.exception(f"Anthropic API error: {str(e)}")
|
||||
@@ -374,8 +477,7 @@ class AnthropicAgent:
|
||||
logger.warning("Retrying with backup API key...")
|
||||
|
||||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
|
||||
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
@@ -396,7 +498,25 @@ class AnthropicAgent:
|
||||
)
|
||||
logger.info("Successfully used backup API key")
|
||||
except Exception as backup_e:
|
||||
logger.exception(f"Backup API call also failed: {str(backup_e)}")
|
||||
backup_error_msg = str(backup_e)
|
||||
logger.exception(f"Backup API call also failed: {backup_error_msg}")
|
||||
|
||||
# 检查备用API是否也是25MB限制错误
|
||||
if "25000000" in backup_error_msg or "Member must have length less than or equal to" in backup_error_msg:
|
||||
logger.warning("备用API也遇到25MB限制错误,进一步裁剪图片数量")
|
||||
# 将图片数量再减半
|
||||
current_image_count = self.only_n_most_recent_images
|
||||
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||||
self.only_n_most_recent_images = new_image_count
|
||||
|
||||
# 重新应用图片过滤
|
||||
_maybe_filter_to_n_most_recent_images(
|
||||
self.messages,
|
||||
new_image_count,
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
logger.info(f"备用API图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||||
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
@@ -412,29 +532,77 @@ class AnthropicAgent:
|
||||
"content": response_params
|
||||
})
|
||||
|
||||
actions: list[Any] = []
|
||||
reasonings: list[str] = []
|
||||
for content_block in response_params:
|
||||
if content_block["type"] == "tool_use":
|
||||
actions.append({
|
||||
"name": content_block["name"],
|
||||
"input": cast(dict[str, Any], content_block["input"]),
|
||||
"id": content_block["id"],
|
||||
"action_type": content_block.get("type"),
|
||||
"command": self.parse_actions_from_tool_call(content_block)
|
||||
max_parse_retry = 3
|
||||
for parse_retry in range(max_parse_retry):
|
||||
actions: list[Any] = []
|
||||
reasonings: list[str] = []
|
||||
try:
|
||||
for content_block in response_params:
|
||||
if content_block["type"] == "tool_use":
|
||||
actions.append({
|
||||
"name": content_block["name"],
|
||||
"input": cast(dict[str, Any], content_block["input"]),
|
||||
"id": content_block["id"],
|
||||
"action_type": content_block.get("type"),
|
||||
"command": self.parse_actions_from_tool_call(content_block)
|
||||
})
|
||||
elif content_block["type"] == "text":
|
||||
reasonings.append(content_block["text"])
|
||||
if isinstance(reasonings, list) and len(reasonings) > 0:
|
||||
reasonings = reasonings[0]
|
||||
else:
|
||||
reasonings = ""
|
||||
logger.info(f"Received actions: {actions}")
|
||||
logger.info(f"Received reasonings: {reasonings}")
|
||||
if len(actions) == 0:
|
||||
actions = ["DONE"]
|
||||
return reasonings, actions
|
||||
except Exception as e:
|
||||
logger.warning(f"parse_actions_from_tool_call解析失败(第{parse_retry+1}/3次),将重新请求API: {e}")
|
||||
# 删除刚刚append的assistant消息,避免污染history
|
||||
self.messages.pop()
|
||||
# 重新请求API
|
||||
response = None
|
||||
for attempt in range(API_RETRY_TIMES):
|
||||
try:
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
logger.info(f"Response: {response}")
|
||||
break # 成功则跳出重试循环
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e2:
|
||||
error_msg = str(e2)
|
||||
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||||
if attempt < API_RETRY_TIMES - 1:
|
||||
time.sleep(API_RETRY_INTERVAL)
|
||||
else:
|
||||
raise
|
||||
response_params = _response_to_params(response)
|
||||
logger.info(f"Received response params: {response_params}")
|
||||
self.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_params
|
||||
})
|
||||
elif content_block["type"] == "text":
|
||||
reasonings.append(content_block["text"])
|
||||
if isinstance(reasonings, list) and len(reasonings) > 0:
|
||||
reasonings = reasonings[0]
|
||||
else:
|
||||
reasonings = ""
|
||||
logger.info(f"Received actions: {actions}")
|
||||
logger.info(f"Received reasonings: {reasonings}")
|
||||
if len(actions) == 0:
|
||||
actions = ["DONE"]
|
||||
return reasonings, actions
|
||||
|
||||
if parse_retry == max_parse_retry - 1:
|
||||
logger.error(f"连续3次parse_actions_from_tool_call解析失败,终止: {e}")
|
||||
actions = ["FAIL"]
|
||||
return reasonings, actions
|
||||
def reset(self, _logger = None, *args, **kwargs):
|
||||
"""
|
||||
Reset the agent's state.
|
||||
|
||||
@@ -11,9 +11,9 @@ import sys
|
||||
from typing import List, Dict
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Process, Manager
|
||||
# import lib_run_single
|
||||
# from desktop_env.desktop_env import DesktopEnv
|
||||
from multiprocessing import Process, Manager, current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.anthropic import AnthropicAgent as PromptAgent
|
||||
|
||||
import fake_run_single as lib_run_single
|
||||
@@ -26,41 +26,28 @@ load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
@@ -85,8 +72,18 @@ def config() -> argparse.Namespace:
|
||||
default="a11y_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
@@ -122,7 +119,7 @@ def config() -> argparse.Namespace:
|
||||
return args
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
"""Distribute tasks evenly across environments."""
|
||||
# Flatten the tasks into a single list
|
||||
all_tasks = []
|
||||
@@ -130,97 +127,35 @@ def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
|
||||
# Calculate tasks per environment
|
||||
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
||||
|
||||
# Distribute tasks
|
||||
distributed_tasks = []
|
||||
for i in range(num_envs):
|
||||
env_tasks = {}
|
||||
start_idx = i * tasks_per_env
|
||||
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
||||
|
||||
for domain, example_id in all_tasks[start_idx:end_idx]:
|
||||
if domain not in env_tasks:
|
||||
env_tasks[domain] = []
|
||||
env_tasks[domain].append(example_id)
|
||||
|
||||
distributed_tasks.append(env_tasks)
|
||||
|
||||
return distributed_tasks
|
||||
return all_tasks
|
||||
|
||||
|
||||
|
||||
def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
||||
def run_env_tasks(task_queue, args, shared_scores):
|
||||
"""Run tasks for a single environment."""
|
||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
||||
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
||||
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
||||
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
# logger traceback
|
||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
||||
|
||||
# First, set up all environments
|
||||
logger.info("Setting up all environments...")
|
||||
envs = []
|
||||
agents = []
|
||||
|
||||
for env_idx in range(args.num_envs):
|
||||
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
agent = PromptAgent(
|
||||
env=env,
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
@@ -228,50 +163,193 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
client_password=args.client_password,
|
||||
provider_name=args.provider_name,
|
||||
screen_width=args.screen_width,
|
||||
screen_height=args.screen_height
|
||||
)
|
||||
agents.append(agent)
|
||||
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = "us-east-1"
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=agent.action_space,
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
provider_name="aws",
|
||||
region="us-east-1",
|
||||
snapshot_name=IMAGE_ID_MAP[REGION],
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
)
|
||||
envs.append(env)
|
||||
|
||||
logger.info("All environments are ready. Starting parallel task execution...")
|
||||
|
||||
# Create a shared list for scores across processes
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
global is_terminating, active_environments, processes
|
||||
if is_terminating:
|
||||
return
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
time.sleep(1)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
|
||||
# Create and start processes for each environment
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(env_idx, env, agent, env_tasks, args, shared_scores)
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
processes.append(p)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# Convert shared list to regular list
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
@@ -350,46 +428,43 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
import signal
|
||||
import time
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
try:
|
||||
args = config()
|
||||
|
||||
args = config()
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
|
||||
|
||||
# path_to_vm can be a list["xxx","xxx"]
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
logger.info("Main process final cleanup...")
|
||||
Reference in New Issue
Block a user