209 lines
8.2 KiB
Python
209 lines
8.2 KiB
Python
from enum import Enum
|
|
from typing import Literal
|
|
import subprocess
|
|
from fabric import Connection
|
|
import time
|
|
|
|
import gymnasium as gym
|
|
from gymnasium import spaces
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, PythonMouseController
|
|
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, PythonKeyboardController
|
|
|
|
class Action(Enum):
|
|
CLICK = 0
|
|
MOUSE_DOWN = 1
|
|
MOUSE_UP = 2
|
|
MOUSE_MOVE = 3
|
|
KEY = 4
|
|
TYPE = 5
|
|
|
|
VM_TYPE = Literal['ubuntu', 'windows']
|
|
|
|
class DesktopEnv(gym.Env):
|
|
"""DesktopEnv with OpenAI Gym interface."""
|
|
|
|
def __init__(self, path_to_vm: str, username: str, password: str,
|
|
host: str, snapshot_path: str = "snapshot", vm_os: VM_TYPE = "ubuntu"):
|
|
self.path_to_vm = path_to_vm
|
|
self.username = username
|
|
self.password = password
|
|
self.host = host
|
|
self.snapshot_path = snapshot_path
|
|
|
|
self.screen_width = 800
|
|
self.screen_height = 800
|
|
# Define the action and observation space
|
|
self.action_space = spaces.Dict({
|
|
"action_type": spaces.Discrete(len(Action)),
|
|
"click_type": spaces.Discrete(len(MouseClick)),
|
|
"x": spaces.Discrete(self.screen_width),
|
|
"y": spaces.Discrete(self.screen_height),
|
|
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
|
|
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
|
|
})
|
|
|
|
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8)
|
|
|
|
# Additional setup
|
|
self.metadata = {'render.modes': ['rgb_array']}
|
|
self._start_emulator()
|
|
self._wait_for_emulator_load()
|
|
|
|
# set up controllers
|
|
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
|
|
|
def _create_controllers(self, vm_os: VM_TYPE) -> tuple[AbstractMouseController, AbstractKeyboardController]:
|
|
if vm_os == "ubuntu":
|
|
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
|
mouse_controller = XDoToolMouseController(ssh_connection)
|
|
keyboard_controller = XDoToolKeyboardController(ssh_connection)
|
|
elif vm_os == "windows":
|
|
mouse_controller = PythonMouseController(http_server=self.host)
|
|
keyboard_controller = PythonKeyboardController(http_server=self.host)
|
|
else:
|
|
raise NotImplementedError(vm_os)
|
|
|
|
return mouse_controller, keyboard_controller
|
|
|
|
def _start_emulator(self):
|
|
self._execute_command(["vmrun", "start", self.path_to_vm])
|
|
|
|
def _wait_for_emulator_load(self):
|
|
while True:
|
|
try:
|
|
output = subprocess.check_output(f"vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
|
output = output.decode()
|
|
if self.path_to_vm.lstrip("~/") in output:
|
|
print("VM is running.")
|
|
return
|
|
else:
|
|
print("Waiting for VM to start...")
|
|
time.sleep(5)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error executing command: {e.output.decode().strip()}")
|
|
return
|
|
|
|
def _execute_command(self, command: list[str]) -> None:
|
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
stdout, stderr = process.communicate()
|
|
if process.returncode != 0:
|
|
print(f"Error executing command: {command}")
|
|
print(stderr.decode())
|
|
return None
|
|
else:
|
|
return stdout.decode()
|
|
|
|
def _execute_xdotool_command(self, command: list[str]) -> None:
|
|
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
|
|
return result.stdout.strip()
|
|
|
|
def _save_state(self):
|
|
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
|
|
|
def _click(self, click: MouseClick):
|
|
self._execute_xdotool_command(f"click {click.value}")
|
|
|
|
def _mousedown(self, click: MouseClick):
|
|
self._execute_xdotool_command(f"mousedown {click.value}")
|
|
|
|
def _mouseup(self, click: MouseClick):
|
|
self._execute_xdotool_command(f"mouseup {click.value}")
|
|
|
|
def _mouse_move(self, x: int, y: int):
|
|
self._execute_xdotool_command(f"mousemove {x} {y}")
|
|
|
|
def _key(self, key: str):
|
|
self._execute_xdotool_command(f"key {key}")
|
|
|
|
def _type(self, text: str):
|
|
self._execute_xdotool_command(f"type {text}")
|
|
|
|
def _get_screenshot(self):
|
|
image_path = "./screenshot.png"
|
|
self.ssh_connection.run("DISPLAY=:0 import -window root screenshot.png")
|
|
self._execute_command(["scp", "user@192.168.7.128:~/screenshot.png", image_path])
|
|
self.ssh_connection.run("rm -rf ~/screenshot.png")
|
|
return image_path
|
|
|
|
def _get_obs(self):
|
|
screenshot_image_path = self._get_screenshot()
|
|
with Image.open(screenshot_image_path) as img:
|
|
return np.array(img)
|
|
|
|
def reset(self):
|
|
input()
|
|
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
|
input()
|
|
self._start_emulator()
|
|
input()
|
|
self._wait_for_emulator_load()
|
|
observation = self._get_obs()
|
|
|
|
return observation
|
|
|
|
def step(self, action):
|
|
action_type = Action(action['action_type'])
|
|
if action_type == Action.CLICK:
|
|
click = MouseClick(action['click_type'])
|
|
if click == MouseClick.LEFT:
|
|
self.mouse_controller.left_click()
|
|
elif click == MouseClick.MIDDLE:
|
|
self.mouse_controller.middle_click()
|
|
elif click == MouseClick.RIGHT:
|
|
self.mouse_controller.right_click()
|
|
elif click == MouseClick.WHEEL_UP:
|
|
self.mouse_controller.scroll_up()
|
|
elif click == MouseClick.WHEEL_DOWN:
|
|
self.mouse_controller.scroll_down()
|
|
elif action_type == Action.MOUSE_DOWN:
|
|
click = MouseClick(action['click_type'])
|
|
if click == MouseClick.LEFT:
|
|
self.mouse_controller.left_down()
|
|
elif click == MouseClick.MIDDLE:
|
|
self.mouse_controller.middle_down()
|
|
elif click == MouseClick.RIGHT:
|
|
self.mouse_controller.right_down()
|
|
elif click == MouseClick.WHEEL_UP:
|
|
self.mouse_controller.scroll_up()
|
|
elif click == MouseClick.WHEEL_DOWN:
|
|
self.mouse_controller.scroll_down()
|
|
elif action_type == Action.MOUSE_UP:
|
|
click = MouseClick(action['click_type'])
|
|
if click == MouseClick.LEFT:
|
|
self.mouse_controller.left_up()
|
|
elif click == MouseClick.MIDDLE:
|
|
self.mouse_controller.middle_up()
|
|
elif click == MouseClick.RIGHT:
|
|
self.mouse_controller.right_up()
|
|
elif click == MouseClick.WHEEL_UP:
|
|
self.mouse_controller.scroll_up()
|
|
elif click == MouseClick.WHEEL_DOWN:
|
|
self.mouse_controller.scroll_down()
|
|
elif action_type == Action.MOUSE_MOVE:
|
|
self.mouse_controller.mouse_move(x = action['x'], y = action['y'])
|
|
elif action_type == Action.KEY:
|
|
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
|
self.keyboard_controller.key(key_sequence)
|
|
elif action_type == Action.TYPE:
|
|
text = ''.join(map(chr, action['text'])) # Convert integer array to string
|
|
self.keyboard_controller.type(text)
|
|
|
|
# Capture new state
|
|
observation = self._get_obs()
|
|
reward = 0 # Define reward calculation
|
|
done = False # Define episode termination condition
|
|
info = {}
|
|
return observation, reward, done, info
|
|
|
|
def render(self, mode='rgb_array'):
|
|
if mode == 'rgb_array':
|
|
return self._get_obs()
|
|
else:
|
|
raise ValueError('Unsupported render mode: {}'.format(mode))
|
|
|
|
def close(self):
|
|
self._execute_command(["vmrun", "stop", self.path_to_vm])
|