393 lines
16 KiB
Python
393 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import tempfile
|
|
import time
|
|
from typing import Callable, Any, Optional, Tuple
|
|
# import uuid
|
|
# import platform
|
|
from typing import List, Dict, Union
|
|
|
|
import gymnasium as gym
|
|
|
|
from desktop_env.controllers.python import PythonController
|
|
from desktop_env.controllers.setup import SetupController
|
|
# from desktop_env.evaluators import eval_funcs
|
|
from desktop_env.evaluators import metrics, getters
|
|
|
|
# import requests
|
|
|
|
logger = logging.getLogger("desktopenv.env")
|
|
|
|
Metric = Callable[[Any, Any], float]
|
|
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
|
|
|
|
|
def _execute_command(command: List[str]) -> None:
|
|
if command[:4] == ["vmrun", "-T", "ws", "start"]:
|
|
p = subprocess.Popen(command)
|
|
p.wait()
|
|
else:
|
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
|
if result.returncode != 0:
|
|
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
|
return result.stdout
|
|
|
|
|
|
class DesktopEnv(gym.Env):
|
|
"""
|
|
DesktopEnv with OpenAI Gym interface.
|
|
Fixme: refactor the logic when implementing the multi-process version
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
path_to_vm: str,
|
|
action_space: str = "computer_13",
|
|
task_config: Dict[str, Any] = None,
|
|
tmp_dir: str = "D:\\NJU\\HKUNLP\\Desktop-Env\\tmp",
|
|
cache_dir: str = "D:\\NJU\\HKUNLP\\Desktop-Env\\cache",
|
|
screen_size: Tuple[int] = (1920, 1080)
|
|
):
|
|
"""
|
|
Args:
|
|
path_to_vm (str): path to .vmx file
|
|
action_space (str): "computer_13" | "pyautogui"
|
|
|
|
task_config (Dict[str, Any]): manages task configs integratedly,
|
|
including
|
|
* base snapshot
|
|
* task id (uuid)
|
|
* instruction
|
|
* setup config
|
|
* evaluator config
|
|
|
|
tmp_dir (str): temporary directory to store trajectory stuffs like
|
|
the extracted screenshots
|
|
cache_dir (str): cache directory to cache task-related stuffs like
|
|
reference file for evaluation
|
|
"""
|
|
|
|
# Initialize environment variables
|
|
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
|
self.tmp_dir_base: str = tmp_dir
|
|
self.cache_dir_base: str = cache_dir
|
|
self.vm_screen_size = screen_size
|
|
|
|
# task-aware stuffs
|
|
# todo: handling the logic of snapshot directory
|
|
self._set_task_info(task_config)
|
|
|
|
# Initialize emulator and controller
|
|
logger.info("Initializing...")
|
|
self._config_screen_size()
|
|
self._start_emulator()
|
|
self.vm_ip = self._get_vm_ip()
|
|
self.controller = PythonController(vm_ip=self.vm_ip)
|
|
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
|
|
|
|
# Meta info of the VM, move to the reset() function
|
|
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
|
|
|
# mode: human or machine
|
|
assert action_space in ["computer_13", "pyautogui"]
|
|
self.action_space = action_space
|
|
# todo: define the action space and the observation space as gym did, or extend theirs
|
|
|
|
# episodic stuffs, like tmp dir and counters, will be updated or reset
|
|
# when calling self.reset()
|
|
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
|
|
self._traj_no: int = -1
|
|
self._step_no: int = 0
|
|
self.action_history: List[Dict[str, any]] = []
|
|
|
|
def _config_screen_size(self):
|
|
def calculate_vram_size(width, height, bits_per_pixel=32):
|
|
"""
|
|
Calculate VRAM size for given width, height, and color depth.
|
|
Color depth defaults to 32 bits per pixel.
|
|
"""
|
|
bytes_per_pixel = bits_per_pixel // 8
|
|
vram_size = width * height * bytes_per_pixel
|
|
return vram_size
|
|
|
|
if not os.path.isfile(self.path_to_vm):
|
|
logger.warning(f"The specified vmx file does not exist: {self.path_to_vm}")
|
|
return False
|
|
|
|
width, height = self.vm_screen_size
|
|
vramSize = calculate_vram_size(width, height)
|
|
|
|
try:
|
|
with open(self.path_to_vm, 'r') as file:
|
|
lines = file.readlines()
|
|
|
|
new_lines = []
|
|
for line in lines:
|
|
if "svga.autodetect" in line:
|
|
continue
|
|
elif "svga.vramSize" in line:
|
|
continue
|
|
elif "displayWidth" in line:
|
|
continue
|
|
elif "displayHeight" in line:
|
|
continue
|
|
else:
|
|
new_lines.append(line)
|
|
|
|
# Append new settings for screen size and VRAM.
|
|
new_lines.append(f'svga.autodetect = "TRUE"\n')
|
|
new_lines.append(f'svga.vramSize = "{vramSize}"\n')
|
|
new_lines.append(f'displayWidth = "{width}"\n')
|
|
new_lines.append(f'displayHeight = "{height}"\n')
|
|
|
|
with open(self.path_to_vm, 'w') as file:
|
|
file.writelines(new_lines)
|
|
logger.info(f"Screen size for {self.path_to_vm} set to {width}x{height} with VRAM size {vramSize} bytes")
|
|
return True
|
|
except IOError as e:
|
|
logger.error(f"An IOError occurred: {e}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"An error occurred: {e}")
|
|
return False
|
|
|
|
def _start_emulator(self):
|
|
while True:
|
|
try:
|
|
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
|
output = output.decode()
|
|
output: List[str] = output.splitlines()
|
|
# if self.path_to_vm.lstrip("~/") in output:
|
|
if self.path_to_vm in output:
|
|
logger.info("VM is running.")
|
|
break
|
|
else:
|
|
logger.info("Starting VM...")
|
|
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
|
time.sleep(3)
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"Error executing command: {e.output.decode().strip()}")
|
|
|
|
def _get_vm_ip(self):
|
|
max_retries = 20
|
|
logger.info("Getting IP Address...")
|
|
for _ in range(max_retries):
|
|
try:
|
|
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
|
|
logger.info(f"IP address: {output}")
|
|
return output
|
|
except:
|
|
time.sleep(5)
|
|
logger.info("Retrying...")
|
|
raise Exception("Failed to get VM IP address!")
|
|
|
|
def _save_state(self):
|
|
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
|
|
|
def _get_screenshot(self):
|
|
# random_uuid = str(uuid.uuid4())
|
|
# os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
|
# image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
|
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
|
|
|
# Get the screenshot and save to the image_path
|
|
screenshot = self.controller.get_screenshot()
|
|
with open(image_path, "wb") as f:
|
|
f.write(screenshot)
|
|
|
|
return image_path
|
|
|
|
def _get_obs(self):
|
|
screenshot_image_path = self._get_screenshot()
|
|
return screenshot_image_path
|
|
|
|
def _set_task_info(self, task_config: Dict[str, Any]):
|
|
self.snapshot_path = task_config["snapshot"]
|
|
self.task_id: str = task_config["id"]
|
|
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
self.instruction = task_config["instruction"]
|
|
self.config = task_config["config"]
|
|
|
|
# evaluator dict
|
|
# func -> metric function string, or list of metric function strings
|
|
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
|
# result -> result getter config, or list of result getter configs
|
|
# expected (optional) -> expected getter config, or list of expected getter configs
|
|
# options (optional) -> metric options, or list of metric options
|
|
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
|
|
# even if one of the metrics does not need expected or options field, it should be included in the list with None
|
|
self.evaluator = task_config["evaluator"]
|
|
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \
|
|
if isinstance(self.evaluator["func"], list) \
|
|
else getattr(metrics, self.evaluator["func"])
|
|
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
|
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
|
self.evaluator["result"]] \
|
|
if isinstance(self.evaluator["result"], list) \
|
|
else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
|
if "expected" in self.evaluator:
|
|
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
|
self.evaluator["expected"]] \
|
|
if isinstance(self.evaluator["expected"], list) \
|
|
else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
|
|
else:
|
|
self.expected_getter = [None] * len(self.metric) \
|
|
if isinstance(self.metric, list) \
|
|
else None
|
|
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in
|
|
self.evaluator["options"]] \
|
|
if isinstance(self.evaluator.get("options", {}), list) \
|
|
else self.evaluator["options"] \
|
|
if "options" in self.evaluator \
|
|
else [{}] * len(self.metric) \
|
|
if isinstance(self.metric, list) \
|
|
else {}
|
|
|
|
assert (not isinstance(self.evaluator["func"], list)
|
|
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
|
|
self.metric_options)))
|
|
|
|
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
|
logger.info("Resetting environment...")
|
|
|
|
logger.info("Switching task...")
|
|
if task_config is not None:
|
|
self._set_task_info(task_config)
|
|
self.setup_controller.reset_cache_dir(self.cache_dir)
|
|
|
|
logger.info("Setting counters...")
|
|
self._traj_no += 1
|
|
self._step_no = 0
|
|
self.action_history.clear()
|
|
|
|
logger.info("Setup new temp dir...")
|
|
self.tmp_dir = tempfile.mkdtemp(
|
|
prefix="{:d}.{:}.".format(self._traj_no, self.task_id),
|
|
dir=self.tmp_dir_base
|
|
)
|
|
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
|
|
|
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
|
|
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
|
time.sleep(5)
|
|
|
|
self._config_screen_size()
|
|
print(self.vm_screen_size)
|
|
logger.info("Starting emulator...")
|
|
self._start_emulator()
|
|
logger.info("Emulator started.")
|
|
|
|
logger.info("Get meta info of the VM...")
|
|
self.vm_platform = self.controller.get_vm_platform()
|
|
self.vm_screen_size = self.controller.get_vm_screen_size()
|
|
print(self.vm_screen_size)
|
|
|
|
logger.info("Setting up environment...")
|
|
self.setup_controller.setup(self.config)
|
|
|
|
time.sleep(5)
|
|
logger.info("Environment setup complete.")
|
|
|
|
observation = {
|
|
"screenshot": self._get_obs(),
|
|
"accessibility_tree": self.controller.get_accessibility_tree(),
|
|
}
|
|
return observation
|
|
|
|
def step(self, action, pause=0.5):
|
|
self._step_no += 1
|
|
self.action_history.append(action)
|
|
|
|
reward = 0 # todo: Define reward calculation for each example
|
|
done = False # todo: Define episode termination condition for each example
|
|
info = {}
|
|
|
|
# handle the special actions
|
|
if action in ['WAIT', 'FAIL', 'DONE']:
|
|
if action == 'WAIT':
|
|
time.sleep(pause)
|
|
elif action == 'FAIL':
|
|
done = True
|
|
info = {"fail": True}
|
|
elif action == 'DONE':
|
|
done = True
|
|
info = {"done": True}
|
|
|
|
# fixme: add reminding logic here, decide if the action is valid for the current action_space
|
|
if self.action_space == "computer_13":
|
|
# the set of all possible actions defined in the action representation
|
|
self.controller.execute_action(action)
|
|
elif self.action_space == "pyautogui":
|
|
if action in ['WAIT', 'FAIL', 'DONE']:
|
|
self.controller.execute_action(action)
|
|
else:
|
|
# the set of all possible python commands insides `pyautogui`
|
|
self.controller.execute_python_command(action)
|
|
|
|
observation = {
|
|
"screenshot": self._get_obs(),
|
|
"accessibility_tree": self.controller.get_accessibility_tree(),
|
|
"terminal": self.controller.get_terminal_output(),
|
|
"instruction": self.instruction
|
|
}
|
|
|
|
return observation, reward, done, info
|
|
|
|
def evaluate(self):
|
|
"""
|
|
Evaluate whether the task is successfully completed.
|
|
"""
|
|
|
|
self.setup_controller.setup(self.evaluator.get("postconfig", []))
|
|
|
|
if type(self.metric) == list:
|
|
for idx, metric in enumerate(self.metric):
|
|
try:
|
|
config = self.evaluator["result"][idx]
|
|
result_state = self.result_getter[idx](self, config)
|
|
except FileNotFoundError:
|
|
logger.error("File not found!")
|
|
if self.metric_conj == 'and':
|
|
return 0
|
|
|
|
expected = self.evaluator["expected"][idx]
|
|
expected_state = self.expected_getter[idx](self, expected) if expected else None
|
|
|
|
metric: int = metric(result_state, expected_state,
|
|
**self.metric_options[idx]) if expected_state is not None \
|
|
else metric(result_state, **self.metric_options[idx])
|
|
|
|
if self.metric_conj == 'and' and not bool(metric):
|
|
return 0
|
|
elif self.metric_conj == 'or' and bool(metric):
|
|
return 1
|
|
return 1 if self.metric_conj == 'and' else 0
|
|
else:
|
|
try:
|
|
result_state = self.result_getter(self, self.evaluator["result"])
|
|
except FileNotFoundError:
|
|
logger.error("File not found!")
|
|
return 0
|
|
|
|
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
|
|
else None
|
|
|
|
metric: float = self.metric(result_state, expected_state,
|
|
**self.metric_options) if expected_state is not None \
|
|
else self.metric(result_state, **self.metric_options)
|
|
|
|
return metric
|
|
|
|
def render(self, mode='rgb_array'):
|
|
if mode == 'rgb_array':
|
|
return self._get_obs()
|
|
else:
|
|
raise ValueError('Unsupported render mode: {}'.format(mode))
|
|
|
|
def close(self):
|
|
_execute_command(["vmrun", "stop", self.path_to_vm])
|