Refactor with pyautogui
This commit is contained in:
@@ -1,34 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal, List, Tuple
|
||||
import subprocess
|
||||
from fabric import Connection
|
||||
import time
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
import uuid
|
||||
from PIL import Image
|
||||
|
||||
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, \
|
||||
PythonMouseController
|
||||
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, \
|
||||
PythonKeyboardController
|
||||
from desktop_env.controllers.python import PythonController
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
CLICK = 0
|
||||
MOUSE_DOWN = 1
|
||||
MOUSE_UP = 2
|
||||
MOUSE_MOVE = 3
|
||||
KEY = 4
|
||||
KEY_DOWN = 5
|
||||
KEY_UP = 6
|
||||
TYPE = 7
|
||||
|
||||
|
||||
VM_TYPE = Literal['ubuntu', 'windows']
|
||||
def _execute_command(command: List[str]) -> None:
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
@@ -37,67 +23,20 @@ class DesktopEnv(gym.Env):
|
||||
def __init__(
|
||||
self,
|
||||
path_to_vm: str,
|
||||
username: str,
|
||||
password: str = None,
|
||||
host: str = "192.168.7.128:5000",
|
||||
snapshot_path: str = "base",
|
||||
vm_os: VM_TYPE = "ubuntu"
|
||||
):
|
||||
|
||||
# The path to the vmx file of your vm
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = path_to_vm
|
||||
|
||||
# username and password for your vm
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
self.host = host
|
||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||
|
||||
# Initialize emulator
|
||||
# Initialize emulator and controller
|
||||
print("Initializing...")
|
||||
self._start_emulator()
|
||||
self.controller = PythonController(http_server=self.host)
|
||||
|
||||
# set up controllers
|
||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
||||
|
||||
# Get the screen size
|
||||
self.screen_width, self.screen_height = self._get_screensize()
|
||||
|
||||
# Define the action and observation space
|
||||
self.action_space = spaces.Dict({
|
||||
"action_type": spaces.Discrete(len(Action)),
|
||||
"click_type": spaces.Discrete(len(MouseClick)),
|
||||
"x": spaces.Discrete(self.screen_width),
|
||||
"y": spaces.Discrete(self.screen_height),
|
||||
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
|
||||
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
|
||||
})
|
||||
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3),
|
||||
dtype=np.uint8)
|
||||
|
||||
# Additional setup
|
||||
self.metadata = {'render.modes': ['rgb_array']}
|
||||
|
||||
|
||||
def _get_screensize(self):
|
||||
screenshot_path = self._get_obs()
|
||||
img = Image.open(screenshot_path)
|
||||
return img.size
|
||||
|
||||
def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
|
||||
if vm_os == "ubuntu":
|
||||
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
||||
mouse_controller = XDoToolMouseController(ssh_connection)
|
||||
keyboard_controller = XDoToolKeyboardController(ssh_connection)
|
||||
elif vm_os == "windows":
|
||||
mouse_controller = PythonMouseController(http_server=self.host)
|
||||
keyboard_controller = PythonKeyboardController(http_server=self.host)
|
||||
else:
|
||||
raise NotImplementedError(vm_os)
|
||||
|
||||
return mouse_controller, keyboard_controller
|
||||
# todo: define the action space and the observation space as gym did
|
||||
|
||||
def _start_emulator(self):
|
||||
while True:
|
||||
@@ -109,52 +48,35 @@ class DesktopEnv(gym.Env):
|
||||
break
|
||||
else:
|
||||
print("Starting VM...")
|
||||
self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||
time.sleep(10)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing command: {e.output.decode().strip()}")
|
||||
|
||||
def _execute_command(self, command: List[str]) -> None:
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||
|
||||
def _save_state(self):
|
||||
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
|
||||
def _get_screenshot(self):
|
||||
random_uuid = str(uuid.uuid4())
|
||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
||||
|
||||
if self.password:
|
||||
self._execute_command(
|
||||
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
|
||||
else:
|
||||
self._execute_command(
|
||||
["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path])
|
||||
# Get the screenshot and save to the image_path
|
||||
screenshot = self.controller.get_screenshot()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
return image_path
|
||||
|
||||
def _get_obs(self):
|
||||
screenshot_image_path = self._get_screenshot()
|
||||
self._add_cursor(screenshot_image_path)
|
||||
return screenshot_image_path
|
||||
|
||||
def _add_cursor(self, img_path: str):
|
||||
x, y = self.mouse_controller.get_mouse()
|
||||
cursor_image = Image.open("./desktop_env/assets/cursor.png")
|
||||
cursor_image = cursor_image.resize((int(cursor_image.width / 2), int(cursor_image.height / 2)))
|
||||
|
||||
screenshot = Image.open(img_path)
|
||||
screenshot.paste(cursor_image, (x, y), cursor_image)
|
||||
screenshot.save(img_path)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed=None, options=None):
|
||||
print("Resetting environment...")
|
||||
|
||||
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
time.sleep(5)
|
||||
|
||||
print("Starting emulator...")
|
||||
@@ -165,75 +87,11 @@ class DesktopEnv(gym.Env):
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, list):
|
||||
for a in action:
|
||||
observation, reward, done, info = self.step(a)
|
||||
return observation, reward, done, info
|
||||
|
||||
# todo: handle the case when the action is not a single action
|
||||
try:
|
||||
action_type = Action(action['action_type'])
|
||||
except KeyError:
|
||||
done = True
|
||||
return self._get_obs(), 0, done, {}
|
||||
|
||||
if action_type == Action.CLICK:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_click()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_click()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_click()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_DOWN:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_down()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_down()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_down()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_UP:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_up()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_up()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_up()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_MOVE:
|
||||
self.mouse_controller.mouse_move(x=action['x'], y=action['y'])
|
||||
elif action_type == Action.KEY:
|
||||
self.keyboard_controller.key(action['key'])
|
||||
elif action_type == Action.KEY_DOWN:
|
||||
self.keyboard_controller.key_down(action['key'])
|
||||
elif action_type == Action.KEY_UP:
|
||||
self.keyboard_controller.key_up(action['key'])
|
||||
elif action_type == Action.TYPE:
|
||||
for key in action['text']:
|
||||
if key == "\r" or key == "\n":
|
||||
self.keyboard_controller.key("enter")
|
||||
else:
|
||||
self.keyboard_controller.key(key)
|
||||
# sleep for 0.05 seconds with some random noise
|
||||
time.sleep(0.05 + np.random.normal(0, 0.01))
|
||||
|
||||
# Capture new state
|
||||
# Our action space is the set of all possible python commands insides `pyautogui`
|
||||
self.controller.execute_python_command(action)
|
||||
observation = self._get_obs()
|
||||
reward = 0 # Define reward calculation
|
||||
done = False # Define episode termination condition
|
||||
reward = 0 # todo: Define reward calculation for each example
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
return observation, reward, done, info
|
||||
|
||||
@@ -244,4 +102,4 @@ class DesktopEnv(gym.Env):
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
||||
def close(self):
|
||||
self._execute_command(["vmrun", "stop", self.path_to_vm])
|
||||
_execute_command(["vmrun", "stop", self.path_to_vm])
|
||||
|
||||
Reference in New Issue
Block a user