Wxy/opencua (#274)

* OpenCUA Agent code base

* update url

* debug, modify url input

* debug opencua

* show result

* debug agent history overlap

* modify opencua agent; add comment lines

* update parallel; clean code; use sleep 3s

* ui-tars-0717
This commit is contained in:
Xinyuan Wang
2025-07-20 15:52:23 +08:00
committed by GitHub
parent bec7129fff
commit e10dd9267c
5 changed files with 320 additions and 224 deletions

View File

@@ -172,9 +172,7 @@ def run_single_example_opencua(agent, env, example, max_steps, instruction, args
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info = env.step(action)
time.sleep(3)
obs = env._get_obs()
obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info(f"Action {action} executed, reward: {reward}, done: {done}")
# Save screenshot and trajectory information

View File

@@ -571,11 +571,6 @@ class OpenCUAAgent:
logger.info(f"========================== {self.model} ===================================")
logger.info(f"Instruction: \n{instruction}")
image_bytes = BytesIO(obs['screenshot'])
with Image.open(image_bytes) as img:
print("Actual screen size", img.size)
print("Logical screen size", self.screen_size)
messages = []
messages.append({
"role": "system",
@@ -598,7 +593,7 @@ class OpenCUAAgent:
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
observation=self.cots[i].get('observation'),
thought=self.cots[i].get('thought'),
action=self.cots[i]['action']
action=self.cots[i].get('action')
)
messages.append({
@@ -609,7 +604,7 @@ class OpenCUAAgent:
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
observation=self.cots[i].get('observation'),
thought=self.cots[i].get('thought'),
action=self.cots[i]['action']
action=self.cots[i].get('action')
)
history_step_texts.append(history_content)
if i == len(self.actions) - self.max_image_history_length:
@@ -640,7 +635,7 @@ class OpenCUAAgent:
"temperature": self.temperature
}, self.model)
logger.info(f"Model Output: \n\n{response}")
logger.info(f"Model Output: \n{response}")
if not response:
logger.error("No response found in the response.")
return "ERROR", [], {}
@@ -666,23 +661,23 @@ class OpenCUAAgent:
self.cots.append(other_cot)
# Print message structure if needed
logger.info(f"\nInstruction: {instruction}")
messages_to_print = []
current_image = 1
for msg in messages:
msg_copy = copy.deepcopy(msg)
if isinstance(msg_copy['content'], list):
for content in msg_copy['content']:
if content['type'] == 'image_url':
content['image_url']['url'] = f'Image {current_image}'
current_image += 1
messages_to_print.append(msg_copy)
# messages_to_print = []
# current_image = 1
# for msg in messages:
# msg_copy = copy.deepcopy(msg)
# if isinstance(msg_copy['content'], list):
# for content in msg_copy['content']:
# if content['type'] == 'image_url':
# content['image_url']['url'] = f'Image {current_image}'
# current_image += 1
# messages_to_print.append(msg_copy)
messages_to_print.append({
"new_step_cot": other_cot,
"response": response
})
logger.info(json.dumps(messages_to_print, indent=2))
# messages_to_print.append({
# "new_step_cot": other_cot,
# "response": response
# })
# logger.info(json.dumps(messages_to_print, indent=2))
logger.info(f"New step cot: {other_cot}")
return response, pyautogui_actions, {}
@@ -720,4 +715,10 @@ class OpenCUAAgent:
logger.error("Retrying...")
time.sleep(5)
else:
return response.json()['choices'][0]['message']['content']
response = response.json()
finish_reason = response["choices"][0].get("finish_reason")
if finish_reason is not None and finish_reason == "stop": # for most of the time, length will not exceed max_tokens
return response['choices'][0]['message']['content']
else:
logger.error("LLM did not finish properly, retrying...")
time.sleep(5)

View File

@@ -3,7 +3,8 @@ import os
import re
import base64
import requests
from typing import Optional, Dict, List, Tuple
import logging
from typing import Optional, Dict, List, Tuple, Union
from loguru import logger
import ast
@@ -573,7 +574,34 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
GROUNDING_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. \n\n## Output Format\n\nAction: ...\n\n\n## Action Space\nclick(point='<point>x1 y1</point>'')\n\n## User Instruction
{instruction}"""
COMPUTER_USE_NO_THINKING = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(point='<point>x1 y1</point>')
left_double(point='<point>x1 y1</point>')
right_single(point='<point>x1 y1</point>')
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>')
hotkey(key='ctrl c') # Split keys with a space and use lowercase. Also, do not use more than 3 keys in one hotkey action.
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') # Show more information on the `direction` side.
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use Chinese in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
class UITarsAgent:
"""
@@ -638,9 +666,11 @@ class UITarsAgent:
self.history_images = []
self.history_responses = []
self.system_prompt = COMPUTER_USE_DOUBAO
if use_thinking:
self.system_prompt = COMPUTER_USE_DOUBAO
else:
self.system_prompt = COMPUTER_USE_NO_THINKING
self.action_parse_res_factor = 1000
self.model_type = "doubao"
self.history_n = 5
@@ -648,6 +678,9 @@ class UITarsAgent:
self.temperature = temperature
self.max_tokens = max_tokens
self.platform = "ubuntu"
self.use_thinking = use_thinking
self.inference_func = self.inference_with_thinking if use_thinking else self.inference_without_thinking
def reset(self, _logger=None):
global logger
@@ -721,7 +754,36 @@ class UITarsAgent:
"details": response.text
}
def predict(self, task_instruction: str, obs: dict) -> Tuple[str, List]:
def inference_without_thinking(self, messages):
api_key = os.environ['DOUBAO_API_KEY']
api_url = os.environ['DOUBAO_API_URL']
headers = {
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json'
}
data = {
"model": self.model,
"messages": messages,
"thinking": {"type": "disabled"},
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
}
response = requests.post(api_url, headers=headers, json=data)
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
else:
print(f"Request failed with status code {response.status_code}")
print(response.json())
return {
"error": f"Request failed with status code {response.status_code}",
"details": response.text
}
def predict(self, task_instruction: str, obs: dict) -> Tuple[Union[str, Dict, None], List]:
"""Predict the next action based on the current observation."""
self.task_instruction = task_instruction
@@ -793,7 +855,7 @@ class UITarsAgent:
return prediction, ["FAIL"]
try:
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
prediction = self.inference_with_thinking(messages)
prediction = self.inference_func(messages)
except Exception as e:
self.logger.error(f"Error when fetching response from client, with error:\n{e}")

View File

@@ -11,6 +11,7 @@ from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.opencua_agent import OpenCUAAgent
@@ -45,7 +46,7 @@ def config() -> argparse.Namespace:
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
parser.add_argument("--max_steps", type=int, default=15)
# evaluation config
@@ -57,7 +58,7 @@ def config() -> argparse.Namespace:
parser.add_argument("--model", type=str, default="opencua")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--max_tokens", type=int, default=8196)
parser.add_argument("--stop_token", type=str, default=None)
# OpenCUAagent config
@@ -133,32 +134,12 @@ logger.addHandler(stdout_handler)
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
"""Distribute tasks evenly across environments."""
# Flatten the tasks into a single list
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
# Calculate tasks per environment
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
# Distribute tasks
distributed_tasks = []
for i in range(num_envs):
env_tasks = {}
start_idx = i * tasks_per_env
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
for domain, example_id in all_tasks[start_idx:end_idx]:
if domain not in env_tasks:
env_tasks[domain] = []
env_tasks[domain].append(example_id)
distributed_tasks.append(env_tasks)
return distributed_tasks
return all_tasks
def process_signal_handler(signum, frame, env_idx):
@@ -182,51 +163,45 @@ def process_signal_handler(signum, frame, env_idx):
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
"""Run tasks for a single environment."""
# Each process has its own list of active environments
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
# Setup signal handlers for this process too
signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
try:
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
@@ -236,7 +211,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
agent = OpenCUAAgent(
env=env,
model=args.model,
@@ -251,7 +225,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
coordinate_type=args.coordinate_type,
max_image_history_length=args.max_image_history_length,
)
try:
lib_run_single.run_single_example_opencua(
agent,
@@ -265,7 +238,7 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
except Exception as e:
import traceback
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
@@ -273,7 +246,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
@@ -281,15 +253,23 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
# This ensures the environment is closed even if there's an exception
logger.info(f"Process {env_idx + 1} cleaning up environment...")
logger.info(f"{current_process().name} cleaning up environment...")
try:
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}")
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
@@ -328,8 +308,8 @@ def signal_handler(signum, frame):
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal
os.kill(p.pid, signal.SIGKILL)
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
@@ -340,38 +320,56 @@ def signal_handler(signum, frame):
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
logger.info("All environments are ready. Starting parallel task execution...")
# Create a shared list for scores across processes
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
# Create and start processes for each environment
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for env_idx, env_tasks in enumerate(distributed_tasks):
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(env_idx, env_tasks, args, shared_scores)
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
processes.append(p)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
# Wait for all processes to complete
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
logger.info(f"Process {p.name} completed")
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
# Let the signal handler do the cleanup
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
# Ensure cleanup happens
for p in processes:
if p.is_alive():
try:
@@ -380,10 +378,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
# Convert shared list to regular list
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
@@ -469,6 +464,18 @@ if __name__ == "__main__":
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)

View File

@@ -11,10 +11,29 @@ from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.uitars15_agent import UITarsAgent
import shutil
import os
# def clear_cache():
# cache_path = "cache"
# try:
# if os.path.exists(cache_path):
# logger.info(f"Deleting cache directory: {cache_path}")
# shutil.rmtree(cache_path)
# logger.info(f"Cache directory deleted successfully")
# else:
# logger.info(f"Cache directory {cache_path} does not exist")
# except Exception as e:
# logger.error(f"Error deleting cache directory: {e}")
# clear_cache()
# Global variables for signal handling
active_environments = []
processes = []
@@ -45,7 +64,7 @@ def config() -> argparse.Namespace:
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=0)
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
parser.add_argument("--max_steps", type=int, default=15)
# evaluation config
@@ -58,6 +77,7 @@ def config() -> argparse.Namespace:
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--max_tokens", type=int, default=3000)
parser.add_argument("--use_thinking", action="store_true", default=False)
# OpenCUAagent config
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
@@ -131,32 +151,12 @@ logger.addHandler(stdout_handler)
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
"""Distribute tasks evenly across environments."""
# Flatten the tasks into a single list
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
# Calculate tasks per environment
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
# Distribute tasks
distributed_tasks = []
for i in range(num_envs):
env_tasks = {}
start_idx = i * tasks_per_env
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
for domain, example_id in all_tasks[start_idx:end_idx]:
if domain not in env_tasks:
env_tasks[domain] = []
env_tasks[domain].append(example_id)
distributed_tasks.append(env_tasks)
return distributed_tasks
return all_tasks
def process_signal_handler(signum, frame, env_idx):
@@ -180,62 +180,55 @@ def process_signal_handler(signum, frame, env_idx):
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
"""Run tasks for a single environment."""
# Each process has its own list of active environments
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
# Setup signal handlers for this process too
signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
agent = UITarsAgent(
model=args.model,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
max_trajectory_length=args.max_trajectory_length,
max_image_history_length=args.max_image_history_length,
use_thinking=True,
language=args.language,
)
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
try:
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
agent = UITarsAgent(
model=args.model,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
max_trajectory_length=args.max_trajectory_length,
max_image_history_length=args.max_image_history_length,
use_thinking=args.use_thinking,
language=args.language,
)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
@@ -258,7 +251,7 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
except Exception as e:
import traceback
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
@@ -266,7 +259,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
@@ -274,14 +266,23 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
# This ensures the environment is closed even if there's an exception
logger.info(f"Process {env_idx + 1} cleaning up environment...")
logger.info(f"{current_process().name} cleaning up environment...")
try:
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}")
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
@@ -321,8 +322,8 @@ def signal_handler(signum, frame):
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal
os.kill(p.pid, signal.SIGKILL)
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
@@ -333,38 +334,56 @@ def signal_handler(signum, frame):
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
logger.info("All environments are ready. Starting parallel task execution...")
# Create a shared list for scores across processes
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
# Create and start processes for each environment
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for env_idx, env_tasks in enumerate(distributed_tasks):
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(env_idx, env_tasks, args, shared_scores)
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
processes.append(p)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
# Wait for all processes to complete
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
logger.info(f"Process {p.name} completed")
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
# Let the signal handler do the cleanup
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
# Ensure cleanup happens
for p in processes:
if p.is_alive():
try:
@@ -373,10 +392,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
# Convert shared list to regular list
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
@@ -462,6 +478,18 @@ if __name__ == "__main__":
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)