Clean code; Refactor environment to pass screenshot content instead of path

This commit is contained in:
Timothyxxx
2024-04-13 23:34:01 +08:00
parent b9ae9b72b2
commit 9c75df5dce
10 changed files with 144 additions and 213 deletions

View File

@@ -3,22 +3,16 @@ from __future__ import annotations
import logging
import os
import subprocess
import tempfile
import time
from typing import Callable, Any, Optional, Tuple
# import uuid
# import platform
from typing import List, Dict, Union
import gymnasium as gym
from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController
# from desktop_env.evaluators import eval_funcs
from desktop_env.evaluators import metrics, getters
# import requests
logger = logging.getLogger("desktopenv.env")
Metric = Callable[[Any, Any], float]
@@ -46,8 +40,7 @@ def _execute_command(command: List[str]) -> None:
class DesktopEnv(gym.Env):
"""
DesktopEnv with OpenAI Gym interface.
Fixme: refactor the logic when implementing the multi-process version
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
"""
def __init__(
@@ -55,32 +48,33 @@ class DesktopEnv(gym.Env):
path_to_vm: str,
snapshot_name: str = "init_state",
action_space: str = "computer_13",
tmp_dir: str = "tmp",
cache_dir: str = "cache",
screen_size: Tuple[int] = (1920, 1080),
headless: bool = False,
require_a11y_tree: bool = True,
require_terminal: bool = False,
):
"""
Args:
path_to_vm (str): path to .vmx file
snapshot_name (str): snapshot name to revert to, default to "init_state"
action_space (str): "computer_13" | "pyautogui"
tmp_dir (str): temporary directory to store trajectory stuffs like
the extracted screenshots
cache_dir (str): cache directory to cache task-related stuffs like
reference file for evaluation
screen_size (Tuple[int]): screen size of the VM
headless (bool): whether to run the VM in headless mode
require_a11y_tree (bool): whether to require accessibility tree
require_terminal (bool): whether to require terminal output
"""
# Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.snapshot_name = snapshot_name
self.tmp_dir_base: str = tmp_dir
self.cache_dir_base: str = cache_dir
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
self.headless = headless
self.require_a11y_tree = require_a11y_tree
os.makedirs(self.tmp_dir_base, exist_ok=True)
self.require_terminal = require_terminal
# Initialize emulator and controller
logger.info("Initializing...")
@@ -89,17 +83,17 @@ class DesktopEnv(gym.Env):
self.controller = PythonController(vm_ip=self.vm_ip)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
# Meta info of the VM, move to the reset() function
self.vm_platform: str = "" # self.controller.get_vm_platform()
# Meta info of the VM
self.vm_platform: str = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size()
# mode: human or machine
self.instruction = None
assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs
# episodic stuffs, like tmp dir and counters, will be updated or reset
# episodic stuffs, like counters, will be updated or reset
# when calling self.reset()
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
self._traj_no: int = -1
self._step_no: int = 0
self.action_history: List[Dict[str, any]] = []
@@ -140,11 +134,7 @@ class DesktopEnv(gym.Env):
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
def _get_screenshot(self):
# random_uuid = str(uuid.uuid4())
# os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
# image_path = os.path.join("tmp", random_uuid, "screenshot.png")
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
screenshot = None
# Get the screenshot and save to the image_path
max_retries = 20
for _ in range(max_retries):
@@ -153,14 +143,18 @@ class DesktopEnv(gym.Env):
break
time.sleep(1)
with open(image_path, "wb") as f:
f.write(screenshot)
if screenshot is None:
logger.error("Failed to get screenshot!")
return image_path
return screenshot
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
return screenshot_image_path
return {
"screenshot": self._get_screenshot(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
"instruction": self.instruction
}
def _set_task_info(self, task_config: Dict[str, Any]):
self.task_id: str = task_config["id"]
@@ -182,7 +176,7 @@ class DesktopEnv(gym.Env):
if isinstance(self.evaluator["func"], list) \
else getattr(metrics, self.evaluator["func"])
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
if "result" in self.evaluator and len(self.evaluator["result"])>0:
if "result" in self.evaluator and len(self.evaluator["result"]) > 0:
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
self.evaluator["result"]] \
if isinstance(self.evaluator["result"], list) \
@@ -192,7 +186,7 @@ class DesktopEnv(gym.Env):
if isinstance(self.metric, list) \
else None
if "expected" in self.evaluator and len(self.evaluator["expected"])>0:
if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0:
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
self.evaluator["expected"]] \
if isinstance(self.evaluator["expected"], list) \
@@ -227,18 +221,10 @@ class DesktopEnv(gym.Env):
self._step_no = 0
self.action_history.clear()
logger.info("Setup new temp dir...")
self.tmp_dir = tempfile.mkdtemp(
prefix="{:d}.{:}.".format(self._traj_no, self.task_id),
dir=self.tmp_dir_base
)
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
time.sleep(5)
print(self.vm_screen_size)
logger.info("Starting emulator...")
self._start_emulator()
logger.info("Emulator started.")
@@ -246,7 +232,6 @@ class DesktopEnv(gym.Env):
logger.info("Get meta info of the VM...")
self.vm_platform = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size()
print(self.vm_screen_size)
logger.info("Setting up environment...")
self.setup_controller.setup(self.config)
@@ -254,10 +239,7 @@ class DesktopEnv(gym.Env):
time.sleep(5)
logger.info("Environment setup complete.")
observation = {
"screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
}
observation = self._get_obs()
return observation
def step(self, action, pause=0.5):
@@ -279,7 +261,6 @@ class DesktopEnv(gym.Env):
done = True
info = {"done": True}
# fixme: add reminding logic here, decide if the action is valid for the current action_space
if self.action_space == "computer_13":
# the set of all possible actions defined in the action representation
self.controller.execute_action(action)
@@ -290,12 +271,7 @@ class DesktopEnv(gym.Env):
# the set of all possible python commands insides `pyautogui`
self.controller.execute_python_command(action)
observation = {
"screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
# "terminal": self.controller.get_terminal_output(),
"instruction": self.instruction
}
observation = self._get_obs()
return observation, reward, done, info
@@ -358,7 +334,7 @@ class DesktopEnv(gym.Env):
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self._get_obs()
return self._get_screenshot()
else:
raise ValueError('Unsupported render mode: {}'.format(mode))