feat: add client password argument to multiple agents and scripts

- Introduced `--client_password` argument in `run_multienv_aguvis.py`, `run_multienv_claude.py`, and `run_multienv_gta1.py` for enhanced security and flexibility.
- Updated agent classes (`PromptAgent`, `AguvisAgent`, `GTA1Agent`) to accept and utilize `client_password` for improved configuration.
- Modified evaluation guidelines to reflect the new client password requirement.
- Ensured existing logic remains intact while enhancing functionality for better user experience.
This commit is contained in:
yuanmengqi
2025-07-27 16:11:23 +00:00
parent 122b16742b
commit 523d553e88
9 changed files with 627 additions and 28 deletions

View File

@@ -270,6 +270,7 @@ Use the `run_multienv_xxx.py` scripts to launch tasks in parallel.
Example (with the OpenAI CUA agent):
```bash
# --client_password set to the one you set to the client machine
# Run OpenAI CUA
python run_multienv_openaicua.py \
--headless \
@@ -279,7 +280,8 @@ python run_multienv_openaicua.py \
--test_all_meta_path evaluation_examples/test_all.json \
--region us-east-1 \
--max_steps 50 \
--num_envs 5
--num_envs 5 \
--client_password osworld-public-evaluation
# Run Anthropic (via AWS Bedrock), please modify agent if you want Anthropic endpoint
python run_multienv_claude.py \
@@ -291,7 +293,8 @@ python run_multienv_claude.py \
--test_all_meta_path evaluation_examples/test_all.json \
--max_steps 50 \
--num_envs 5 \
--provider_name aws
--provider_name aws \
--client_password osworld-public-evaluation
```
Key Parameters:
@@ -330,7 +333,7 @@ For more, see: [MONITOR_README](./monitor/README.md)
### 4.2 VNC Remote Desktop Access
We pre-install vnc for every virtual machine so you can have a look on it during the running.
You can access via VNC at`http://<client-public-ip>:5910/vnc.html`
The password set default is `osworld-public-evaluation` to prevent attack.
The password set default is `osworld-public-evaluation` in our AMI to prevent attack.
## 5. Contact the team to update leaderboard and fix errors (optional)

View File

@@ -235,7 +235,8 @@ class PromptAgent:
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=3,
a11y_tree_max_tokens=10000
a11y_tree_max_tokens=10000,
client_password="password"
):
self.platform = platform
self.model = model
@@ -246,6 +247,7 @@ class PromptAgent:
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.a11y_tree_max_tokens = a11y_tree_max_tokens
self.client_password = client_password
self.thoughts = []
self.actions = []
@@ -281,6 +283,8 @@ class PromptAgent:
raise ValueError("Invalid action space: " + action_space)
else:
raise ValueError("Invalid experiment type: " + observation_type)
self.system_message = self.system_message.format(CLIENT_PASSWORD=self.client_password)
def predict(self, instruction: str, obs: Dict) -> List:
"""

View File

@@ -360,6 +360,7 @@ class AguvisAgent:
temperature=0.5,
action_space="pyautogui",
observation_type="screenshot",
client_password="password"
):
self.platform = platform
self.planner_model = planner_model
@@ -372,6 +373,8 @@ class AguvisAgent:
self.observation_type = observation_type
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
self.client_password = client_password
self.thoughts = []
self.actions = []
self.observations = []
@@ -429,7 +432,7 @@ class AguvisAgent:
# So we temporarily separate the planner prompt and aguvis prompt.
planner_messages = []
planner_system_message = AGUVIS_PLANNER_SYS_PROMPT
planner_system_message = AGUVIS_PLANNER_SYS_PROMPT.format(CLIENT_PASSWORD=self.client_password)
planner_messages.append({
"role": "system",
"content": [{"type": "text", "text": planner_system_message}]

View File

@@ -45,6 +45,8 @@ GTA1_MODEL_NMAE = os.environ.get("GTA1_API_KEY",None) #Your served model name
GTA1_SERVICE_URL = os.environ.get("GTA1_SERVICE_URL",None) #"Your GTA1 Service URL"
proxies = None # Your proxies
MAX_RETRY_TIMES = 20
def encode_image(image_content):
return base64.b64encode(image_content).decode("utf-8")
@@ -1126,17 +1128,16 @@ def call_llm_safe(agent):
functions borrow from https://github.com/simular-ai/Agent-S/blob/a0c5c9bf0c526119b1f023c8948563c780729428/gui_agents/s2/utils/common_utils.py#L27
'''
# Retry if fails
max_retries = 3 # Set the maximum number of retries
attempt = 0
response = ""
while attempt < max_retries:
while attempt < MAX_RETRY_TIMES:
try:
response = agent.get_response()
break # If successful, break out of the loop
except Exception as e:
attempt += 1
print(f"Attempt {attempt} failed: {e}")
if attempt == max_retries:
if attempt == MAX_RETRY_TIMES:
print("Max retries reached. Handling failure.")
time.sleep(1.0)
return response
@@ -1200,11 +1201,13 @@ class GTA1Agent:
max_steps=100,
max_image_history_length = 5,
N_SEQ = 8,
client_password="password"
):
self.platform = platform
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.client_password = client_password
self.action_space = action_space
self.observation_type = observation_type
assert action_space in ["pyautogui"], "Invalid action space"
@@ -1343,7 +1346,7 @@ class GTA1Agent:
valid_responses.extend(valid_responses_)
retry_count += 1
assert len(valid_responses) > int(self.N_SEQ) * 0.8, f"Not enough valid responses generated {len(valid_responses)}"
# assert len(valid_responses) > int(self.N_SEQ) * 0.8, f"Not enough valid responses generated {len(valid_responses)}"
logger.info(f"Executing selection")
if self.N_SEQ > 1:
@@ -1438,7 +1441,7 @@ class GTA1Agent:
)
image = screenshot.resize((height, width))
system_promt = GTA1_JUDGE_SYSTEM_PROMPT.format(N_PLANNING=len(response), N_INDEX=len(response)-1,width=width,height=height)
system_promt = GTA1_JUDGE_SYSTEM_PROMPT.format(N_PLANNING=len(response), N_INDEX=len(response)-1,width=width,height=height, CLIENT_PASSWORD=self.client_password)
lines = [
f"The goal of the task is:\n{instruction}",
]
@@ -1482,7 +1485,7 @@ class GTA1Agent:
}
wait = 1
for _ in range(10):
for _ in range(MAX_RETRY_TIMES):
try:
prediction = requests.post(url, headers=headers, json=payload, proxies=proxies, timeout=180)
if prediction.status_code != 200:

View File

@@ -644,7 +644,7 @@ class OpenAICUAAgent:
"""
Predict the next action(s) based on the current observation.
"""
prompt = OPERATOR_PROMPT.replace("{CLIENT_PASSWORD}", self.client_password)
prompt = OPERATOR_PROMPT.format(CLIENT_PASSWORD=self.client_password)
base64_image = encode_image(obs["screenshot"])
if self.cua_messages == []:

View File

@@ -15,7 +15,7 @@ When you think you have to wait for some time, return ```WAIT```;
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
When you think the task is done, return ```DONE```.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
""".strip()
@@ -36,7 +36,7 @@ When you think you have to wait for some time, return ```WAIT```;
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
When you think the task is done, return ```DONE```.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
Our past communication is great, and what you have done is very helpful. I will now give you another task to complete.
First take a deep breath, think step by step, give the current screenshot a thinking, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
""".strip()
@@ -550,7 +550,7 @@ When you think you have to wait for some time, return ```WAIT```;
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
When you think the task is done, return ```DONE```.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
""".strip()
@@ -817,7 +817,7 @@ When you think you have to wait for some time, return ```WAIT```;
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
When you think the task is done, return ```DONE```.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
""".strip()
@@ -1092,7 +1092,7 @@ When you think you have to wait for some time, return ```WAIT```;
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
When you think the task is done, return ```DONE```.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
""".strip()
@@ -1142,7 +1142,7 @@ When you think you have to wait for some time, return ```WAIT```;
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
When you think the task is done, return ```DONE```.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
"""
@@ -1168,7 +1168,7 @@ Here are some guidelines for you:
2. If a click action is needed, use only the following functions: pyautogui.click, pyautogui.rightClick or pyautogui.doubleClick.
3. Return ```Done``` when you think the task is done. Return ```Fail``` when you think the task can not be done.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
""".strip()
@@ -1335,7 +1335,7 @@ Here are some guidelines for you:
2. If a click action is needed, use only the following functions: pyautogui.click, pyautogui.rightClick or pyautogui.doubleClick.
3. Return ```Done``` when you think the task is done. Return ```Fail``` when you think the task can not be done.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR NEVER EVER RETURN ME ANYTHING ELSE.
"""
@@ -1475,7 +1475,7 @@ Here are some guidelines for you:
5. Save modified files before returning ```agent.done()```. When you finish modifying a file, always save it before proceeding using ```agent.hotkey(['ctrl', 's'])``` or equivalent. Tasks may involve multiple files. Save each after finishing modification.
6. If you meet "Authentication required" prompt, you can continue to click "Cancel" to close it.
My computer's password is 'password', feel free to use it when you need sudo rights.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE I ASKED FOR NEVER EVER RETURN ME ANYTHING ELSE."""
GTA1_GROUNDING_SYSTEM_PROMPT = '''
@@ -1488,7 +1488,7 @@ Output the coordinate pair exactly:
GTA1_JUDGE_SYSTEM_PROMPT='''
You are an expert at evaluating the planning and reasoning of UI agents working toward achieving a goal.
My computer's password is 'password', feel free to use it when you need sudo rights or login.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights or login.
Each time, I will provide you with:
- The current screenshot of the UI of width {width} and height {height}
@@ -1517,3 +1517,56 @@ Respond **only** with valid JSON (no extra keys or comments):
}}
```
'''.strip()
O3_SYSTEM_PROMPT = """
You are an agent which follow my instruction and perform desktop computer tasks as instructed.
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
You are on Ubuntu operating system and the resolution of the screen is 1920x1080.
For each step, you will get an observation of an image, which is the screenshot of the computer screen and you will predict the action of the computer based on the image.
The following rules are IMPORTANT:
- If previous actions didn't achieve the expected result, do not repeat them, especially the last one. Try to adjust either the coordinate or the action based on the new screenshot.
- Do not predict multiple clicks at once. Base each action on the current screenshot; do not predict actions for elements or events not yet visible in the screenshot.
- You cannot complete the task by outputting text content in your response. You must use mouse and keyboard to interact with the computer. Return ```Fail``` when you think the task can not be done.
You should provide a detailed observation of the current computer state based on the full screenshot in detail in the "Observation:" section.
Provide any information that is possibly relevant to achieving the task goal and any elements that may affect the task execution, such as pop-ups, notifications, error messages, loading states, etc..
You MUST return the observation before the thought.
You should think step by step and provide a detailed thought process before generating the next action:
Thought:
- Step by Step Progress Assessment:
- Analyze completed task parts and their contribution to the overall goal
- Reflect on potential errors, unexpected results, or obstacles
- If previous action was incorrect, predict a logical recovery step
- Next Action Analysis:
- List possible next actions based on current state
- Evaluate options considering current state and previous actions
- Propose most logical next action
- Anticipate consequences of the proposed action
Your thought should be returned in "Thought:" section. You MUST return the thought before the code.
You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot.
Return exactly ONE line of python code to perform the action each time. At each step, you MUST generate the corresponding instruction to the code before a # in a comment (example: # Click \"Yes, I trust the authors\" button\npyautogui.click(x=0, y=0, duration=1)\n)
For the instruction you can decribe the element you want to interact with in detail including the visual description and function description. And make it clear and concise.
For example you can describe what the element looks like, and what will be the expected result when you interact with it.
You need to to specify the coordinates of by yourself based on your observation of current observation, but you should be careful to ensure that the coordinates are correct.
Remember you should only return ONE line of code, DO NOT RETURN more. You should return the code inside a code block, like this:
```python
# your code here
```
Specially, it is also allowed to return the following special code:
When you think you have to wait for some time, return ```WAIT```;
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
When you think the task is done, return ```DONE```.
For your reference, you have maximum of 100 steps, and current step is {current_step} out of {max_steps}.
If you are in the last step, you should return ```DONE``` or ```FAIL``` according to the result.
Here are some guidelines for you:
1. Remember to generate the corresponding instruction to the code before a # in a comment and only return ONE line of code.
2. If a click action is needed, use only the following functions: pyautogui.click, pyautogui.rightClick or pyautogui.doubleClick.
3. Return ```Done``` when you think the task is done. Return ```Fail``` when you think the task can not be done.
My computer's password is '{CLIENT_PASSWORD}', feel free to use it when you need sudo rights.
First give the current screenshot and previous things we did a short reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR NEVER EVER RETURN ME ANYTHING ELSE.
"""

View File

@@ -80,6 +80,12 @@ def config() -> argparse.Namespace:
default="screenshot",
help="Observation type",
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=2.0)
@@ -216,6 +222,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
temperature=args.temperature,
action_space=args.action_space,
observation_type=args.observation_type,
client_password=args.client_password
)
agents.append(agent)
@@ -227,7 +234,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
os_type="Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
provider_name = "docker"
provider_name = args.provider_name,
client_password=args.client_password
)
envs.append(env)

View File

@@ -14,8 +14,7 @@ from tqdm import tqdm
from multiprocessing import Process, Manager, current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.anthropic import AnthropicAgent as PromptAgent
from mm_agents.anthropic import AnthropicAgent
# .env
from dotenv import load_dotenv
@@ -152,7 +151,7 @@ def run_env_tasks(task_queue, args, shared_scores):
client_password=args.client_password
)
active_environments.append(env)
agent = PromptAgent(
agent = AnthropicAgent(
env=env,
model=args.model,
max_tokens=args.max_tokens,
@@ -161,10 +160,9 @@ def run_env_tasks(task_queue, args, shared_scores):
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
client_password=args.client_password,
provider_name=args.provider_name,
screen_width=args.screen_width,
screen_height=args.screen_height
screen_height=args.screen_height,
)
logger.info(f"Process {current_process().name} started.")
while True:

527
run_multienv_gta1.py Normal file
View File

@@ -0,0 +1,527 @@
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.gta1_agent import GTA1Agent
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# import wandb
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="o3")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
agent = GTA1Agent(
max_steps=args.max_steps,
client_password=args.client_password,
)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")