from enum import Enum from typing import Literal, List, Tuple import subprocess from fabric import Connection import time import gymnasium as gym from gymnasium import spaces import numpy as np from PIL import Image from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, PythonMouseController from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, PythonKeyboardController class Action(Enum): CLICK = 0 MOUSE_DOWN = 1 MOUSE_UP = 2 MOUSE_MOVE = 3 KEY = 4 KEY_DOWN = 5 KEY_UP = 6 TYPE = 7 VM_TYPE = Literal['ubuntu', 'windows'] class DesktopEnv(gym.Env): """DesktopEnv with OpenAI Gym interface.""" def __init__(self, path_to_vm: str, username: str, password: str, host: str, snapshot_path: str = "some_point_browser", vm_os: VM_TYPE = "ubuntu"): self.path_to_vm = path_to_vm self.username = username self.password = password self.host = host self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory self.screen_width = 800 self.screen_height = 800 # Define the action and observation space self.action_space = spaces.Dict({ "action_type": spaces.Discrete(len(Action)), "click_type": spaces.Discrete(len(MouseClick)), "x": spaces.Discrete(self.screen_width), "y": spaces.Discrete(self.screen_height), "key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII "text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII }) self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8) # Additional setup self.metadata = {'render.modes': ['rgb_array']} # Initialize emulator print("Initializing...") self._start_emulator() # set up controllers self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os) def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]: if vm_os == "ubuntu": ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password}) mouse_controller = XDoToolMouseController(ssh_connection) keyboard_controller = XDoToolKeyboardController(ssh_connection) elif vm_os == "windows": mouse_controller = PythonMouseController(http_server=self.host) keyboard_controller = PythonKeyboardController(http_server=self.host) else: raise NotImplementedError(vm_os) return mouse_controller, keyboard_controller def _start_emulator(self): while True: try: output = subprocess.check_output(f"vmrun -T ws list", shell=True, stderr=subprocess.STDOUT) output = output.decode() if self.path_to_vm.lstrip("~/") in output: print("VM is running.") break else: print("Starting VM...") self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm]) time.sleep(5) except subprocess.CalledProcessError as e: print(f"Error executing command: {e.output.decode().strip()}") def _execute_command(self, command: List[str]) -> None: process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) stdout, stderr = process.communicate() if process.returncode != 0: print(f"Error executing command: {command}") return None else: return stdout.decode() def _save_state(self): self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path]) def _get_screenshot(self): image_path = "./screenshot.png" self._execute_command( ["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path]) return image_path def _get_obs(self): screenshot_image_path = self._get_screenshot() with Image.open(screenshot_image_path) as img: return np.array(img) def reset(self): print("Resetting environment...") print("Reverting to snapshot to {}...".format(self.snapshot_path)) self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path]) print("Starting emulator...") self._start_emulator() print("Emulator started.") observation = self._get_obs() return observation def step(self, action): action_type = Action(action['action_type']) if action_type == Action.CLICK: click = MouseClick(action['click_type']) if click == MouseClick.LEFT: self.mouse_controller.left_click() elif click == MouseClick.MIDDLE: self.mouse_controller.middle_click() elif click == MouseClick.RIGHT: self.mouse_controller.right_click() elif click == MouseClick.WHEEL_UP: self.mouse_controller.scroll_up() elif click == MouseClick.WHEEL_DOWN: self.mouse_controller.scroll_down() elif action_type == Action.MOUSE_DOWN: click = MouseClick(action['click_type']) if click == MouseClick.LEFT: self.mouse_controller.left_down() elif click == MouseClick.MIDDLE: self.mouse_controller.middle_down() elif click == MouseClick.RIGHT: self.mouse_controller.right_down() elif click == MouseClick.WHEEL_UP: self.mouse_controller.scroll_up() elif click == MouseClick.WHEEL_DOWN: self.mouse_controller.scroll_down() elif action_type == Action.MOUSE_UP: click = MouseClick(action['click_type']) if click == MouseClick.LEFT: self.mouse_controller.left_up() elif click == MouseClick.MIDDLE: self.mouse_controller.middle_up() elif click == MouseClick.RIGHT: self.mouse_controller.right_up() elif click == MouseClick.WHEEL_UP: self.mouse_controller.scroll_up() elif click == MouseClick.WHEEL_DOWN: self.mouse_controller.scroll_down() elif action_type == Action.MOUSE_MOVE: self.mouse_controller.mouse_move(x = action['x'], y = action['y']) elif action_type == Action.KEY: key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string self.keyboard_controller.key(key_sequence) elif action_type == Action.KEY_DOWN: key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string self.keyboard_controller.key_down(key_sequence) elif action_type == Action.KEY_UP: key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string self.keyboard_controller.key_up(key_sequence) elif action_type == Action.TYPE: text = ''.join(map(chr, action['text'])) # Convert integer array to string self.keyboard_controller.type(text) # Capture new state observation = self._get_obs() reward = 0 # Define reward calculation done = False # Define episode termination condition info = {} return observation, reward, done, info def render(self, mode='rgb_array'): if mode == 'rgb_array': return self._get_obs() else: raise ValueError('Unsupported render mode: {}'.format(mode)) def close(self): self._execute_command(["vmrun", "stop", self.path_to_vm])