from enum import Enum import subprocess from fabric import Connection import gymnasium as gym from gymnasium import spaces import numpy as np from PIL import Image class Action(Enum): CLICK = 0 MOUSE_DOWN = 1 MOUSE_UP = 2 MOUSE_MOVE = 3 KEY = 4 TYPE = 5 class MouseClick(Enum): LEFT = 1 MIDDLE = 2 RIGHT = 3 WHEEL_UP = 4 WHEEL_DOWN = 5 class DesktopEnv(gym.Env): """DesktopEnv with OpenAI Gym interface.""" def __init__(self, path_to_vm: str, username: str, password: str, host: str, snapshot_path: str = "snapshot"): self.path_to_vm = path_to_vm self.username = username self.password = password self.host = host self.snapshot_path = snapshot_path self.ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": password}) self.screen_width = 800 self.screen_height = 800 # Define the action and observation space self.action_space = spaces.Dict({ "action_type": spaces.Discrete(len(Action)), "click_type": spaces.Discrete(len(MouseClick)), "x": spaces.Discrete(self.screen_width), "y": spaces.Discrete(self.screen_height), "key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII "text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII }) self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8) # Additional setup self.metadata = {'render.modes': ['rgb_array']} self._start_emulator() def _start_emulator(self): self._execute_command(["vmrun", "start", self.path_to_vm]) def _execute_command(self, command: list[str]) -> None: process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.communicate() if process.returncode != 0: print(f"Error executing command: {command}") print(stderr.decode()) return None else: return stdout.decode() def _execute_xdotool_command(self, command: list[str]) -> None: result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True) return result.stdout.strip() def _save_state(self): self._execute_command(["vmrun", "snapshot", self.snapshot_path]) def _click(self, click: MouseClick): self._execute_xdotool_command(f"click {click.value}") def _mousedown(self, click: MouseClick): self._execute_xdotool_command(f"mousedown {click.value}") def _mouseup(self, click: MouseClick): self._execute_xdotool_command(f"mouseup {click.value}") def _mouse_move(self, x: int, y: int): self._execute_xdotool_command(f"mousemove {x} {y}") def _key(self, key: str): self._execute_xdotool_command(f"key {key}") def _type(self, text: str): self._execute_xdotool_command(f"type {text}") def _get_screenshot(self): image_path = "./screenshot.png" self.ssh_connection.run("DISPLAY=:0 import -window root screenshot.png") self._execute_command(["scp", "user@192.168.7.128:~/screenshot.png", image_path]) self.ssh_connection.run("rm -rf ~/screenshot.png") return image_path def _get_obs(self): screenshot_image_path = self._get_screenshot() with Image.open(screenshot_image_path) as img: return np.array(img) def reset(self): self._execute_command(["vmrun", "revertToSnapshot", self.snapshot_path]) observation = self._get_obs() return observation def step(self, action): action_type = Action(action['action_type']) if action_type == Action.CLICK: self._click(MouseClick(action['click_type'])) elif action_type == Action.MOUSE_DOWN: self._mousedown(MouseClick(action['click_type'])) elif action_type == Action.MOUSE_UP: self._mouseup(MouseClick(action['click_type'])) elif action_type == Action.MOUSE_MOVE: self._mouse_move(action['x'], action['y']) elif action_type == Action.KEY: key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string self.key(key_sequence) elif action_type == Action.TYPE: text = ''.join(map(chr, action['text'])) # Convert integer array to string self._type(text) # Capture new state observation = self._get_obs() reward = 0 # Define reward calculation done = False # Define episode termination condition info = {} return observation, reward, done, info def render(self, mode='rgb_array'): if mode == 'rgb_array': return self._get_obs() else: raise ValueError('Unsupported render mode: {}'.format(mode)) def close(self): self._execute_command(["vmrun", "stop", self.path_to_vm])