Files
sci-gui-agent-benchmark/desktop_env/desktop_env.py
Tianbao Xie 0cc93543a8 Environment is_used flag; OS domain fix (#219)
* Refactor evaluator structure in LibreOffice Writer example JSON to support multiple expected and result files, enhancing evaluation flexibility.

* Update instance type to t3.large and add VNC access URL logging for allocated VMs, enhancing remote access capabilities.

* Update instance type to t3.large and add VNC access URL logging for allocated VMs, enhancing remote access capabilities.

* Update time format in get_vm_file function to include hours, minutes, and seconds for more precise file naming with time suffix.

* More delay for 936321ce-5236-426a-9a20-e0e3c5dc536f; support one more potential solutions.

* Enhance SetupController with configurable retry limit and improved error handling for file opening requests. Introduce new function to compare unique training records, and update logging for better debugging. Adjust JSON examples for evaluation to support multiple expected and result files.

* Clean debug code

* Enhance DesktopEnv to track environment usage for optimized snapshot management. Introduce is_environment_used flag to determine if a snapshot revert is necessary based on provider type. Update setup and step methods to mark environment usage appropriately. Add new execute_with_verification method in SetupController for command execution with result verification, improving reliability. Change AWS instance type to m5.large for better performance and update AMI ID for compatibility. Update file opening logic in main.py to handle both file paths and application commands more effectively.

---------

Co-authored-by: yuanmengqi <yuanmengqi@mail.ustc.edu.cn>
2025-06-28 00:45:53 +08:00

428 lines
20 KiB
Python

from __future__ import annotations
import logging
import os
import time
from typing import Callable, Any, Optional, Tuple
from typing import List, Dict, Union
import gymnasium as gym
from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController
from desktop_env.evaluators import metrics, getters
from desktop_env.providers import create_vm_manager_and_provider
logger = logging.getLogger("desktopenv.env")
Metric = Callable[[Any, Any], float]
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
MAX_RETRIES = 5 # Maximum retries for environment setup
class DesktopEnv(gym.Env):
"""
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
"""
def __init__(
self,
provider_name: str = "vmware",
region: str = None,
path_to_vm: str = None,
snapshot_name: str = "init_state",
action_space: str = "computer_13",
cache_dir: str = "cache",
screen_size: Tuple[int] = (1920, 1080),
headless: bool = False,
require_a11y_tree: bool = True,
require_terminal: bool = False,
os_type: str = "Ubuntu",
enable_proxy: bool = False,
):
"""
Args:
provider_name (str): virtualization provider name, default to "vmware"
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
path_to_vm (str): path to .vmx file
snapshot_name (str): snapshot name to revert to, default to "init_state"
action_space (str): "computer_13" | "pyautogui"
cache_dir (str): cache directory to cache task-related stuffs like
reference file for evaluation
screen_size (Tuple[int]): screen size of the VM
headless (bool): whether to run the VM in headless mode
require_a11y_tree (bool): whether to require accessibility tree
require_terminal (bool): whether to require terminal output
os_type (str): operating system type, default to "Ubuntu"
enable_proxy (bool): whether to enable proxy support, default to False
"""
# Initialize VM manager and vitualization provider
self.region = region
self.provider_name = provider_name
self.enable_proxy = enable_proxy # Store proxy enablement setting
# Default
self.server_port = 5000
self.chromium_port = 9222
self.vnc_port = 8006
self.vlc_port = 8080
# Initialize with default (no proxy) provider
self.current_use_proxy = False
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False)
self.os_type = os_type
# Track whether environment has been used (step/setup) to optimize snapshot revert
# docker, aws, gcp, azure are always unused as the emulator starts from a clean state
# vmware, virtualbox are always used as the emulator starts from a dirty state
if self.provider_name in {"docker", "aws", "gcp", "azure"}:
self.is_environment_used = False
elif self.provider_name in {"vmware", "virtualbox"}:
self.is_environment_used = True
else:
raise ValueError(f"Invalid provider name: {self.provider_name}")
# Initialize environment variables
if path_to_vm:
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \
if provider_name in {"vmware", "virtualbox"} else path_to_vm
else:
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region)
try:
self.snapshot_name = snapshot_name
self.cache_dir_base: str = cache_dir
# todo: add the logic to get the screen size from the VM
self.headless = headless
self.require_a11y_tree = require_a11y_tree
self.require_terminal = require_terminal
# Initialize emulator and controller
if provider_name != "docker": # Check if this is applicable to other VM providers
logger.info("Initializing...")
self._start_emulator()
# mode: human or machine
self.instruction = None
assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space # todo: refactor it to the ActType
# episodic stuffs, like counters, will be updated or reset
# when calling self.reset()
self._traj_no: int = -1
self._step_no: int = 0
self.action_history: List[Dict[str, any]] = []
except Exception as e:
logger.error(f"Failed to initialize DesktopEnv: {e}")
# If initialization fails, we should clean up the VM
try:
self.close()
self.manager.delete_vm(self.path_to_vm, self.region)
logger.info(f"Cleaned up VM {self.path_to_vm}.")
except Exception as cleanup_error:
logger.error(f"Failed to clean up VM {self.path_to_vm}: {cleanup_error}")
raise
def _start_emulator(self):
# Power on the virtual machine
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
# Get the ip from the virtual machine, and setup the controller
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
self.vm_ip = vm_ip_ports[0]
if len(vm_ip_ports) > 1:
self.server_port = int(vm_ip_ports[1])
self.chromium_port = int(vm_ip_ports[2])
self.vnc_port = int(vm_ip_ports[3])
self.vlc_port = int(vm_ip_ports[4])
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base)
def _revert_to_snapshot(self):
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
# due to the fact it could be changed when implemented by cloud services
path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
if path_to_vm and not path_to_vm == self.path_to_vm:
# path_to_vm has to be a new path
self.manager.delete_vm(self.path_to_vm, self.region)
self.manager.add_vm(path_to_vm, self.region)
self.manager.occupy_vm(path_to_vm, os.getpid(), self.region)
self.path_to_vm = path_to_vm
def _save_state(self, snapshot_name=None):
# Save the current virtual machine state to a certain snapshot name
self.provider.save_state(self.path_to_vm, snapshot_name)
def close(self):
# Close (release) the virtual machine
self.provider.stop_emulator(self.path_to_vm)
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
# Reset to certain task in OSWorld
logger.info("Resetting environment...")
logger.info("Switching task...")
logger.info("Setting counters...")
self._traj_no += 1
self._step_no = 0
self.action_history.clear()
for attempt in range(MAX_RETRIES):
# Check and handle proxy requirement changes BEFORE starting emulator
if task_config is not None:
# Only consider task proxy requirement if proxy is enabled at system level
task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
if not self.enable_proxy and task_config.get("proxy", False):
logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
if task_use_proxy != self.current_use_proxy:
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
# Close current provider if it exists
if hasattr(self, 'provider') and self.provider:
try:
self.provider.stop_emulator(self.path_to_vm)
except Exception as e:
logger.warning(f"Failed to stop current provider: {e}")
# Create new provider with appropriate proxy setting
self.current_use_proxy = task_use_proxy
self.manager, self.provider = create_vm_manager_and_provider(
self.provider_name,
self.region,
use_proxy=task_use_proxy
)
if task_use_proxy:
logger.info("Using proxy-enabled AWS provider.")
else:
logger.info("Using regular AWS provider.")
# Only revert to snapshot if environment has been used (step/setup)
# This optimization is especially important for cloud providers like AWS
# where unnecessary snapshot operations are costly and time-consuming
if self.is_environment_used:
logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
self._revert_to_snapshot()
logger.info("Starting emulator...")
self._start_emulator()
logger.info("Emulator started.")
# Reset the usage flag after reverting
self.is_environment_used = False
else:
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
if task_config is not None:
self._set_task_info(task_config)
self.setup_controller.reset_cache_dir(self.cache_dir)
logger.info("Setting up environment...")
success = self.setup_controller.setup(self.config)
if success:
# Mark environment as used when setup is successfully executed
if self.config: # Only mark as used if there were actual setup operations
self.is_environment_used = True
break
else:
logger.error(
"Environment setup failed, retrying (%d/%d)...",
attempt + 1,
MAX_RETRIES,
)
time.sleep(5)
else:
break
logger.info("Environment setup complete.")
if task_config.get("proxy", False) and self.enable_proxy:
# If using proxy and proxy is enabled, set up the proxy configuration
self.setup_controller._proxy_setup()
observation = self._get_obs()
return observation
def _get_obs(self):
# We provide screenshot, accessibility_tree (optional), terminal (optional), and instruction.
# can be customized and scaled
return {
"screenshot": self.controller.get_screenshot(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
"instruction": self.instruction
}
@property
def vm_platform(self):
return self.controller.get_vm_platform()
@property
def vm_screen_size(self):
return self.controller.get_vm_screen_size()
def _set_task_info(self, task_config: Dict[str, Any]):
"""Set task info (proxy logic is handled in reset method)"""
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"] if "config" in task_config else []
self._set_evaluator_info(task_config)
def _set_evaluator_info(self, task_config: Dict[str, Any]):
"""Set evaluator information from task config"""
# evaluator dict
# func -> metric function string, or list of metric function strings
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
# result -> result getter config, or list of result getter configs
# expected (optional) -> expected getter config, or list of expected getter configs
# options (optional) -> metric options, or list of metric options
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
# even if one of the metrics does not need expected or options field, it should be included in the list with None
self.evaluator = task_config["evaluator"]
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \
if isinstance(self.evaluator["func"], list) \
else getattr(metrics, self.evaluator["func"])
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
if "result" in self.evaluator and len(self.evaluator["result"]) > 0:
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
self.evaluator["result"]] \
if isinstance(self.evaluator["result"], list) \
else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
else:
self.result_getter = [None] * len(self.metric) \
if isinstance(self.metric, list) \
else None
if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0:
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
self.evaluator["expected"]] \
if isinstance(self.evaluator["expected"], list) \
else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
else:
self.expected_getter = [None] * len(self.metric) \
if isinstance(self.metric, list) \
else None
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in
self.evaluator["options"]] \
if isinstance(self.evaluator.get("options", {}), list) \
else self.evaluator["options"] \
if "options" in self.evaluator \
else [{}] * len(self.metric) \
if isinstance(self.metric, list) \
else {}
assert (not isinstance(self.evaluator["func"], list)
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
self.metric_options)))
def step(self, action, pause=2):
self._step_no += 1
self.action_history.append(action)
# Mark environment as used when step is called
self.is_environment_used = True
reward = 0 # todo: Define reward calculation for each example
done = False # todo: Define episode termination condition for each example
info = {}
# handle the special actions
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
if action == 'WAIT':
time.sleep(pause)
elif action == 'FAIL':
done = True
info = {"fail": True}
elif action == 'DONE':
done = True
info = {"done": True}
if self.action_space == "computer_13":
# the set of all possible actions defined in the action representation
self.controller.execute_action(action)
elif self.action_space == "pyautogui":
if action in ['WAIT', 'FAIL', 'DONE']:
self.controller.execute_action(action)
else:
# the set of all possible python commands insides `pyautogui`
self.controller.execute_python_command(action)
time.sleep(pause)
observation = self._get_obs()
return observation, reward, done, info
def evaluate(self):
"""
Evaluate whether the task is successfully completed.
"""
postconfig = self.evaluator.get("postconfig", [])
self.setup_controller.setup(postconfig)
# Mark environment as used if there were postconfig setup operations
if postconfig:
self.is_environment_used = True
if self.evaluator['func'] == "infeasible":
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
return 1
else:
return 0
else:
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
return 0
if type(self.metric) == list:
# Multiple metrics to evaluate whether the task is successfully completed
results = []
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
if "expected" in self.evaluator:
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
for idx, metric in enumerate(self.metric):
try:
config = self.evaluator["result"][idx]
result_state = self.result_getter[idx](self, config)
except FileNotFoundError:
logger.error("File not found!")
if self.metric_conj == 'and':
return 0
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
else:
metric: int = metric(result_state, **self.metric_options[idx])
if self.metric_conj == 'and' and float(metric) == 0.0:
return 0
elif self.metric_conj == 'or' and float(metric) == 1.0:
return 1
else:
results.append(metric)
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
else:
# Single metric to evaluate whether the task is successfully completed
try:
result_state = self.result_getter(self, self.evaluator["result"])
except FileNotFoundError:
logger.error("File not found!")
return 0
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
expected_state = self.expected_getter(self, self.evaluator["expected"])
metric: float = self.metric(result_state, expected_state, **self.metric_options)
else:
metric: float = self.metric(result_state, **self.metric_options)
return metric
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self.controller.get_screenshot()
else:
raise ValueError('Unsupported render mode: {}'.format(mode))