Files
sci-gui-agent-benchmark/desktop_env/envs/desktop_env.py
2023-11-27 00:29:09 +08:00

197 lines
7.9 KiB
Python

from enum import Enum
from typing import Literal, List, Tuple
import subprocess
from fabric import Connection
import time
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from PIL import Image
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, PythonMouseController
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, PythonKeyboardController
class Action(Enum):
CLICK = 0
MOUSE_DOWN = 1
MOUSE_UP = 2
MOUSE_MOVE = 3
KEY = 4
KEY_DOWN = 5
KEY_UP = 6
TYPE = 7
VM_TYPE = Literal['ubuntu', 'windows']
class DesktopEnv(gym.Env):
"""DesktopEnv with OpenAI Gym interface."""
def __init__(self, path_to_vm: str, username: str, password: str,
host: str, snapshot_path: str = "some_point_browser", vm_os: VM_TYPE = "ubuntu"):
self.path_to_vm = path_to_vm
self.username = username
self.password = password
self.host = host
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
self.screen_width = 800
self.screen_height = 800
# Define the action and observation space
self.action_space = spaces.Dict({
"action_type": spaces.Discrete(len(Action)),
"click_type": spaces.Discrete(len(MouseClick)),
"x": spaces.Discrete(self.screen_width),
"y": spaces.Discrete(self.screen_height),
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
})
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8)
# Additional setup
self.metadata = {'render.modes': ['rgb_array']}
# Initialize emulator
print("Initializing...")
self._start_emulator()
# set up controllers
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
if vm_os == "ubuntu":
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
mouse_controller = XDoToolMouseController(ssh_connection)
keyboard_controller = XDoToolKeyboardController(ssh_connection)
elif vm_os == "windows":
mouse_controller = PythonMouseController(http_server=self.host)
keyboard_controller = PythonKeyboardController(http_server=self.host)
else:
raise NotImplementedError(vm_os)
return mouse_controller, keyboard_controller
def _start_emulator(self):
while True:
try:
output = subprocess.check_output(f"vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
output = output.decode()
if self.path_to_vm.lstrip("~/") in output:
print("VM is running.")
break
else:
print("Starting VM...")
self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
time.sleep(5)
except subprocess.CalledProcessError as e:
print(f"Error executing command: {e.output.decode().strip()}")
def _execute_command(self, command: List[str]) -> None:
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error executing command: {command}")
return None
else:
return stdout.decode()
def _save_state(self):
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
def _get_screenshot(self):
image_path = "./screenshot.png"
self._execute_command(
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm,
image_path])
return image_path
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
with Image.open(screenshot_image_path) as img:
return np.array(img)
def reset(self):
print("Resetting environment...")
print("Reverting to snapshot to {}...".format(self.snapshot_path))
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
print("Starting emulator...")
self._start_emulator()
print("Emulator started.")
observation = self._get_obs()
return observation
def step(self, action):
action_type = Action(action['action_type'])
if action_type == Action.CLICK:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_click()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_click()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_click()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_DOWN:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_down()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_down()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_down()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_UP:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_up()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_up()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_up()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_MOVE:
self.mouse_controller.mouse_move(x = action['x'], y = action['y'])
elif action_type == Action.KEY:
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
self.keyboard_controller.key(key_sequence)
elif action_type == Action.KEY_DOWN:
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
self.keyboard_controller.key_down(key_sequence)
elif action_type == Action.KEY_UP:
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
self.keyboard_controller.key_up(key_sequence)
elif action_type == Action.TYPE:
text = ''.join(map(chr, action['text'])) # Convert integer array to string
self.keyboard_controller.type(text)
# Capture new state
observation = self._get_obs()
reward = 0 # Define reward calculation
done = False # Define episode termination condition
info = {}
return observation, reward, done, info
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self._get_obs()
else:
raise ValueError('Unsupported render mode: {}'.format(mode))
def close(self):
self._execute_command(["vmrun", "stop", self.path_to_vm])