From 0a37cccd53fb25499a6397a38f62ac270dbf3c9f Mon Sep 17 00:00:00 2001 From: Yuan Mengqi <100453613+yuanmengqi@users.noreply.github.com> Date: Wed, 23 Jul 2025 03:35:49 +0800 Subject: [PATCH] update claude (#280) * add uitars agent code * improve claude * improve claude * improve claude * improve claude * improve claude --- lib_run_single.py | 2 + mm_agents/anthropic/main.py | 264 ++++++++++++++++++---- run_multienv_claude.py | 433 +++++++++++++++++++++--------------- 3 files changed, 472 insertions(+), 227 deletions(-) diff --git a/lib_run_single.py b/lib_run_single.py index 91a0163..eb91c72 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -32,6 +32,8 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") logger.info("Step %d: %s", step_idx + 1, action) obs, reward, done, info = env.step(action, args.sleep_after_execution) + time.sleep(3) + obs = env._get_obs() logger.info("Reward: %.2f", reward) logger.info("Done: %s", done) diff --git a/mm_agents/anthropic/main.py b/mm_agents/anthropic/main.py index 493a7bb..903796d 100644 --- a/mm_agents/anthropic/main.py +++ b/mm_agents/anthropic/main.py @@ -23,6 +23,10 @@ from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to import logging logger = logging.getLogger("desktopenv.agent") +# MAX_HISTORY = 10 +API_RETRY_TIMES = 500 +API_RETRY_INTERVAL = 5 + class AnthropicAgent: def __init__(self, platform: str = "Ubuntu", @@ -107,9 +111,24 @@ class AnthropicAgent: int(coordinate[0] * self.resize_factor[0]), int(coordinate[1] * self.resize_factor[1]) ) + + if action == "left_mouse_down": + result += "pyautogui.mouseDown()\n" + elif action == "left_mouse_up": + result += "pyautogui.mouseUp()\n" + + elif action == "hold_key": + if not isinstance(text, str): + raise ValueError(f"{text} must be a string") + + keys = text.split('+') + for key in keys: + key = key.strip().lower() + result += f"pyautogui.keyDown('{key}')\n" + expected_outcome = f"Keys {text} held down." # Handle mouse move and drag actions - if action in ("mouse_move", "left_click_drag"): + elif action in ("mouse_move", "left_click_drag"): if coordinate is None: raise ValueError(f"coordinate is required for {action}") if text is not None: @@ -189,7 +208,7 @@ class AnthropicAgent: expected_outcome = "Scroll action finished" # Handle click actions - elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press"): + elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press", "triple_click"): if coordinate is not None: x, y = coordinate if action == "left_click": @@ -204,6 +223,9 @@ class AnthropicAgent: result += (f"pyautogui.mouseDown({x}, {y})\n") result += ("time.sleep(1)\n") result += (f"pyautogui.mouseUp({x}, {y})\n") + elif action == "triple_click": + result += (f"pyautogui.tripleClick({x}, {y})\n") + else: if action == "left_click": result += ("pyautogui.click()\n") @@ -217,6 +239,8 @@ class AnthropicAgent: result += ("pyautogui.mouseDown()\n") result += ("time.sleep(1)\n") result += ("pyautogui.mouseUp()\n") + elif action == "triple_click": + result += ("pyautogui.tripleClick()\n") expected_outcome = "Click action finished" elif action == "wait": @@ -239,6 +263,54 @@ class AnthropicAgent: return result + def _trim_history(self, max_rounds=4): + + messages = self.messages + if not messages or len(messages) <= 1: + return + + # 计算需要保留的最近轮次数 + actual_max_rounds = max_rounds * 2 + + # 如果消息数量不超过限制,不需要处理 + if len(messages) <= actual_max_rounds: + return + + # 保留前3条消息(初始消息)和最近的actual_max_rounds条消息 messages[0:1] + messages[-actual_max_rounds:] + keep_messages = [] + + # 对于中间被删除的消息,只保留非图片内容 + for i in range(1, len(messages) - actual_max_rounds): + old_message = messages[i] + if old_message["role"] == "user" and "content" in old_message: + # 过滤掉image类型的内容块,保留其他类型 + filtered_content = [] + for content_block in old_message["content"]: + filtered_content_item = [] + if content_block.get("type") == "tool_result": + for content_block_item in content_block["content"]: + if content_block_item.get("type") != "image": + filtered_content_item.append(content_block_item) + filtered_content.append({ + "type": content_block.get("type"), + "tool_use_id": content_block.get("tool_use_id"), + "content": filtered_content_item + }) + else: + filtered_content.append(content_block) + + # 如果过滤后还有内容,则保留这条消息 + if filtered_content: + keep_messages.append({ + "role": old_message["role"], + "content": filtered_content + }) + else: + # 非用户消息或没有content的消息直接保留 + keep_messages.append(old_message) + + self.messages = messages[0:1] + keep_messages + messages[-actual_max_rounds:] + def predict(self, task_instruction: str, obs: Dict = None, system: Any = None): system = BetaTextBlockParam( type="text", @@ -326,8 +398,10 @@ class AnthropicAgent: min_removal_threshold=image_truncation_threshold, ) - try: + + #self._trim_history(max_rounds=MAX_HISTORY) + try: if self.model_name == "claude-3-5-sonnet-20241022": tools = [ {'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1}, @@ -336,7 +410,7 @@ class AnthropicAgent: ] if self.platform == 'Ubuntu' else [ {'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1}, ] - elif self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514": + elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]: tools = [ {'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1}, # {'type': 'bash_20250124', 'name': 'bash'}, @@ -348,25 +422,54 @@ class AnthropicAgent: "thinking": {"type": "enabled", "budget_tokens": 1024} } response = None - if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514": - response = client.beta.messages.create( - max_tokens=self.max_tokens, - messages=self.messages, - model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name], - system=[system], - tools=tools, - betas=betas, - extra_body=extra_body - ) - elif self.model_name == "claude-3-5-sonnet-20241022": - response = client.beta.messages.create( - max_tokens=self.max_tokens, - messages=self.messages, - model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name], - system=[system], - tools=tools, - betas=betas, - ) + + for attempt in range(API_RETRY_TIMES): + try: + if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]: + response = client.beta.messages.create( + max_tokens=self.max_tokens, + messages=self.messages, + model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name], + system=[system], + tools=tools, + betas=betas, + extra_body=extra_body + ) + elif self.model_name == "claude-3-5-sonnet-20241022": + response = client.beta.messages.create( + max_tokens=self.max_tokens, + messages=self.messages, + model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name], + system=[system], + tools=tools, + betas=betas, + ) + logger.info(f"Response: {response}") + break # 成功则跳出重试循环 + except (APIError, APIStatusError, APIResponseValidationError) as e: + error_msg = str(e) + logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}") + + # 检查是否是25MB限制错误 + if "25000000" in error_msg or "Member must have length less than or equal to" in error_msg: + logger.warning("检测到25MB限制错误,自动裁剪图片数量") + # 将图片数量减半 + current_image_count = self.only_n_most_recent_images + new_image_count = max(1, current_image_count // 2) # 至少保留1张图片 + self.only_n_most_recent_images = new_image_count + + # 重新应用图片过滤 + _maybe_filter_to_n_most_recent_images( + self.messages, + new_image_count, + min_removal_threshold=image_truncation_threshold, + ) + logger.info(f"图片数量已从 {current_image_count} 减少到 {new_image_count}") + + if attempt < API_RETRY_TIMES - 1: + time.sleep(API_RETRY_INTERVAL) + else: + raise # 全部失败后抛出异常,进入原有except逻辑 except (APIError, APIStatusError, APIResponseValidationError) as e: logger.exception(f"Anthropic API error: {str(e)}") @@ -374,8 +477,7 @@ class AnthropicAgent: logger.warning("Retrying with backup API key...") backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4) - - if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514": + if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]: response = backup_client.beta.messages.create( max_tokens=self.max_tokens, messages=self.messages, @@ -396,7 +498,25 @@ class AnthropicAgent: ) logger.info("Successfully used backup API key") except Exception as backup_e: - logger.exception(f"Backup API call also failed: {str(backup_e)}") + backup_error_msg = str(backup_e) + logger.exception(f"Backup API call also failed: {backup_error_msg}") + + # 检查备用API是否也是25MB限制错误 + if "25000000" in backup_error_msg or "Member must have length less than or equal to" in backup_error_msg: + logger.warning("备用API也遇到25MB限制错误,进一步裁剪图片数量") + # 将图片数量再减半 + current_image_count = self.only_n_most_recent_images + new_image_count = max(1, current_image_count // 2) # 至少保留1张图片 + self.only_n_most_recent_images = new_image_count + + # 重新应用图片过滤 + _maybe_filter_to_n_most_recent_images( + self.messages, + new_image_count, + min_removal_threshold=image_truncation_threshold, + ) + logger.info(f"备用API图片数量已从 {current_image_count} 减少到 {new_image_count}") + return None, None except Exception as e: @@ -412,29 +532,77 @@ class AnthropicAgent: "content": response_params }) - actions: list[Any] = [] - reasonings: list[str] = [] - for content_block in response_params: - if content_block["type"] == "tool_use": - actions.append({ - "name": content_block["name"], - "input": cast(dict[str, Any], content_block["input"]), - "id": content_block["id"], - "action_type": content_block.get("type"), - "command": self.parse_actions_from_tool_call(content_block) + max_parse_retry = 3 + for parse_retry in range(max_parse_retry): + actions: list[Any] = [] + reasonings: list[str] = [] + try: + for content_block in response_params: + if content_block["type"] == "tool_use": + actions.append({ + "name": content_block["name"], + "input": cast(dict[str, Any], content_block["input"]), + "id": content_block["id"], + "action_type": content_block.get("type"), + "command": self.parse_actions_from_tool_call(content_block) + }) + elif content_block["type"] == "text": + reasonings.append(content_block["text"]) + if isinstance(reasonings, list) and len(reasonings) > 0: + reasonings = reasonings[0] + else: + reasonings = "" + logger.info(f"Received actions: {actions}") + logger.info(f"Received reasonings: {reasonings}") + if len(actions) == 0: + actions = ["DONE"] + return reasonings, actions + except Exception as e: + logger.warning(f"parse_actions_from_tool_call解析失败(第{parse_retry+1}/3次),将重新请求API: {e}") + # 删除刚刚append的assistant消息,避免污染history + self.messages.pop() + # 重新请求API + response = None + for attempt in range(API_RETRY_TIMES): + try: + if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]: + response = client.beta.messages.create( + max_tokens=self.max_tokens, + messages=self.messages, + model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name], + system=[system], + tools=tools, + betas=betas, + extra_body=extra_body + ) + elif self.model_name == "claude-3-5-sonnet-20241022": + response = client.beta.messages.create( + max_tokens=self.max_tokens, + messages=self.messages, + model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name], + system=[system], + tools=tools, + betas=betas, + ) + logger.info(f"Response: {response}") + break # 成功则跳出重试循环 + except (APIError, APIStatusError, APIResponseValidationError) as e2: + error_msg = str(e2) + logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}") + if attempt < API_RETRY_TIMES - 1: + time.sleep(API_RETRY_INTERVAL) + else: + raise + response_params = _response_to_params(response) + logger.info(f"Received response params: {response_params}") + self.messages.append({ + "role": "assistant", + "content": response_params }) - elif content_block["type"] == "text": - reasonings.append(content_block["text"]) - if isinstance(reasonings, list) and len(reasonings) > 0: - reasonings = reasonings[0] - else: - reasonings = "" - logger.info(f"Received actions: {actions}") - logger.info(f"Received reasonings: {reasonings}") - if len(actions) == 0: - actions = ["DONE"] - return reasonings, actions - + if parse_retry == max_parse_retry - 1: + logger.error(f"连续3次parse_actions_from_tool_call解析失败,终止: {e}") + actions = ["FAIL"] + return reasonings, actions def reset(self, _logger = None, *args, **kwargs): """ Reset the agent's state. diff --git a/run_multienv_claude.py b/run_multienv_claude.py index 170ac2e..69bcd73 100644 --- a/run_multienv_claude.py +++ b/run_multienv_claude.py @@ -11,9 +11,9 @@ import sys from typing import List, Dict import math from tqdm import tqdm -from multiprocessing import Process, Manager -# import lib_run_single -# from desktop_env.desktop_env import DesktopEnv +from multiprocessing import Process, Manager, current_process +import lib_run_single +from desktop_env.desktop_env import DesktopEnv from mm_agents.anthropic import AnthropicAgent as PromptAgent import fake_run_single as lib_run_single @@ -26,41 +26,28 @@ load_dotenv() # Logger Configs {{{ # logger = logging.getLogger() -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") file_handler = logging.FileHandler( os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" ) -debug_handler = logging.FileHandler( - os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" -) stdout_handler = logging.StreamHandler(sys.stdout) -sdebug_handler = logging.FileHandler( - os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" -) file_handler.setLevel(logging.INFO) -debug_handler.setLevel(logging.DEBUG) stdout_handler.setLevel(logging.INFO) -sdebug_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" ) file_handler.setFormatter(formatter) -debug_handler.setFormatter(formatter) stdout_handler.setFormatter(formatter) -sdebug_handler.setFormatter(formatter) stdout_handler.addFilter(logging.Filter("desktopenv")) -sdebug_handler.addFilter(logging.Filter("desktopenv")) logger.addHandler(file_handler) -logger.addHandler(debug_handler) logger.addHandler(stdout_handler) -logger.addHandler(sdebug_handler) # }}} Logger Configs # logger = logging.getLogger("desktopenv.experiment") @@ -85,8 +72,18 @@ def config() -> argparse.Namespace: default="a11y_tree", help="Observation type", ) - parser.add_argument("--screen_width", type=int, default=1920) - parser.add_argument("--screen_height", type=int, default=1080) + parser.add_argument( + "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) parser.add_argument("--sleep_after_execution", type=float, default=0.0) parser.add_argument("--max_steps", type=int, default=15) @@ -122,7 +119,7 @@ def config() -> argparse.Namespace: return args -def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: +def distribute_tasks(test_all_meta: dict) -> List[tuple]: """Distribute tasks evenly across environments.""" # Flatten the tasks into a single list all_tasks = [] @@ -130,97 +127,35 @@ def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: for example_id in examples: all_tasks.append((domain, example_id)) - # Calculate tasks per environment - tasks_per_env = math.ceil(len(all_tasks) / num_envs) - - # Distribute tasks - distributed_tasks = [] - for i in range(num_envs): - env_tasks = {} - start_idx = i * tasks_per_env - end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) - - for domain, example_id in all_tasks[start_idx:end_idx]: - if domain not in env_tasks: - env_tasks[domain] = [] - env_tasks[domain].append(example_id) - - distributed_tasks.append(env_tasks) - - return distributed_tasks + return all_tasks -def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): +def run_env_tasks(task_queue, args, shared_scores): """Run tasks for a single environment.""" - logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") - - for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): - for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): - config_file = os.path.join( - args.test_config_base_dir, f"examples/{domain}/{example_id}.json" - ) - with open(config_file, "r", encoding="utf-8") as f: - example = json.load(f) - - logger.info(f"[Env {env_idx+1}][Domain]: {domain}") - logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") - logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") - - example_result_dir = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - domain, - example_id, - ) - os.makedirs(example_result_dir, exist_ok=True) - - try: - lib_run_single.run_single_example( - agent, - env, - example, - args.max_steps, - example["instruction"], - args, - example_result_dir, - shared_scores, - ) - except Exception as e: - import traceback - # logger traceback - logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") - logger.error(traceback.format_exc()) - env.controller.end_recording( - os.path.join(example_result_dir, "recording.mp4") - ) - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write( - json.dumps( - {"Error": f"Time limit exceeded in {domain}/{example_id}"} - ) - ) - f.write("\n") - - env.close() - - -def test(args: argparse.Namespace, test_all_meta: dict) -> None: - logger.info("Args: %s", args) - - distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) - - # First, set up all environments - logger.info("Setting up all environments...") - envs = [] - agents = [] - - for env_idx in range(args.num_envs): - logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}") - + active_environments = [] + env = None + try: + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) agent = PromptAgent( + env=env, model=args.model, max_tokens=args.max_tokens, top_p=args.top_p, @@ -228,50 +163,193 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: action_space=args.action_space, observation_type=args.observation_type, max_trajectory_length=args.max_trajectory_length, - screen_size=(args.screen_width, args.screen_height), + client_password=args.client_password, + provider_name=args.provider_name, + screen_width=args.screen_width, + screen_height=args.screen_height ) - agents.append(agent) - - from desktop_env.providers.aws.manager import IMAGE_ID_MAP - REGION = "us-east-1" - env = DesktopEnv( - path_to_vm=args.path_to_vm, - action_space=agent.action_space, + logger.info(f"Process {current_process().name} started.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") - provider_name="aws", - region="us-east-1", - snapshot_name=IMAGE_ID_MAP[REGION], - screen_size=(args.screen_width, args.screen_height), - headless=args.headless, - os_type="Ubuntu", - require_a11y_tree=args.observation_type - in ["a11y_tree", "screenshot_a11y_tree", "som"], - ) - envs.append(env) - - logger.info("All environments are ready. Starting parallel task execution...") - - # Create a shared list for scores across processes + +def process_signal_handler(signum, frame, env_idx): + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) + for env in active_environments: + if env is not None: + try: + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error closing environment: {e}") + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) + + +def signal_handler(signum, frame): + global is_terminating, active_environments, processes + if is_terminating: + return + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + time.sleep(1) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal as sig + os.kill(p.pid, sig.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + all_tasks = distribute_tasks(test_all_meta) + logger.info(f"Total tasks: {len(all_tasks)}") with Manager() as manager: shared_scores = manager.list() - - # Create and start processes for each environment + task_queue = manager.Queue() + for item in all_tasks: + task_queue.put(item) + num_envs = args.num_envs processes = [] - for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): + for i in range(num_envs): p = Process( target=run_env_tasks, - args=(env_idx, env, agent, env_tasks, args, shared_scores) + args=(task_queue, args, shared_scores), + name=f"EnvProcess-{i+1}" ) - processes.append(p) + p.daemon = True p.start() - - # Wait for all processes to complete - for p in processes: - p.join() - - # Convert shared list to regular list + processes.append(p) + logger.info(f"Started process {p.name} with PID {p.pid}") + try: + while True: + alive_count = 0 + for idx, p in enumerate(processes): + if not p.is_alive(): + logger.warning(f"Process {p.name} died, restarting...") + new_p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-Restart-{idx+1}" + ) + new_p.daemon = True + new_p.start() + processes[idx] = new_p + logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") + else: + alive_count += 1 + if task_queue.empty(): + logger.info("All tasks finished.") + break + if alive_count == 0: + logger.error("All processes died, exiting.") + break + time.sleep(5) + for p in processes: + p.join() + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise scores = list(shared_scores) - logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") @@ -350,46 +428,43 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file if __name__ == "__main__": ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" + import signal + import time + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + try: + args = config() - args = config() - # save args to json in result_dir/action_space/observation_type/model/args.json - path_to_args = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - "args.json", - ) - os.makedirs(os.path.dirname(path_to_args), exist_ok=True) - with open(path_to_args, "w", encoding="utf-8") as f: - json.dump(vars(args), f, indent=4) + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) - with open(args.test_all_meta_path, "r", encoding="utf-8") as f: - test_all_meta = json.load(f) + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} - if args.domain != "all": - test_all_meta = {args.domain: test_all_meta[args.domain]} + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") - test_file_list = get_unfinished( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - left_info = "" - for domain in test_file_list: - left_info += f"{domain}: {len(test_file_list[domain])}\n" - logger.info(f"Left tasks:\n{left_info}") - - get_result( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - test(args, test_file_list) - - -# path_to_vm can be a list["xxx","xxx"] \ No newline at end of file + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + signal_handler(signal.SIGTERM, None) + finally: + logger.info("Main process final cleanup...") \ No newline at end of file