mouse and keyboard controllers for windows and linux
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
import subprocess
|
||||
from fabric import Connection
|
||||
import time
|
||||
@@ -8,6 +9,9 @@ from gymnasium import spaces
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, PythonMouseController
|
||||
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, PythonKeyboardController
|
||||
|
||||
class Action(Enum):
|
||||
CLICK = 0
|
||||
MOUSE_DOWN = 1
|
||||
@@ -16,24 +20,18 @@ class Action(Enum):
|
||||
KEY = 4
|
||||
TYPE = 5
|
||||
|
||||
class MouseClick(Enum):
|
||||
LEFT = 1
|
||||
MIDDLE = 2
|
||||
RIGHT = 3
|
||||
WHEEL_UP = 4
|
||||
WHEEL_DOWN = 5
|
||||
VM_TYPE = Literal['ubuntu', 'windows']
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""DesktopEnv with OpenAI Gym interface."""
|
||||
|
||||
def __init__(self, path_to_vm: str, username: str, password: str,
|
||||
host: str, snapshot_path: str = "snapshot"):
|
||||
host: str, snapshot_path: str = "snapshot", vm_os: VM_TYPE = "ubuntu"):
|
||||
self.path_to_vm = path_to_vm
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.snapshot_path = snapshot_path
|
||||
self.ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": password})
|
||||
|
||||
self.screen_width = 800
|
||||
self.screen_height = 800
|
||||
@@ -54,6 +52,22 @@ class DesktopEnv(gym.Env):
|
||||
self._start_emulator()
|
||||
self._wait_for_emulator_load()
|
||||
|
||||
# set up controllers
|
||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
||||
|
||||
def _create_controllers(self, vm_os: VM_TYPE) -> tuple[AbstractMouseController, AbstractKeyboardController]:
|
||||
if vm_os == "ubuntu":
|
||||
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
||||
mouse_controller = XDoToolMouseController(ssh_connection)
|
||||
keyboard_controller = XDoToolKeyboardController(ssh_connection)
|
||||
elif vm_os == "windows":
|
||||
mouse_controller = PythonMouseController(http_server=self.host)
|
||||
keyboard_controller = PythonKeyboardController(http_server=self.host)
|
||||
else:
|
||||
raise NotImplementedError(vm_os)
|
||||
|
||||
return mouse_controller, keyboard_controller
|
||||
|
||||
def _start_emulator(self):
|
||||
self._execute_command(["vmrun", "start", self.path_to_vm])
|
||||
|
||||
@@ -133,19 +147,49 @@ class DesktopEnv(gym.Env):
|
||||
def step(self, action):
|
||||
action_type = Action(action['action_type'])
|
||||
if action_type == Action.CLICK:
|
||||
self._click(MouseClick(action['click_type']))
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_click()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_click()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_click()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_DOWN:
|
||||
self._mousedown(MouseClick(action['click_type']))
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_down()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_down()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_down()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_UP:
|
||||
self._mouseup(MouseClick(action['click_type']))
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_up()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_up()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_up()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_MOVE:
|
||||
self._mouse_move(action['x'], action['y'])
|
||||
self.mouse_controller.mouse_move(x = action['x'], y = action['y'])
|
||||
elif action_type == Action.KEY:
|
||||
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
||||
self.key(key_sequence)
|
||||
self.keyboard_controller.key(key_sequence)
|
||||
elif action_type == Action.TYPE:
|
||||
text = ''.join(map(chr, action['text'])) # Convert integer array to string
|
||||
self._type(text)
|
||||
self.keyboard_controller.type(text)
|
||||
|
||||
# Capture new state
|
||||
observation = self._get_obs()
|
||||
|
||||
Reference in New Issue
Block a user