* Refactor evaluator structure in LibreOffice Writer example JSON to support multiple expected and result files, enhancing evaluation flexibility. * Update instance type to t3.large and add VNC access URL logging for allocated VMs, enhancing remote access capabilities. * Update instance type to t3.large and add VNC access URL logging for allocated VMs, enhancing remote access capabilities. * Update time format in get_vm_file function to include hours, minutes, and seconds for more precise file naming with time suffix. * More delay for 936321ce-5236-426a-9a20-e0e3c5dc536f; support one more potential solutions. * Enhance SetupController with configurable retry limit and improved error handling for file opening requests. Introduce new function to compare unique training records, and update logging for better debugging. Adjust JSON examples for evaluation to support multiple expected and result files. * Clean debug code * Enhance DesktopEnv to track environment usage for optimized snapshot management. Introduce is_environment_used flag to determine if a snapshot revert is necessary based on provider type. Update setup and step methods to mark environment usage appropriately. Add new execute_with_verification method in SetupController for command execution with result verification, improving reliability. Change AWS instance type to m5.large for better performance and update AMI ID for compatibility. Update file opening logic in main.py to handle both file paths and application commands more effectively. --------- Co-authored-by: yuanmengqi <yuanmengqi@mail.ustc.edu.cn>
428 lines
20 KiB
Python
428 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import Callable, Any, Optional, Tuple
|
|
from typing import List, Dict, Union
|
|
|
|
import gymnasium as gym
|
|
|
|
from desktop_env.controllers.python import PythonController
|
|
from desktop_env.controllers.setup import SetupController
|
|
from desktop_env.evaluators import metrics, getters
|
|
from desktop_env.providers import create_vm_manager_and_provider
|
|
|
|
logger = logging.getLogger("desktopenv.env")
|
|
|
|
Metric = Callable[[Any, Any], float]
|
|
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
|
|
|
MAX_RETRIES = 5 # Maximum retries for environment setup
|
|
|
|
class DesktopEnv(gym.Env):
|
|
"""
|
|
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
provider_name: str = "vmware",
|
|
region: str = None,
|
|
path_to_vm: str = None,
|
|
snapshot_name: str = "init_state",
|
|
action_space: str = "computer_13",
|
|
cache_dir: str = "cache",
|
|
screen_size: Tuple[int] = (1920, 1080),
|
|
headless: bool = False,
|
|
require_a11y_tree: bool = True,
|
|
require_terminal: bool = False,
|
|
os_type: str = "Ubuntu",
|
|
enable_proxy: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
provider_name (str): virtualization provider name, default to "vmware"
|
|
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
|
|
path_to_vm (str): path to .vmx file
|
|
snapshot_name (str): snapshot name to revert to, default to "init_state"
|
|
action_space (str): "computer_13" | "pyautogui"
|
|
cache_dir (str): cache directory to cache task-related stuffs like
|
|
reference file for evaluation
|
|
screen_size (Tuple[int]): screen size of the VM
|
|
headless (bool): whether to run the VM in headless mode
|
|
require_a11y_tree (bool): whether to require accessibility tree
|
|
require_terminal (bool): whether to require terminal output
|
|
os_type (str): operating system type, default to "Ubuntu"
|
|
enable_proxy (bool): whether to enable proxy support, default to False
|
|
"""
|
|
# Initialize VM manager and vitualization provider
|
|
self.region = region
|
|
self.provider_name = provider_name
|
|
self.enable_proxy = enable_proxy # Store proxy enablement setting
|
|
|
|
# Default
|
|
self.server_port = 5000
|
|
self.chromium_port = 9222
|
|
self.vnc_port = 8006
|
|
self.vlc_port = 8080
|
|
|
|
# Initialize with default (no proxy) provider
|
|
self.current_use_proxy = False
|
|
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False)
|
|
|
|
self.os_type = os_type
|
|
|
|
# Track whether environment has been used (step/setup) to optimize snapshot revert
|
|
# docker, aws, gcp, azure are always unused as the emulator starts from a clean state
|
|
# vmware, virtualbox are always used as the emulator starts from a dirty state
|
|
if self.provider_name in {"docker", "aws", "gcp", "azure"}:
|
|
self.is_environment_used = False
|
|
elif self.provider_name in {"vmware", "virtualbox"}:
|
|
self.is_environment_used = True
|
|
else:
|
|
raise ValueError(f"Invalid provider name: {self.provider_name}")
|
|
|
|
# Initialize environment variables
|
|
if path_to_vm:
|
|
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) \
|
|
if provider_name in {"vmware", "virtualbox"} else path_to_vm
|
|
else:
|
|
|
|
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region)
|
|
try:
|
|
self.snapshot_name = snapshot_name
|
|
self.cache_dir_base: str = cache_dir
|
|
# todo: add the logic to get the screen size from the VM
|
|
self.headless = headless
|
|
self.require_a11y_tree = require_a11y_tree
|
|
self.require_terminal = require_terminal
|
|
|
|
# Initialize emulator and controller
|
|
if provider_name != "docker": # Check if this is applicable to other VM providers
|
|
logger.info("Initializing...")
|
|
self._start_emulator()
|
|
|
|
# mode: human or machine
|
|
self.instruction = None
|
|
assert action_space in ["computer_13", "pyautogui"]
|
|
self.action_space = action_space # todo: refactor it to the ActType
|
|
|
|
# episodic stuffs, like counters, will be updated or reset
|
|
# when calling self.reset()
|
|
self._traj_no: int = -1
|
|
self._step_no: int = 0
|
|
self.action_history: List[Dict[str, any]] = []
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize DesktopEnv: {e}")
|
|
# If initialization fails, we should clean up the VM
|
|
try:
|
|
self.close()
|
|
self.manager.delete_vm(self.path_to_vm, self.region)
|
|
logger.info(f"Cleaned up VM {self.path_to_vm}.")
|
|
except Exception as cleanup_error:
|
|
logger.error(f"Failed to clean up VM {self.path_to_vm}: {cleanup_error}")
|
|
raise
|
|
|
|
def _start_emulator(self):
|
|
# Power on the virtual machine
|
|
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
|
|
|
# Get the ip from the virtual machine, and setup the controller
|
|
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
|
|
self.vm_ip = vm_ip_ports[0]
|
|
if len(vm_ip_ports) > 1:
|
|
self.server_port = int(vm_ip_ports[1])
|
|
self.chromium_port = int(vm_ip_ports[2])
|
|
self.vnc_port = int(vm_ip_ports[3])
|
|
self.vlc_port = int(vm_ip_ports[4])
|
|
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
|
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base)
|
|
|
|
def _revert_to_snapshot(self):
|
|
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
|
# due to the fact it could be changed when implemented by cloud services
|
|
path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
|
|
if path_to_vm and not path_to_vm == self.path_to_vm:
|
|
# path_to_vm has to be a new path
|
|
|
|
self.manager.delete_vm(self.path_to_vm, self.region)
|
|
self.manager.add_vm(path_to_vm, self.region)
|
|
self.manager.occupy_vm(path_to_vm, os.getpid(), self.region)
|
|
self.path_to_vm = path_to_vm
|
|
|
|
def _save_state(self, snapshot_name=None):
|
|
# Save the current virtual machine state to a certain snapshot name
|
|
self.provider.save_state(self.path_to_vm, snapshot_name)
|
|
|
|
def close(self):
|
|
# Close (release) the virtual machine
|
|
self.provider.stop_emulator(self.path_to_vm)
|
|
|
|
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
|
|
|
# Reset to certain task in OSWorld
|
|
logger.info("Resetting environment...")
|
|
logger.info("Switching task...")
|
|
logger.info("Setting counters...")
|
|
self._traj_no += 1
|
|
self._step_no = 0
|
|
self.action_history.clear()
|
|
|
|
for attempt in range(MAX_RETRIES):
|
|
# Check and handle proxy requirement changes BEFORE starting emulator
|
|
if task_config is not None:
|
|
# Only consider task proxy requirement if proxy is enabled at system level
|
|
task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
|
|
if not self.enable_proxy and task_config.get("proxy", False):
|
|
logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
|
|
|
|
if task_use_proxy != self.current_use_proxy:
|
|
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
|
|
|
|
# Close current provider if it exists
|
|
if hasattr(self, 'provider') and self.provider:
|
|
try:
|
|
self.provider.stop_emulator(self.path_to_vm)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to stop current provider: {e}")
|
|
|
|
# Create new provider with appropriate proxy setting
|
|
self.current_use_proxy = task_use_proxy
|
|
self.manager, self.provider = create_vm_manager_and_provider(
|
|
self.provider_name,
|
|
self.region,
|
|
use_proxy=task_use_proxy
|
|
)
|
|
|
|
if task_use_proxy:
|
|
logger.info("Using proxy-enabled AWS provider.")
|
|
else:
|
|
logger.info("Using regular AWS provider.")
|
|
|
|
|
|
# Only revert to snapshot if environment has been used (step/setup)
|
|
# This optimization is especially important for cloud providers like AWS
|
|
# where unnecessary snapshot operations are costly and time-consuming
|
|
if self.is_environment_used:
|
|
logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
|
|
self._revert_to_snapshot()
|
|
logger.info("Starting emulator...")
|
|
self._start_emulator()
|
|
logger.info("Emulator started.")
|
|
# Reset the usage flag after reverting
|
|
self.is_environment_used = False
|
|
else:
|
|
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
|
|
|
|
if task_config is not None:
|
|
self._set_task_info(task_config)
|
|
self.setup_controller.reset_cache_dir(self.cache_dir)
|
|
logger.info("Setting up environment...")
|
|
success = self.setup_controller.setup(self.config)
|
|
if success:
|
|
# Mark environment as used when setup is successfully executed
|
|
if self.config: # Only mark as used if there were actual setup operations
|
|
self.is_environment_used = True
|
|
break
|
|
else:
|
|
logger.error(
|
|
"Environment setup failed, retrying (%d/%d)...",
|
|
attempt + 1,
|
|
MAX_RETRIES,
|
|
)
|
|
time.sleep(5)
|
|
else:
|
|
break
|
|
|
|
logger.info("Environment setup complete.")
|
|
|
|
if task_config.get("proxy", False) and self.enable_proxy:
|
|
# If using proxy and proxy is enabled, set up the proxy configuration
|
|
self.setup_controller._proxy_setup()
|
|
|
|
observation = self._get_obs()
|
|
return observation
|
|
|
|
def _get_obs(self):
|
|
# We provide screenshot, accessibility_tree (optional), terminal (optional), and instruction.
|
|
# can be customized and scaled
|
|
return {
|
|
"screenshot": self.controller.get_screenshot(),
|
|
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
|
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
|
"instruction": self.instruction
|
|
}
|
|
|
|
@property
|
|
def vm_platform(self):
|
|
return self.controller.get_vm_platform()
|
|
|
|
@property
|
|
def vm_screen_size(self):
|
|
return self.controller.get_vm_screen_size()
|
|
|
|
def _set_task_info(self, task_config: Dict[str, Any]):
|
|
"""Set task info (proxy logic is handled in reset method)"""
|
|
self.task_id: str = task_config["id"]
|
|
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
self.instruction = task_config["instruction"]
|
|
self.config = task_config["config"] if "config" in task_config else []
|
|
|
|
self._set_evaluator_info(task_config)
|
|
|
|
def _set_evaluator_info(self, task_config: Dict[str, Any]):
|
|
"""Set evaluator information from task config"""
|
|
# evaluator dict
|
|
# func -> metric function string, or list of metric function strings
|
|
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
|
# result -> result getter config, or list of result getter configs
|
|
# expected (optional) -> expected getter config, or list of expected getter configs
|
|
# options (optional) -> metric options, or list of metric options
|
|
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
|
|
# even if one of the metrics does not need expected or options field, it should be included in the list with None
|
|
self.evaluator = task_config["evaluator"]
|
|
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \
|
|
if isinstance(self.evaluator["func"], list) \
|
|
else getattr(metrics, self.evaluator["func"])
|
|
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
|
if "result" in self.evaluator and len(self.evaluator["result"]) > 0:
|
|
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
|
self.evaluator["result"]] \
|
|
if isinstance(self.evaluator["result"], list) \
|
|
else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
|
else:
|
|
self.result_getter = [None] * len(self.metric) \
|
|
if isinstance(self.metric, list) \
|
|
else None
|
|
|
|
if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0:
|
|
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
|
self.evaluator["expected"]] \
|
|
if isinstance(self.evaluator["expected"], list) \
|
|
else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
|
|
else:
|
|
self.expected_getter = [None] * len(self.metric) \
|
|
if isinstance(self.metric, list) \
|
|
else None
|
|
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in
|
|
self.evaluator["options"]] \
|
|
if isinstance(self.evaluator.get("options", {}), list) \
|
|
else self.evaluator["options"] \
|
|
if "options" in self.evaluator \
|
|
else [{}] * len(self.metric) \
|
|
if isinstance(self.metric, list) \
|
|
else {}
|
|
|
|
assert (not isinstance(self.evaluator["func"], list)
|
|
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
|
|
self.metric_options)))
|
|
|
|
def step(self, action, pause=2):
|
|
self._step_no += 1
|
|
self.action_history.append(action)
|
|
|
|
# Mark environment as used when step is called
|
|
self.is_environment_used = True
|
|
|
|
reward = 0 # todo: Define reward calculation for each example
|
|
done = False # todo: Define episode termination condition for each example
|
|
info = {}
|
|
|
|
# handle the special actions
|
|
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
|
|
if action == 'WAIT':
|
|
time.sleep(pause)
|
|
elif action == 'FAIL':
|
|
done = True
|
|
info = {"fail": True}
|
|
elif action == 'DONE':
|
|
done = True
|
|
info = {"done": True}
|
|
|
|
if self.action_space == "computer_13":
|
|
# the set of all possible actions defined in the action representation
|
|
self.controller.execute_action(action)
|
|
elif self.action_space == "pyautogui":
|
|
if action in ['WAIT', 'FAIL', 'DONE']:
|
|
self.controller.execute_action(action)
|
|
else:
|
|
# the set of all possible python commands insides `pyautogui`
|
|
self.controller.execute_python_command(action)
|
|
|
|
time.sleep(pause)
|
|
observation = self._get_obs()
|
|
|
|
return observation, reward, done, info
|
|
|
|
def evaluate(self):
|
|
"""
|
|
Evaluate whether the task is successfully completed.
|
|
"""
|
|
|
|
postconfig = self.evaluator.get("postconfig", [])
|
|
self.setup_controller.setup(postconfig)
|
|
# Mark environment as used if there were postconfig setup operations
|
|
if postconfig:
|
|
self.is_environment_used = True
|
|
|
|
if self.evaluator['func'] == "infeasible":
|
|
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
|
|
return 1
|
|
else:
|
|
return 0
|
|
else:
|
|
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
|
|
return 0
|
|
|
|
if type(self.metric) == list:
|
|
# Multiple metrics to evaluate whether the task is successfully completed
|
|
results = []
|
|
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
|
|
if "expected" in self.evaluator:
|
|
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
|
|
for idx, metric in enumerate(self.metric):
|
|
try:
|
|
config = self.evaluator["result"][idx]
|
|
result_state = self.result_getter[idx](self, config)
|
|
except FileNotFoundError:
|
|
logger.error("File not found!")
|
|
if self.metric_conj == 'and':
|
|
return 0
|
|
|
|
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
|
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
|
|
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
|
|
else:
|
|
metric: int = metric(result_state, **self.metric_options[idx])
|
|
|
|
if self.metric_conj == 'and' and float(metric) == 0.0:
|
|
return 0
|
|
elif self.metric_conj == 'or' and float(metric) == 1.0:
|
|
return 1
|
|
else:
|
|
results.append(metric)
|
|
|
|
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
|
|
else:
|
|
# Single metric to evaluate whether the task is successfully completed
|
|
try:
|
|
result_state = self.result_getter(self, self.evaluator["result"])
|
|
except FileNotFoundError:
|
|
logger.error("File not found!")
|
|
return 0
|
|
|
|
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
|
expected_state = self.expected_getter(self, self.evaluator["expected"])
|
|
metric: float = self.metric(result_state, expected_state, **self.metric_options)
|
|
else:
|
|
metric: float = self.metric(result_state, **self.metric_options)
|
|
|
|
return metric
|
|
|
|
def render(self, mode='rgb_array'):
|
|
if mode == 'rgb_array':
|
|
return self.controller.get_screenshot()
|
|
else:
|
|
raise ValueError('Unsupported render mode: {}'.format(mode))
|