fix multienv bug (#327)

This commit is contained in:
hanyullai
2025-08-30 11:10:53 +08:00
committed by GitHub
parent 3344abd641
commit 54a14cbc07
2 changed files with 440 additions and 421 deletions

View File

@@ -19,8 +19,9 @@ from requests.exceptions import SSLError
from tqdm import tqdm from tqdm import tqdm
import lib_run_single import lib_run_single
from desktop_env.desktop_env import DesktopEnv as DesktopEnvBase from desktop_env.desktop_env import MAX_RETRIES, DesktopEnv as DesktopEnvBase
from mm_agents.autoglm import AutoGLMAgent from mm_agents.autoglm import AutoGLMAgent
from typing import Optional, Dict, Any
# Almost deprecated since it's not multi-env, use run_multienv_*.py instead # Almost deprecated since it's not multi-env, use run_multienv_*.py instead

View File

@@ -8,13 +8,9 @@ import json
import logging import logging
import os import os
import sys import sys
import signal import math
import ast
import time import time
from typing import List
from multiprocessing import Process, Manager, current_process
import lib_run_single
from run_autoglm import DesktopEnv
from mm_agents.autoglm import AutoGLMAgent
import backoff import backoff
import httpx import httpx
@@ -22,52 +18,75 @@ from openai import APIConnectionError, APIError, OpenAI, RateLimitError
from requests.exceptions import SSLError from requests.exceptions import SSLError
from tqdm import tqdm from tqdm import tqdm
# Global variables for signal handling import lib_run_single
active_environments = [] from desktop_env.desktop_env import MAX_RETRIES, DesktopEnv as DesktopEnvBase
processes = [] from mm_agents.autoglm import AutoGLMAgent
is_terminating = False from typing import Optional, Dict, Any
from multiprocessing import Pool
# .env
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ # # Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8")
debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8")
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8")
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def config() -> argparse.Namespace: def config() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Run end-to-end evaluation on the benchmark")
description="Run end-to-end evaluation on the benchmark"
)
# environment config # environment config
parser.add_argument("--path_to_vm", type=str) parser.add_argument("--path_to_vm", type=str)
parser.add_argument( parser.add_argument(
"--headless", action="store_true", default=True, help="Run in headless machine" "--provider_name",
) type=str,
parser.add_argument( default="docker",
"--action_space", type=str, default="autoglm_computer_use", help="Action type" help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)",
) )
parser.add_argument("--headless", action="store_true", default=True, help="Run in headless machine")
parser.add_argument("--action_space", type=str, default="autoglm_computer_use", help="Action type")
parser.add_argument( parser.add_argument(
"--observation_type", "--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="a11y_tree", default="a11y_tree",
help="Observation type", help="Observation type",
) )
parser.add_argument( parser.add_argument("--screen_width", type=int, default=1920)
"--provider_name", type=str, default="docker", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" parser.add_argument("--screen_height", type=int, default=1080)
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
parser.add_argument("--sleep_after_execution", type=float, default=1.0) parser.add_argument("--sleep_after_execution", type=float, default=1.0)
parser.add_argument("--max_steps", type=int, default=50) parser.add_argument("--max_steps", type=int, default=50)
# agent config # agent config
parser.add_argument("--max_trajectory_length", type=int, default=3) parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument( parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config # lm config
parser.add_argument("--model", type=str, default="autoglm-os") parser.add_argument("--model", type=str, default="autoglm-os")
@@ -78,331 +97,255 @@ def config() -> argparse.Namespace:
# example config # example config
parser.add_argument("--domain", type=str, default="all") parser.add_argument("--domain", type=str, default="all")
parser.add_argument( parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json")
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
)
# aws config # aws config
parser.add_argument( parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM" "--region", type=str, default="us-east-1", help="AWS region for the VM"
) )
parser.add_argument( parser.add_argument("--client_password", type=str, default="", help="Client password")
"--client_password", type=str, default="", help="Client password"
)
# logging related # logging related
parser.add_argument("--result_dir", type=str, default="./results") parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=20, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='INFO', help="Set the logging level")
# parallel number
parser.add_argument("--num_workers", type=int, default=20, help="Number of parallel workers")
args = parser.parse_args() args = parser.parse_args()
return args return args
args = config() # Get command line arguments first
if args.client_password == "":
if args.provider_name == "aws":
args.client_password = "osworld-public-evaluation"
else:
args.client_password = "password"
else:
args.client_password = args.client_password
logger = logging.getLogger() class DesktopEnv(DesktopEnvBase):
log_level = getattr(logging, args.log_level.upper()) def step(self, action, pause=2):
logger.setLevel(log_level) self._step_no += 1
self.action_history.append(action)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") # Mark environment as used when step is called
self.is_environment_used = True
file_handler = logging.FileHandler( reward = 0 # todo: Define reward calculation for each example
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" done = False # todo: Define episode termination condition for each example
) info = {}
debug_handler = logging.FileHandler( logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO) # handle the special actions
debug_handler.setLevel(logging.DEBUG) if action in ['WAIT', 'FAIL', 'DONE']:
stdout_handler.setLevel(log_level) if action == 'WAIT':
time.sleep(pause)
formatter = logging.Formatter( exe_result = 'Wait ' + str(pause) + ' seconds'
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" elif action == 'FAIL':
) done = True
file_handler.setFormatter(formatter) info = {"fail": True}
debug_handler.setFormatter(formatter) exe_result = 'Finish: fail'
stdout_handler.setFormatter(formatter) elif action == 'DONE':
done = True
stdout_handler.addFilter(logging.Filter("desktopenv")) info = {"done": True}
exe_result = 'Finish: success'
logger.addHandler(file_handler) elif type(action) == dict:
logger.addHandler(debug_handler) if action['action_type'] == 'OPEN_APP':
logger.addHandler(stdout_handler) self.setup_controller._launch_setup(action['parameters']['launch_app_command'], shell=True)
# }}} Logger Configs # exe_result = 'Open ' + action['parameters']['app_name']
elif action['action_type'] == 'OPEN_CHROME_TAB':
logger = logging.getLogger("desktopenv.experiment") self.setup_controller._chrome_open_tabs_setup(action['parameters']['urls_to_open'])
exe_result = 'Open ' + str(action['parameters']['urls_to_open']) + ' in Chrome successfully'
else:
def distribute_tasks(test_all_meta: dict) -> List[tuple]: # the set of all possible python commands insides `pyautogui`
"""Distribute tasks evenly across environments.""" result = self.controller.execute_python_command(action)
# Flatten the tasks into a single list
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try: try:
logger.info(f"Process {env_idx + 1} closing environment...") if result['error']:
env.close() exe_result = result['error'].strip()
logger.info(f"Process {env_idx + 1} environment closed successfully") else:
exe_result = result['output'].strip()
except Exception as e: except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}") exe_result = 'Error Action: ' + action
logger.error(f"Error executing action: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") time.sleep(pause)
sys.exit(0) observation = self._get_obs()
observation['exe_result'] = exe_result
return observation, reward, done, info
def run_env_tasks(task_queue, args, shared_scores): def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
"""Run tasks for a single environment.""" # Reset to certain task in OSWorld
active_environments = [] logger.info("Resetting environment...")
env = None logger.info("Switching task...")
try: logger.info("Setting counters...")
@backoff.on_exception( self._traj_no += 1
backoff.constant, self._step_no = 0
(RateLimitError, APIConnectionError), self.action_history.clear()
interval=0.1,
)
def call_llm(messages):
logger.info("Calling LLM...")
# set api_key and base_url by environment variables
engine = OpenAI(timeout=60.0)
response = engine.chat.completions.create(
model=args.model,
messages=messages,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
)
logger.info("LLM called successfully.")
return response.choices[0].message.content
env = DesktopEnv( for attempt in range(MAX_RETRIES):
provider_name=args.provider_name, # Only revert to snapshot if environment has been used (step/setup)
region=args.region, # This optimization is especially important for cloud providers like AWS
client_password=args.client_password, # where unnecessary snapshot operations are costly and time-consuming
path_to_vm=args.path_to_vm,
action_space=args.action_space, if task_config is not None:
screen_size=(args.screen_width, args.screen_height), # Only consider task proxy requirement if proxy is enabled at system level
headless=args.headless, task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
os_type="Ubuntu", if not self.enable_proxy and task_config.get("proxy", False):
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
)
active_environments.append(env) if task_use_proxy != self.current_use_proxy:
agent = AutoGLMAgent( # keep because get_info_from_website depend on this
action_space=args.action_space, self.current_use_proxy = task_use_proxy
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length, if self.is_environment_used:
client_password=args.client_password, logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
gen_func=call_llm, self._revert_to_snapshot()
) logger.info("Starting emulator...")
logger.info(f"Process {current_process().name} started.") self._start_emulator()
while True: logger.info("Emulator started.")
try: # Reset the usage flag after reverting
item = task_queue.get(timeout=5) self.is_environment_used = False
except Exception: else:
break logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
domain, example_id = item
try: if task_config is not None:
config_file = os.path.join( if task_config.get("proxy", False) and self.enable_proxy:
args.test_config_base_dir, f"examples/{domain}/{example_id}.json" # If using proxy and proxy is enabled, set up the proxy configuration
) self.setup_controller._proxy_setup(self.client_password)
with open(config_file, "r", encoding="utf-8") as f: self._set_task_info(task_config)
example = json.load(f) self.setup_controller.reset_cache_dir(self.cache_dir)
logger.info(f"[{current_process().name}][Domain]: {domain}") logger.info("Setting up environment...")
logger.info(f"[{current_process().name}][Example ID]: {example_id}") success = self.setup_controller.setup(self.config, task_config.get("proxy", False) and self.enable_proxy)
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") if success:
example_result_dir = os.path.join( # Mark environment as used when setup is successfully executed
args.result_dir, if self.config: # Only mark as used if there were actual setup operations
args.action_space, self.is_environment_used = True
args.observation_type, break
args.model, else:
domain, logger.error(
example_id, "Environment setup failed, retrying (%d/%d)...",
) attempt + 1,
os.makedirs(example_result_dir, exist_ok=True) MAX_RETRIES,
try:
lib_run_single.run_single_example_autoglm(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
) )
except Exception as e: time.sleep(5)
import traceback else:
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") break
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
logger.info("Environment setup complete.")
def signal_handler(signum, frame): # Upload tools from autoglm package
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" import mm_agents.autoglm
global is_terminating, active_environments, processes tool_dir = os.path.join(os.path.dirname(mm_agents.autoglm.__file__), 'tools', 'package')
for file in os.listdir(tool_dir):
if os.path.isdir(os.path.join(tool_dir, file)):
continue
self.setup_controller._upload_file_setup([{
"local_path": os.path.join(tool_dir, file),
"path": os.path.join('~', file)
}])
# Avoid duplicate handling # start soffice service for office tools
if is_terminating: self.setup_controller._launch_setup('soffice --accept="socket,host=localhost,port=2002;urp;" --norestore --nologo --nodefault', shell=True)
return time.sleep(5)
is_terminating = True observation = self._get_obs()
logger.info(f"Received signal {signum}. Gracefully shutting down...") return observation
# Close all registered environments in the main process def get_current_apps(self):
for env in active_environments: apps_code = r"""import subprocess;
try: command = "wmctrl -xl";
logger.info(f"Closing environment...") apps = subprocess.run(command, shell=True, capture_output=True, text=True).stdout.strip().split('\n');
env.close() print(apps);"""
logger.info(f"Environment closed successfully") window_code = r"""import subprocess;
except Exception as e: command = "wmctrl -a :ACTIVE: -v 2>&1 | grep 'Using window' | awk '{print $3}'";
logger.error(f"Error closing environment: {e}") window_id = subprocess.run(command, shell=True, capture_output=True, text=True).stdout.strip();
print(window_id);"""
# Send termination signal to all child processes first apps = self.controller.execute_python_command(apps_code)['output'].strip()
for p in processes: apps = ast.literal_eval(apps)
if p.is_alive(): app_list = {}
for app in apps:
parts = app.split(maxsplit=4)
if len(parts) < 4:
continue
if parts[1] != '0':
continue
window_id = parts[0]
app_name = '.'.join(parts[2].split('.')[-(math.ceil(parts[2].count('.') / 2)):])
title = parts[3]
app_list[window_id] = {
'app_name': app_name,
'title': title
}
cur_id = self.controller.execute_python_command(window_code)['output'].strip()
return app_list, cur_id
def maximize_window(self):
window_state = r"""import subprocess;
command = "xprop -id $(xprop -root _NET_ACTIVE_WINDOW | awk -F' ' '{print $5}') _NET_WM_STATE"
output = subprocess.run(command, shell=True, capture_output=True, text=True).stdout.strip();
print(output);"""
for _ in range(5):
try: try:
logger.info(f"Sending termination signal to process {p.name}...") self.setup_controller._launch_setup('wmctrl -r :ACTIVE: -b add,maximized_vert,maximized_horz', shell=True)
p.terminate() time.sleep(2)
output = self.controller.execute_python_command(window_state)['output'].strip()
if '_NET_WM_STATE_FOCUSED' not in output or '_NET_WM_STATE_SKIP_TASKBAR' in output or '_NET_WM_STATE_MODAL' in output or '_NET_WM_STATE_MAXIMIZED' in output: # 没有窗口 or popups or 模态窗口 or 窗口已经最大化
return
except Exception as e: except Exception as e:
logger.error(f"Error sending termination signal to process: {e}") logger.error(f"Failed to maximize window: {e}")
time.sleep(1)
# Allow a short time for processes to handle their own cleanup def _get_obs(self):
time.sleep(1) tool_list = {
"libreoffice_calc": "CalcTools",
"libreoffice_impress": "ImpressTools",
"libreoffice_writer": "WriterTools",
"code": "CodeTools",
"vlc": "VLCTools",
"google_chrome": "BrowserTools"
}
# Forcefully terminate any processes that didn't exit self.maximize_window()
for p in processes:
if p.is_alive(): for i in range(3):
try: try:
logger.info(f"Forcefully terminating process {p.name}...") app_list, cur_id = self.get_current_apps()
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e: except Exception as e:
logger.error(f"Error forcefully terminating process: {e}") if i == 2:
raise e
logger.error(f"Failed to get current apps: {e}")
time.sleep(1)
logger.info("Shutdown complete. Exiting.") if cur_id in app_list:
sys.exit(0) cur_app = app_list[cur_id]['app_name']
tool_name = cur_app.strip().lower().replace('-', '_')
if tool_name in tool_list:
class_name = tool_list[tool_name]
command = f"from {tool_name} import *; "
command += f"{class_name}.env_info(); "
command += f"{class_name}.print_result();"
app_info = self.controller.execute_python_command(command)['output'].strip()
else:
app_info = None
else:
cur_app = None
app_info = None
def test(args: argparse.Namespace, test_all_meta: dict) -> None: tree = self.controller.get_accessibility_tree()
global processes screenshot = self.controller.get_screenshot()
logger.info("Args: %s", args) if screenshot is None:
all_tasks = distribute_tasks(test_all_meta) logger.error("Failed to get screenshot.")
logger.info(f"Total tasks: {len(all_tasks)}") screenshot = b''
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
return {
"screenshot": screenshot,
"accessibility_tree": tree,
"instruction": self.instruction,
"apps": app_list,
"cur_window_id": cur_id,
"cur_app": cur_app,
"app_info": app_info,
}
def get_unfinished( def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model) target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
@@ -430,9 +373,7 @@ def get_unfinished(
for domain, examples in finished.items(): for domain, examples in finished.items():
if domain in total_file_json: if domain in total_file_json:
total_file_json[domain] = [ total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples]
x for x in total_file_json[domain] if x not in examples
]
return total_file_json return total_file_json
@@ -454,13 +395,7 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file
if "result.txt" in os.listdir(example_path): if "result.txt" in os.listdir(example_path):
# empty all files under example_id # empty all files under example_id
try: try:
all_result.append( all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except: except:
all_result.append(0.0) all_result.append(0.0)
@@ -471,93 +406,176 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result return all_result
def _worker_run(task):
import json, os, datetime, logging, httpx, backoff
from openai import OpenAI, RateLimitError, APIConnectionError
from types import SimpleNamespace
domain, example_id, args = task # args 为 argparse.Namespace
logger = logging.getLogger("desktopenv.experiment")
try:
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
instruction = example["instruction"]
@backoff.on_exception(backoff.constant, (RateLimitError, APIConnectionError), interval=0.1)
def call_llm(messages):
logger.info("Calling LLM...")
# set api_key and base_url by environment variables
engine = OpenAI(timeout=60.0)
response = engine.chat.completions.create(
model=args.model,
messages=messages,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
)
logger.info("LLM called successfully.")
return response.choices[0].message.content
env = DesktopEnv(
provider_name=args.provider_name,
region=args.region,
client_password=args.client_password,
path_to_vm=args.path_to_vm,
action_space=args.action_space,
screen_size=(args.screen_width, args.screen_height),
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
)
agent = AutoGLMAgent(
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
client_password=args.client_password,
gen_func=call_llm,
)
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
local_scores = []
try:
lib_run_single.run_single_example_autoglm(
agent,
env,
example,
args.max_steps,
instruction,
args,
example_result_dir,
local_scores,
)
except Exception as e:
logger.error(f"[并发任务异常] {domain}/{example_id}: {e}")
if hasattr(env, "controller") and env.controller is not None:
try:
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
except Exception:
pass
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({"Error": f"Exception in {domain}/{example_id}: {str(e)}"}) + "\n")
finally:
try:
env.close()
except Exception:
pass
score = None
result_path = os.path.join(example_result_dir, "result.txt")
if os.path.exists(result_path):
try:
with open(result_path, "r") as rf:
score = float(rf.read().strip())
except Exception:
score = 0.0
else:
score = 0.0
logger.info(f"[Finish] {domain}/{example_id} score={score}")
return (domain, example_id, score)
except Exception as e:
logger = logging.getLogger("desktopenv.experiment")
logger.error(f"[Initializing Fail] {domain}/{example_id}: {e}")
return (domain, example_id, 0.0)
def test_parallel(args: argparse.Namespace, test_all_meta: dict):
from tqdm import tqdm
tasks = []
for domain in test_all_meta:
for example_id in test_all_meta[domain]:
tasks.append((domain, example_id, args))
if not tasks:
logger.info("No pending tasks")
return
logger.info(f"Starting parallel execution: {args.num_workers} processes, {len(tasks)} tasks total")
results = []
with Pool(processes=args.num_workers) as pool:
for res in tqdm(pool.imap_unordered(_worker_run, tasks), total=len(tasks), desc="Parallel execution"):
results.append(res)
scores = [s for (_, _, s) in results if s is not None]
if scores:
avg = sum(scores) / len(scores)
logger.info(f"Parallel execution completed. Average score: {avg}")
else:
logger.info("No scores obtained.")
if __name__ == "__main__": if __name__ == "__main__":
####### The complete version of the list of examples ####### ####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
if args.client_password == "":
if args.provider_name == "aws":
args.client_password = "osworld-public-evaluation"
else:
args.client_password = "password"
else:
args.client_password = args.client_password
# Register signal handlers for graceful termination # save args to json in result_dir/action_space/observation_type/model/args.json
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C path_to_args = os.path.join(
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
try: with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
# args already defined globally above test_all_meta = json.load(f)
# save args to json in result_dir/action_space/observation_type/model/args.json if args.domain != "all":
path_to_args = os.path.join( test_all_meta = {args.domain: test_all_meta[args.domain]}
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f: test_file_list = get_unfinished(
test_all_meta = json.load(f) args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
if args.domain != "all": get_result(
test_all_meta = {args.domain: test_all_meta[args.domain]} args.action_space,
args.model,
test_file_list = get_unfinished( args.observation_type,
args.action_space, args.result_dir,
args.model, test_all_meta,
args.observation_type, )
args.result_dir, test_parallel(args, test_file_list)
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")