Refactor with pyautogui

This commit is contained in:
Timothyxxx
2023-12-02 17:52:00 +08:00
parent 3aebeffec3
commit 992d8f8fce
11 changed files with 144 additions and 515 deletions

View File

@@ -1,34 +1,20 @@
from __future__ import annotations
import os
from enum import Enum
from typing import Literal, List, Tuple
import subprocess
from fabric import Connection
import time
import uuid
from typing import List
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import uuid
from PIL import Image
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, \
PythonMouseController
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, \
PythonKeyboardController
from desktop_env.controllers.python import PythonController
class Action(Enum):
CLICK = 0
MOUSE_DOWN = 1
MOUSE_UP = 2
MOUSE_MOVE = 3
KEY = 4
KEY_DOWN = 5
KEY_UP = 6
TYPE = 7
VM_TYPE = Literal['ubuntu', 'windows']
def _execute_command(command: List[str]) -> None:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
if result.returncode != 0:
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
class DesktopEnv(gym.Env):
@@ -37,67 +23,20 @@ class DesktopEnv(gym.Env):
def __init__(
self,
path_to_vm: str,
username: str,
password: str = None,
host: str = "192.168.7.128:5000",
snapshot_path: str = "base",
vm_os: VM_TYPE = "ubuntu"
):
# The path to the vmx file of your vm
# Initialize environment variables
self.path_to_vm = path_to_vm
# username and password for your vm
self.username = username
self.password = password
self.host = host
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
# Initialize emulator
# Initialize emulator and controller
print("Initializing...")
self._start_emulator()
self.controller = PythonController(http_server=self.host)
# set up controllers
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
# Get the screen size
self.screen_width, self.screen_height = self._get_screensize()
# Define the action and observation space
self.action_space = spaces.Dict({
"action_type": spaces.Discrete(len(Action)),
"click_type": spaces.Discrete(len(MouseClick)),
"x": spaces.Discrete(self.screen_width),
"y": spaces.Discrete(self.screen_height),
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
})
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3),
dtype=np.uint8)
# Additional setup
self.metadata = {'render.modes': ['rgb_array']}
def _get_screensize(self):
screenshot_path = self._get_obs()
img = Image.open(screenshot_path)
return img.size
def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
if vm_os == "ubuntu":
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
mouse_controller = XDoToolMouseController(ssh_connection)
keyboard_controller = XDoToolKeyboardController(ssh_connection)
elif vm_os == "windows":
mouse_controller = PythonMouseController(http_server=self.host)
keyboard_controller = PythonKeyboardController(http_server=self.host)
else:
raise NotImplementedError(vm_os)
return mouse_controller, keyboard_controller
# todo: define the action space and the observation space as gym did
def _start_emulator(self):
while True:
@@ -109,52 +48,35 @@ class DesktopEnv(gym.Env):
break
else:
print("Starting VM...")
self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
time.sleep(10)
except subprocess.CalledProcessError as e:
print(f"Error executing command: {e.output.decode().strip()}")
def _execute_command(self, command: List[str]) -> None:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
if result.returncode != 0:
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
def _save_state(self):
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
def _get_screenshot(self):
random_uuid = str(uuid.uuid4())
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
if self.password:
self._execute_command(
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
else:
self._execute_command(
["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path])
# Get the screenshot and save to the image_path
screenshot = self.controller.get_screenshot()
with open(image_path, "wb") as f:
f.write(screenshot)
return image_path
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
self._add_cursor(screenshot_image_path)
return screenshot_image_path
def _add_cursor(self, img_path: str):
x, y = self.mouse_controller.get_mouse()
cursor_image = Image.open("./desktop_env/assets/cursor.png")
cursor_image = cursor_image.resize((int(cursor_image.width / 2), int(cursor_image.height / 2)))
screenshot = Image.open(img_path)
screenshot.paste(cursor_image, (x, y), cursor_image)
screenshot.save(img_path)
def reset(self):
def reset(self, seed=None, options=None):
print("Resetting environment...")
print("Reverting to snapshot to {}...".format(self.snapshot_path))
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5)
print("Starting emulator...")
@@ -165,75 +87,11 @@ class DesktopEnv(gym.Env):
return observation
def step(self, action):
if isinstance(action, list):
for a in action:
observation, reward, done, info = self.step(a)
return observation, reward, done, info
# todo: handle the case when the action is not a single action
try:
action_type = Action(action['action_type'])
except KeyError:
done = True
return self._get_obs(), 0, done, {}
if action_type == Action.CLICK:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_click()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_click()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_click()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_DOWN:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_down()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_down()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_down()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_UP:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_up()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_up()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_up()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_MOVE:
self.mouse_controller.mouse_move(x=action['x'], y=action['y'])
elif action_type == Action.KEY:
self.keyboard_controller.key(action['key'])
elif action_type == Action.KEY_DOWN:
self.keyboard_controller.key_down(action['key'])
elif action_type == Action.KEY_UP:
self.keyboard_controller.key_up(action['key'])
elif action_type == Action.TYPE:
for key in action['text']:
if key == "\r" or key == "\n":
self.keyboard_controller.key("enter")
else:
self.keyboard_controller.key(key)
# sleep for 0.05 seconds with some random noise
time.sleep(0.05 + np.random.normal(0, 0.01))
# Capture new state
# Our action space is the set of all possible python commands insides `pyautogui`
self.controller.execute_python_command(action)
observation = self._get_obs()
reward = 0 # Define reward calculation
done = False # Define episode termination condition
reward = 0 # todo: Define reward calculation for each example
done = False # todo: Define episode termination condition for each example
info = {}
return observation, reward, done, info
@@ -244,4 +102,4 @@ class DesktopEnv(gym.Env):
raise ValueError('Unsupported render mode: {}'.format(mode))
def close(self):
self._execute_command(["vmrun", "stop", self.path_to_vm])
_execute_command(["vmrun", "stop", self.path_to_vm])