Files
sci-gui-agent-benchmark/desktop_env/envs/desktop_env.py
David Chang eeb8a120d6 ver Jan5th
debugged
2024-01-05 15:20:47 +08:00

244 lines
9.1 KiB
Python

from __future__ import annotations
import os
import subprocess
import time
# import uuid
# import platform
from typing import List, Dict
from typing import Callable, Any, Optional
import tempfile
import gymnasium as gym
# import requests
from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController
# from desktop_env.evaluators import eval_funcs
from desktop_env.evaluators import metrics, getters
import logging
logger = logging.getLogger("desktopenv.env")
Metric = Callable[[Any, Any], float]
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
def _execute_command(command: List[str]) -> None:
if command[:4] == ["vmrun", "-T", "ws", "start"]:
p = subprocess.Popen(command)
p.wait()
else:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
if result.returncode != 0:
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
return result.stdout
class DesktopEnv(gym.Env):
"""DesktopEnv with OpenAI Gym interface."""
def __init__(
self,
path_to_vm: str,
action_space: str = "computer_13",
task_config: Dict[str, Any] = None,
tmp_dir: str = "tmp",
cache_dir: str = "cache"
):
"""
Args:
path_to_vm (str): path to .vmx file
action_space (str): "computer_13" | "pyautogui"
task_config (Dict[str, Any]): manages task configs integratedly,
including
* base snapshot
* task id (uuid)
* instruction
* setup config
* evaluator config
tmp_dir (str): temporary directory to store trajectory stuffs like
the extracted screenshots
cache_dir (str): cache directory to cache task-related stuffs like
reference file for evaluation
"""
# Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.tmp_dir_base: str = tmp_dir
self.cache_dir_base: str = cache_dir
# task-aware stuffs
# todo: handling the logic of snapshot directory
self._set_task_info(task_config)
# Initialize emulator and controller
logger.info("Initializing...")
self._start_emulator()
self.host = f"http://{self._get_vm_ip()}:5000"
self.controller = PythonController(http_server=self.host)
self.setup_controller = SetupController(http_server=self.host, cache_dir=self.cache_dir)
# mode: human or machine
assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs
# episodic stuffs, like tmp dir and counters, will be updated or reset
# when calling self.reset()
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
self._traj_no: int = -1
self._step_no: int = 0
self.action_history: List[Dict[str, any]] = []
def _start_emulator(self):
while True:
try:
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
output = output.decode()
output: List[str] = output.splitlines()
# if self.path_to_vm.lstrip("~/") in output:
if self.path_to_vm in output:
logger.info("VM is running.")
break
else:
logger.info("Starting VM...")
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
time.sleep(3)
except subprocess.CalledProcessError as e:
logger.error(f"Error executing command: {e.output.decode().strip()}")
def _get_vm_ip(self):
max_retries = 10
logger.info("Getting IP Address...")
for _ in range(max_retries):
try:
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
logger.info(f"IP address: {output}")
return output
except:
time.sleep(5)
logger.info("Retrying...")
raise Exception("Failed to get VM IP address!")
def _save_state(self):
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
def _get_screenshot(self):
# random_uuid = str(uuid.uuid4())
# os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
# image_path = os.path.join("tmp", random_uuid, "screenshot.png")
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
# Get the screenshot and save to the image_path
screenshot = self.controller.get_screenshot()
with open(image_path, "wb") as f:
f.write(screenshot)
return image_path
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
return screenshot_image_path
def _set_task_info(self, task_config: Dict[str, Any]):
self.snapshot_path = task_config["snapshot"]
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"]
self.evaluator = task_config["evaluator"]
self.metric: Metric = getattr(metrics, self.evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(
self.evaluator["expected"]["type"])) if "expected" in self.evaluator else None
self.metric_options: Dict[str, Any] = self.evaluator.get("options", {})
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None):
logger.info("Resetting environment...")
logger.info("Switching task...")
if task_config is not None:
self._set_task_info(task_config)
self.setup_controller.reset_cache_dir(self.cache_dir)
logger.info("Setting counters...")
self._traj_no += 1
self._step_no = 0
self.action_history.clear()
logger.info("Setup new temp dir...")
self.tmp_dir = tempfile.mkdtemp(
prefix="{:d}.{:}.".format(self._traj_no, self.task_id),
dir=self.tmp_dir_base
)
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5)
logger.info("Starting emulator...")
self._start_emulator()
logger.info("Emulator started.")
logger.info("Setting up environment...")
self.setup_controller.setup(self.config)
time.sleep(5)
logger.info("Environment setup complete.")
observation = self._get_obs()
return observation
def step(self, action, pause=0.5):
self._step_no += 1
# fixme: add reminding logic here, decide if the action is valid for the current action_space
if self.action_space == "computer_13":
# the set of all possible actions defined in the action representation
self.controller.execute_action(action)
elif self.action_space == "pyautogui":
# the set of all possible python commands insides `pyautogui`
self.controller.execute_python_command(action)
self.action_history.append(action)
# todo: maybe for the better here we need to add a logic to wait until the rendering is done
time.sleep(pause)
observation = {
"screenshot": self._get_obs(),
"instruction": self.instruction
}
reward = 0 # todo: Define reward calculation for each example
done = False # todo: Define episode termination condition for each example
info = {}
return observation, reward, done, info
def evaluate(self):
"""
Evaluate whether the task is successfully completed.
"""
self.setup_controller.setup(self.evaluator["postconfig"])
result_state = self.result_getter(self, self.evaluator["result"])
expected_state = self.expected_getter(self, self.evaluator["expected"]) if "expected" in self.evaluator \
else None
metric: float = self.metric(result_state, expected_state, **self.metric_options) if expected_state is not None \
else self.metric(result_state, **self.metric_options)
return metric
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self._get_obs()
else:
raise ValueError('Unsupported render mode: {}'.format(mode))
def close(self):
_execute_command(["vmrun", "stop", self.path_to_vm])