diff --git a/lib_run_single.py b/lib_run_single.py
index dda5407..51518e2 100644
--- a/lib_run_single.py
+++ b/lib_run_single.py
@@ -172,9 +172,7 @@ def run_single_example_opencua(agent, env, example, max_steps, instruction, args
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
- obs, reward, done, info = env.step(action)
- time.sleep(3)
- obs = env._get_obs()
+ obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info(f"Action {action} executed, reward: {reward}, done: {done}")
# Save screenshot and trajectory information
diff --git a/mm_agents/opencua_agent.py b/mm_agents/opencua_agent.py
index c1ddb60..4e1b371 100644
--- a/mm_agents/opencua_agent.py
+++ b/mm_agents/opencua_agent.py
@@ -571,11 +571,6 @@ class OpenCUAAgent:
logger.info(f"========================== {self.model} ===================================")
logger.info(f"Instruction: \n{instruction}")
- image_bytes = BytesIO(obs['screenshot'])
- with Image.open(image_bytes) as img:
- print("Actual screen size", img.size)
- print("Logical screen size", self.screen_size)
-
messages = []
messages.append({
"role": "system",
@@ -598,7 +593,7 @@ class OpenCUAAgent:
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
observation=self.cots[i].get('observation'),
thought=self.cots[i].get('thought'),
- action=self.cots[i]['action']
+ action=self.cots[i].get('action')
)
messages.append({
@@ -609,7 +604,7 @@ class OpenCUAAgent:
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
observation=self.cots[i].get('observation'),
thought=self.cots[i].get('thought'),
- action=self.cots[i]['action']
+ action=self.cots[i].get('action')
)
history_step_texts.append(history_content)
if i == len(self.actions) - self.max_image_history_length:
@@ -640,7 +635,7 @@ class OpenCUAAgent:
"temperature": self.temperature
}, self.model)
- logger.info(f"Model Output: \n\n{response}")
+ logger.info(f"Model Output: \n{response}")
if not response:
logger.error("No response found in the response.")
return "ERROR", [], {}
@@ -666,23 +661,23 @@ class OpenCUAAgent:
self.cots.append(other_cot)
# Print message structure if needed
- logger.info(f"\nInstruction: {instruction}")
- messages_to_print = []
- current_image = 1
- for msg in messages:
- msg_copy = copy.deepcopy(msg)
- if isinstance(msg_copy['content'], list):
- for content in msg_copy['content']:
- if content['type'] == 'image_url':
- content['image_url']['url'] = f'Image {current_image}'
- current_image += 1
- messages_to_print.append(msg_copy)
+ # messages_to_print = []
+ # current_image = 1
+ # for msg in messages:
+ # msg_copy = copy.deepcopy(msg)
+ # if isinstance(msg_copy['content'], list):
+ # for content in msg_copy['content']:
+ # if content['type'] == 'image_url':
+ # content['image_url']['url'] = f'Image {current_image}'
+ # current_image += 1
+ # messages_to_print.append(msg_copy)
- messages_to_print.append({
- "new_step_cot": other_cot,
- "response": response
- })
- logger.info(json.dumps(messages_to_print, indent=2))
+ # messages_to_print.append({
+ # "new_step_cot": other_cot,
+ # "response": response
+ # })
+ # logger.info(json.dumps(messages_to_print, indent=2))
+ logger.info(f"New step cot: {other_cot}")
return response, pyautogui_actions, {}
@@ -720,4 +715,10 @@ class OpenCUAAgent:
logger.error("Retrying...")
time.sleep(5)
else:
- return response.json()['choices'][0]['message']['content']
+ response = response.json()
+ finish_reason = response["choices"][0].get("finish_reason")
+ if finish_reason is not None and finish_reason == "stop": # for most of the time, length will not exceed max_tokens
+ return response['choices'][0]['message']['content']
+ else:
+ logger.error("LLM did not finish properly, retrying...")
+ time.sleep(5)
diff --git a/mm_agents/uitars15_agent.py b/mm_agents/uitars15_agent.py
index d8213b7..34280cd 100644
--- a/mm_agents/uitars15_agent.py
+++ b/mm_agents/uitars15_agent.py
@@ -3,7 +3,8 @@ import os
import re
import base64
import requests
-from typing import Optional, Dict, List, Tuple
+import logging
+from typing import Optional, Dict, List, Tuple, Union
from loguru import logger
import ast
@@ -573,7 +574,34 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
GROUNDING_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. \n\n## Output Format\n\nAction: ...\n\n\n## Action Space\nclick(point='x1 y1'')\n\n## User Instruction
{instruction}"""
+COMPUTER_USE_NO_THINKING = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
+## Output Format
+```
+Thought: ...
+Action: ...
+```
+
+## Action Space
+
+click(point='x1 y1')
+left_double(point='x1 y1')
+right_single(point='x1 y1')
+drag(start_point='x1 y1', end_point='x2 y2')
+hotkey(key='ctrl c') # Split keys with a space and use lowercase. Also, do not use more than 3 keys in one hotkey action.
+type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
+scroll(point='x1 y1', direction='down or up or right or left') # Show more information on the `direction` side.
+wait() #Sleep for 5s and take a screenshot to check for any changes.
+finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
+
+
+## Note
+- Use Chinese in `Thought` part.
+- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
+
+## User Instruction
+{instruction}
+"""
class UITarsAgent:
"""
@@ -638,9 +666,11 @@ class UITarsAgent:
self.history_images = []
self.history_responses = []
- self.system_prompt = COMPUTER_USE_DOUBAO
+ if use_thinking:
+ self.system_prompt = COMPUTER_USE_DOUBAO
+ else:
+ self.system_prompt = COMPUTER_USE_NO_THINKING
-
self.action_parse_res_factor = 1000
self.model_type = "doubao"
self.history_n = 5
@@ -648,6 +678,9 @@ class UITarsAgent:
self.temperature = temperature
self.max_tokens = max_tokens
self.platform = "ubuntu"
+ self.use_thinking = use_thinking
+
+ self.inference_func = self.inference_with_thinking if use_thinking else self.inference_without_thinking
def reset(self, _logger=None):
global logger
@@ -721,7 +754,36 @@ class UITarsAgent:
"details": response.text
}
- def predict(self, task_instruction: str, obs: dict) -> Tuple[str, List]:
+ def inference_without_thinking(self, messages):
+ api_key = os.environ['DOUBAO_API_KEY']
+ api_url = os.environ['DOUBAO_API_URL']
+ headers = {
+ 'Authorization': f'Bearer {api_key}',
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "model": self.model,
+ "messages": messages,
+ "thinking": {"type": "disabled"},
+ "max_tokens": self.max_tokens,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ }
+
+ response = requests.post(api_url, headers=headers, json=data)
+
+
+ if response.status_code == 200:
+ return response.json()["choices"][0]["message"]["content"]
+ else:
+ print(f"Request failed with status code {response.status_code}")
+ print(response.json())
+ return {
+ "error": f"Request failed with status code {response.status_code}",
+ "details": response.text
+ }
+
+ def predict(self, task_instruction: str, obs: dict) -> Tuple[Union[str, Dict, None], List]:
"""Predict the next action based on the current observation."""
self.task_instruction = task_instruction
@@ -793,7 +855,7 @@ class UITarsAgent:
return prediction, ["FAIL"]
try:
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
- prediction = self.inference_with_thinking(messages)
+ prediction = self.inference_func(messages)
except Exception as e:
self.logger.error(f"Error when fetching response from client, with error:\n{e}")
diff --git a/run_multienv_opencua.py b/run_multienv_opencua.py
index 1765976..6f8af63 100644
--- a/run_multienv_opencua.py
+++ b/run_multienv_opencua.py
@@ -11,6 +11,7 @@ from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
+from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.opencua_agent import OpenCUAAgent
@@ -45,7 +46,7 @@ def config() -> argparse.Namespace:
default="screenshot",
help="Observation type",
)
- parser.add_argument("--sleep_after_execution", type=float, default=0.0)
+ parser.add_argument("--sleep_after_execution", type=float, default=3.0)
parser.add_argument("--max_steps", type=int, default=15)
# evaluation config
@@ -57,7 +58,7 @@ def config() -> argparse.Namespace:
parser.add_argument("--model", type=str, default="opencua")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
- parser.add_argument("--max_tokens", type=int, default=1500)
+ parser.add_argument("--max_tokens", type=int, default=8196)
parser.add_argument("--stop_token", type=str, default=None)
# OpenCUAagent config
@@ -133,32 +134,12 @@ logger.addHandler(stdout_handler)
logger = logging.getLogger("desktopenv.experiment")
-def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
- """Distribute tasks evenly across environments."""
- # Flatten the tasks into a single list
+def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
-
- # Calculate tasks per environment
- tasks_per_env = math.ceil(len(all_tasks) / num_envs)
-
- # Distribute tasks
- distributed_tasks = []
- for i in range(num_envs):
- env_tasks = {}
- start_idx = i * tasks_per_env
- end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
-
- for domain, example_id in all_tasks[start_idx:end_idx]:
- if domain not in env_tasks:
- env_tasks[domain] = []
- env_tasks[domain].append(example_id)
-
- distributed_tasks.append(env_tasks)
-
- return distributed_tasks
+ return all_tasks
def process_signal_handler(signum, frame, env_idx):
@@ -182,51 +163,45 @@ def process_signal_handler(signum, frame, env_idx):
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
-
-def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
- """Run tasks for a single environment."""
- # Each process has its own list of active environments
+def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
-
- # Setup signal handlers for this process too
- signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
- signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
-
- from desktop_env.providers.aws.manager import IMAGE_ID_MAP
- REGION = args.region
- screen_size = (args.screen_width, args.screen_height)
- ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
- env = DesktopEnv(
- path_to_vm=args.path_to_vm,
- action_space=args.action_space,
- provider_name=args.provider_name,
- region=REGION,
- snapshot_name=ami_id,
- screen_size=screen_size,
- headless=args.headless,
- os_type="Ubuntu",
- require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
- enable_proxy=True,
- client_password=args.client_password
- )
- active_environments.append(env)
-
- logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
-
try:
- for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
- for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
+ from desktop_env.providers.aws.manager import IMAGE_ID_MAP
+ REGION = args.region
+ screen_size = (args.screen_width, args.screen_height)
+ ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
+ env = DesktopEnv(
+ path_to_vm=args.path_to_vm,
+ action_space=args.action_space,
+ provider_name=args.provider_name,
+ region=REGION,
+ snapshot_name=ami_id,
+ screen_size=screen_size,
+ headless=args.headless,
+ os_type="Ubuntu",
+ require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
+ enable_proxy=True,
+ client_password=args.client_password
+ )
+ active_environments.append(env)
+
+ logger.info(f"Process {current_process().name} started.")
+ while True:
+ try:
+ item = task_queue.get(timeout=5)
+ except Exception:
+ break
+ domain, example_id = item
+ try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
-
- logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
- logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
- logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
-
+ logger.info(f"[{current_process().name}][Domain]: {domain}")
+ logger.info(f"[{current_process().name}][Example ID]: {example_id}")
+ logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
@@ -236,7 +211,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
-
agent = OpenCUAAgent(
env=env,
model=args.model,
@@ -251,7 +225,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
coordinate_type=args.coordinate_type,
max_image_history_length=args.max_image_history_length,
)
-
try:
lib_run_single.run_single_example_opencua(
agent,
@@ -265,7 +238,7 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
except Exception as e:
import traceback
- logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
+ logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
@@ -273,7 +246,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
-
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
@@ -281,15 +253,23 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
)
f.write("\n")
+ except Exception as e:
+ logger.error(f"Task-level error in {current_process().name}: {e}")
+ import traceback
+ logger.error(traceback.format_exc())
+ except Exception as e:
+ logger.error(f"Process-level error in {current_process().name}: {e}")
+ import traceback
+ logger.error(traceback.format_exc())
finally:
- # This ensures the environment is closed even if there's an exception
- logger.info(f"Process {env_idx + 1} cleaning up environment...")
+ logger.info(f"{current_process().name} cleaning up environment...")
try:
- env.close()
- logger.info(f"Process {env_idx + 1} environment closed successfully")
+ if env:
+ env.close()
+ logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
- logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}")
-
+ logger.error(f"{current_process().name} error during environment cleanup: {e}")
+
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
@@ -328,8 +308,8 @@ def signal_handler(signum, frame):
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
- import signal
- os.kill(p.pid, signal.SIGKILL)
+ import signal as sig
+ os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
@@ -340,38 +320,56 @@ def signal_handler(signum, frame):
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
-
- distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
-
- logger.info("All environments are ready. Starting parallel task execution...")
-
- # Create a shared list for scores across processes
+ all_tasks = distribute_tasks(test_all_meta)
+ logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
-
- # Create and start processes for each environment
+ task_queue = manager.Queue()
+ for item in all_tasks:
+ task_queue.put(item)
+ num_envs = args.num_envs
processes = []
- for env_idx, env_tasks in enumerate(distributed_tasks):
+ for i in range(num_envs):
p = Process(
target=run_env_tasks,
- args=(env_idx, env_tasks, args, shared_scores)
+ args=(task_queue, args, shared_scores),
+ name=f"EnvProcess-{i+1}"
)
- processes.append(p)
+ p.daemon = True
p.start()
+ processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
-
try:
- # Wait for all processes to complete
+ while True:
+ alive_count = 0
+ for idx, p in enumerate(processes):
+ if not p.is_alive():
+ logger.warning(f"Process {p.name} died, restarting...")
+ new_p = Process(
+ target=run_env_tasks,
+ args=(task_queue, args, shared_scores),
+ name=f"EnvProcess-Restart-{idx+1}"
+ )
+ new_p.daemon = True
+ new_p.start()
+ processes[idx] = new_p
+ logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
+ else:
+ alive_count += 1
+ if task_queue.empty():
+ logger.info("All tasks finished.")
+ break
+ if alive_count == 0:
+ logger.error("All processes died, exiting.")
+ break
+ time.sleep(5)
for p in processes:
p.join()
- logger.info(f"Process {p.name} completed")
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
- # Let the signal handler do the cleanup
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
- # Ensure cleanup happens
for p in processes:
if p.is_alive():
try:
@@ -380,10 +378,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
-
- # Convert shared list to regular list
scores = list(shared_scores)
-
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
@@ -469,6 +464,18 @@ if __name__ == "__main__":
try:
args = config()
+
+ # save args to json in result_dir/action_space/observation_type/model/args.json
+ path_to_args = os.path.join(
+ args.result_dir,
+ args.action_space,
+ args.observation_type,
+ args.model,
+ "args.json",
+ )
+ os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
+ with open(path_to_args, "w", encoding="utf-8") as f:
+ json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
diff --git a/run_multienv_uitars15.py b/run_multienv_uitars15.py
index 0884dc2..c041b3b 100644
--- a/run_multienv_uitars15.py
+++ b/run_multienv_uitars15.py
@@ -11,10 +11,29 @@ from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
+from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.uitars15_agent import UITarsAgent
+import shutil
+import os
+
+# def clear_cache():
+# cache_path = "cache"
+
+# try:
+# if os.path.exists(cache_path):
+# logger.info(f"Deleting cache directory: {cache_path}")
+# shutil.rmtree(cache_path)
+# logger.info(f"Cache directory deleted successfully")
+# else:
+# logger.info(f"Cache directory {cache_path} does not exist")
+# except Exception as e:
+# logger.error(f"Error deleting cache directory: {e}")
+
+# clear_cache()
+
# Global variables for signal handling
active_environments = []
processes = []
@@ -45,7 +64,7 @@ def config() -> argparse.Namespace:
default="screenshot",
help="Observation type",
)
- parser.add_argument("--sleep_after_execution", type=float, default=0)
+ parser.add_argument("--sleep_after_execution", type=float, default=3.0)
parser.add_argument("--max_steps", type=int, default=15)
# evaluation config
@@ -58,6 +77,7 @@ def config() -> argparse.Namespace:
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--max_tokens", type=int, default=3000)
+ parser.add_argument("--use_thinking", action="store_true", default=False)
# OpenCUAagent config
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
@@ -131,32 +151,12 @@ logger.addHandler(stdout_handler)
logger = logging.getLogger("desktopenv.experiment")
-def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
- """Distribute tasks evenly across environments."""
- # Flatten the tasks into a single list
+def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
-
- # Calculate tasks per environment
- tasks_per_env = math.ceil(len(all_tasks) / num_envs)
-
- # Distribute tasks
- distributed_tasks = []
- for i in range(num_envs):
- env_tasks = {}
- start_idx = i * tasks_per_env
- end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
-
- for domain, example_id in all_tasks[start_idx:end_idx]:
- if domain not in env_tasks:
- env_tasks[domain] = []
- env_tasks[domain].append(example_id)
-
- distributed_tasks.append(env_tasks)
-
- return distributed_tasks
+ return all_tasks
def process_signal_handler(signum, frame, env_idx):
@@ -180,62 +180,55 @@ def process_signal_handler(signum, frame, env_idx):
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
-
-def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
- """Run tasks for a single environment."""
- # Each process has its own list of active environments
+def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
-
- # Setup signal handlers for this process too
- signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
- signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
-
- from desktop_env.providers.aws.manager import IMAGE_ID_MAP
- REGION = args.region
- screen_size = (args.screen_width, args.screen_height)
- ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
- env = DesktopEnv(
- path_to_vm=args.path_to_vm,
- action_space=args.action_space,
- provider_name=args.provider_name,
- region=REGION,
- snapshot_name=ami_id,
- screen_size=screen_size,
- headless=args.headless,
- os_type="Ubuntu",
- require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
- enable_proxy=True,
- client_password=args.client_password
- )
- active_environments.append(env)
- agent = UITarsAgent(
- model=args.model,
- max_tokens=args.max_tokens,
- top_p=args.top_p,
- temperature=args.temperature,
-
- max_trajectory_length=args.max_trajectory_length,
- max_image_history_length=args.max_image_history_length,
- use_thinking=True,
- language=args.language,
- )
-
- logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
-
try:
- for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
- for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
+ from desktop_env.providers.aws.manager import IMAGE_ID_MAP
+ REGION = args.region
+ screen_size = (args.screen_width, args.screen_height)
+ ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
+ env = DesktopEnv(
+ path_to_vm=args.path_to_vm,
+ action_space=args.action_space,
+ provider_name=args.provider_name,
+ region=REGION,
+ snapshot_name=ami_id,
+ screen_size=screen_size,
+ headless=args.headless,
+ os_type="Ubuntu",
+ require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
+ enable_proxy=True,
+ client_password=args.client_password
+ )
+ active_environments.append(env)
+ agent = UITarsAgent(
+ model=args.model,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ temperature=args.temperature,
+
+ max_trajectory_length=args.max_trajectory_length,
+ max_image_history_length=args.max_image_history_length,
+ use_thinking=args.use_thinking,
+ language=args.language,
+ )
+ logger.info(f"Process {current_process().name} started.")
+ while True:
+ try:
+ item = task_queue.get(timeout=5)
+ except Exception:
+ break
+ domain, example_id = item
+ try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
-
- logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
- logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
- logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
-
+ logger.info(f"[{current_process().name}][Domain]: {domain}")
+ logger.info(f"[{current_process().name}][Example ID]: {example_id}")
+ logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
@@ -258,7 +251,7 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
except Exception as e:
import traceback
- logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
+ logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
@@ -266,7 +259,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
-
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
@@ -274,14 +266,23 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
)
f.write("\n")
+ except Exception as e:
+ logger.error(f"Task-level error in {current_process().name}: {e}")
+ import traceback
+ logger.error(traceback.format_exc())
+ except Exception as e:
+ logger.error(f"Process-level error in {current_process().name}: {e}")
+ import traceback
+ logger.error(traceback.format_exc())
finally:
- # This ensures the environment is closed even if there's an exception
- logger.info(f"Process {env_idx + 1} cleaning up environment...")
+ logger.info(f"{current_process().name} cleaning up environment...")
try:
- env.close()
- logger.info(f"Process {env_idx + 1} environment closed successfully")
+ if env:
+ env.close()
+ logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
- logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}")
+ logger.error(f"{current_process().name} error during environment cleanup: {e}")
+
def signal_handler(signum, frame):
@@ -321,8 +322,8 @@ def signal_handler(signum, frame):
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
- import signal
- os.kill(p.pid, signal.SIGKILL)
+ import signal as sig
+ os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
@@ -333,38 +334,56 @@ def signal_handler(signum, frame):
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
-
- distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
-
- logger.info("All environments are ready. Starting parallel task execution...")
-
- # Create a shared list for scores across processes
+ all_tasks = distribute_tasks(test_all_meta)
+ logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
-
- # Create and start processes for each environment
+ task_queue = manager.Queue()
+ for item in all_tasks:
+ task_queue.put(item)
+ num_envs = args.num_envs
processes = []
- for env_idx, env_tasks in enumerate(distributed_tasks):
+ for i in range(num_envs):
p = Process(
target=run_env_tasks,
- args=(env_idx, env_tasks, args, shared_scores)
+ args=(task_queue, args, shared_scores),
+ name=f"EnvProcess-{i+1}"
)
- processes.append(p)
+ p.daemon = True
p.start()
+ processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
-
try:
- # Wait for all processes to complete
+ while True:
+ alive_count = 0
+ for idx, p in enumerate(processes):
+ if not p.is_alive():
+ logger.warning(f"Process {p.name} died, restarting...")
+ new_p = Process(
+ target=run_env_tasks,
+ args=(task_queue, args, shared_scores),
+ name=f"EnvProcess-Restart-{idx+1}"
+ )
+ new_p.daemon = True
+ new_p.start()
+ processes[idx] = new_p
+ logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
+ else:
+ alive_count += 1
+ if task_queue.empty():
+ logger.info("All tasks finished.")
+ break
+ if alive_count == 0:
+ logger.error("All processes died, exiting.")
+ break
+ time.sleep(5)
for p in processes:
p.join()
- logger.info(f"Process {p.name} completed")
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
- # Let the signal handler do the cleanup
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
- # Ensure cleanup happens
for p in processes:
if p.is_alive():
try:
@@ -373,10 +392,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
-
- # Convert shared list to regular list
scores = list(shared_scores)
-
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
@@ -462,6 +478,18 @@ if __name__ == "__main__":
try:
args = config()
+
+ # save args to json in result_dir/action_space/observation_type/model/args.json
+ path_to_args = os.path.join(
+ args.result_dir,
+ args.action_space,
+ args.observation_type,
+ args.model,
+ "args.json",
+ )
+ os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
+ with open(path_to_args, "w", encoding="utf-8") as f:
+ json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)