Add support for automatic VM download and configuration, enable auto-scaling management; move metadata retrieval out of the init function to speed up environment initialization.

This commit is contained in:
Timothyxxx
2024-04-21 19:51:15 +08:00
parent 29d2f69556
commit 0b3e7dca24
4 changed files with 349 additions and 162 deletions

View File

@@ -12,6 +12,7 @@ import gymnasium as gym
from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController
from desktop_env.evaluators import metrics, getters
from . import _get_vm_path
logger = logging.getLogger("desktopenv.env")
@@ -45,7 +46,7 @@ class DesktopEnv(gym.Env):
def __init__(
self,
path_to_vm: str,
path_to_vm: str = None,
snapshot_name: str = "init_state",
action_space: str = "computer_13",
cache_dir: str = "cache",
@@ -68,10 +69,10 @@ class DesktopEnv(gym.Env):
"""
# Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm if path_to_vm else _get_vm_path())))
self.snapshot_name = snapshot_name
self.cache_dir_base: str = cache_dir
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
# todo: add the logic to get the screen size from the VM
self.headless = headless
self.require_a11y_tree = require_a11y_tree
self.require_terminal = require_terminal
@@ -83,10 +84,6 @@ class DesktopEnv(gym.Env):
self.controller = PythonController(vm_ip=self.vm_ip)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
# Meta info of the VM
self.vm_platform: str = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size()
# mode: human or machine
self.instruction = None
assert action_space in ["computer_13", "pyautogui"]
@@ -98,6 +95,14 @@ class DesktopEnv(gym.Env):
self._step_no: int = 0
self.action_history: List[Dict[str, any]] = []
@property
def vm_platform(self):
return self.controller.get_vm_platform()
@property
def vm_screen_size(self):
return self.controller.get_vm_screen_size()
def _start_emulator(self):
while True:
try:
@@ -229,10 +234,6 @@ class DesktopEnv(gym.Env):
self._start_emulator()
logger.info("Emulator started.")
logger.info("Get meta info of the VM...")
self.vm_platform = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size()
logger.info("Setting up environment...")
self.setup_controller.setup(self.config)