Files
sci-gui-agent-benchmark/desktop_env/envs/desktop_env.py
2023-11-30 17:31:46 +08:00

248 lines
9.2 KiB
Python

import os
from enum import Enum
from typing import Literal, List, Tuple
import subprocess
from fabric import Connection
import time
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import uuid
from PIL import Image
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, \
PythonMouseController
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, \
PythonKeyboardController
class Action(Enum):
CLICK = 0
MOUSE_DOWN = 1
MOUSE_UP = 2
MOUSE_MOVE = 3
KEY = 4
KEY_DOWN = 5
KEY_UP = 6
TYPE = 7
VM_TYPE = Literal['ubuntu', 'windows']
class DesktopEnv(gym.Env):
"""DesktopEnv with OpenAI Gym interface."""
def __init__(
self,
path_to_vm: str,
username: str,
password: str = None,
host: str = "192.168.7.128:5000",
snapshot_path: str = "base",
vm_os: VM_TYPE = "ubuntu"
):
# The path to the vmx file of your vm
self.path_to_vm = path_to_vm
# username and password for your vm
self.username = username
self.password = password
self.host = host
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
# Initialize emulator
print("Initializing...")
self._start_emulator()
# set up controllers
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
# Get the screen size
self.screen_width, self.screen_height = self._get_screensize()
# Define the action and observation space
self.action_space = spaces.Dict({
"action_type": spaces.Discrete(len(Action)),
"click_type": spaces.Discrete(len(MouseClick)),
"x": spaces.Discrete(self.screen_width),
"y": spaces.Discrete(self.screen_height),
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
})
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3),
dtype=np.uint8)
# Additional setup
self.metadata = {'render.modes': ['rgb_array']}
def _get_screensize(self):
screenshot_path = self._get_obs()
img = Image.open(screenshot_path)
return img.size
def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
if vm_os == "ubuntu":
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
mouse_controller = XDoToolMouseController(ssh_connection)
keyboard_controller = XDoToolKeyboardController(ssh_connection)
elif vm_os == "windows":
mouse_controller = PythonMouseController(http_server=self.host)
keyboard_controller = PythonKeyboardController(http_server=self.host)
else:
raise NotImplementedError(vm_os)
return mouse_controller, keyboard_controller
def _start_emulator(self):
while True:
try:
output = subprocess.check_output(f"vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
output = output.decode()
if self.path_to_vm.lstrip("~/") in output:
print("VM is running.")
break
else:
print("Starting VM...")
self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
time.sleep(10)
except subprocess.CalledProcessError as e:
print(f"Error executing command: {e.output.decode().strip()}")
def _execute_command(self, command: List[str]) -> None:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
if result.returncode != 0:
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
def _save_state(self):
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
def _get_screenshot(self):
random_uuid = str(uuid.uuid4())
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
if self.password:
self._execute_command(
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
else:
self._execute_command(
["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path])
return image_path
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
self._add_cursor(screenshot_image_path)
return screenshot_image_path
def _add_cursor(self, img_path: str):
x, y = self.mouse_controller.get_mouse()
cursor_image = Image.open("./desktop_env/assets/cursor.png")
cursor_image = cursor_image.resize((int(cursor_image.width / 2), int(cursor_image.height / 2)))
screenshot = Image.open(img_path)
screenshot.paste(cursor_image, (x, y), cursor_image)
screenshot.save(img_path)
def reset(self):
print("Resetting environment...")
print("Reverting to snapshot to {}...".format(self.snapshot_path))
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5)
print("Starting emulator...")
self._start_emulator()
print("Emulator started.")
observation = self._get_obs()
return observation
def step(self, action):
if isinstance(action, list):
for a in action:
observation, reward, done, info = self.step(a)
return observation, reward, done, info
# todo: handle the case when the action is not a single action
try:
action_type = Action(action['action_type'])
except KeyError:
done = True
return self._get_obs(), 0, done, {}
if action_type == Action.CLICK:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_click()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_click()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_click()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_DOWN:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_down()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_down()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_down()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_UP:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_up()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_up()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_up()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_MOVE:
self.mouse_controller.mouse_move(x=action['x'], y=action['y'])
elif action_type == Action.KEY:
self.keyboard_controller.key(action['key'])
elif action_type == Action.KEY_DOWN:
self.keyboard_controller.key_down(action['key'])
elif action_type == Action.KEY_UP:
self.keyboard_controller.key_up(action['key'])
elif action_type == Action.TYPE:
for key in action['text']:
if key == "\r" or key == "\n":
self.keyboard_controller.key("enter")
else:
self.keyboard_controller.key(key)
# sleep for 0.05 seconds with some random noise
time.sleep(0.05 + np.random.normal(0, 0.01))
# Capture new state
observation = self._get_obs()
reward = 0 # Define reward calculation
done = False # Define episode termination condition
info = {}
return observation, reward, done, info
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self._get_obs()
else:
raise ValueError('Unsupported render mode: {}'.format(mode))
def close(self):
self._execute_command(["vmrun", "stop", self.path_to_vm])