Files
sci-gui-agent-benchmark/desktop_env/envs/desktop_env.py
2023-10-30 00:28:33 +08:00

143 lines
5.0 KiB
Python

from enum import Enum
import subprocess
from fabric import Connection
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from PIL import Image
class Action(Enum):
CLICK = 0
MOUSE_DOWN = 1
MOUSE_UP = 2
MOUSE_MOVE = 3
KEY = 4
TYPE = 5
class MouseClick(Enum):
LEFT = 1
MIDDLE = 2
RIGHT = 3
WHEEL_UP = 4
WHEEL_DOWN = 5
class DesktopEnv(gym.Env):
"""DesktopEnv with OpenAI Gym interface."""
def __init__(self, path_to_vm: str, username: str, password: str,
host: str, snapshot_path: str = "snapshot"):
self.path_to_vm = path_to_vm
self.username = username
self.password = password
self.host = host
self.snapshot_path = snapshot_path
self.ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": password})
self.screen_width = 800
self.screen_height = 800
# Define the action and observation space
self.action_space = spaces.Dict({
"action_type": spaces.Discrete(len(Action)),
"click_type": spaces.Discrete(len(MouseClick)),
"x": spaces.Discrete(self.screen_width),
"y": spaces.Discrete(self.screen_height),
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
})
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8)
# Additional setup
self.metadata = {'render.modes': ['rgb_array']}
self._start_emulator()
def _start_emulator(self):
self._execute_command(["vmrun", "start", self.path_to_vm])
def _execute_command(self, command: list[str]) -> None:
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error executing command: {command}")
print(stderr.decode())
return None
else:
return stdout.decode()
def _execute_xdotool_command(self, command: list[str]) -> None:
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
return result.stdout.strip()
def _save_state(self):
self._execute_command(["vmrun", "snapshot", self.snapshot_path])
def _click(self, click: MouseClick):
self._execute_xdotool_command(f"click {click.value}")
def _mousedown(self, click: MouseClick):
self._execute_xdotool_command(f"mousedown {click.value}")
def _mouseup(self, click: MouseClick):
self._execute_xdotool_command(f"mouseup {click.value}")
def _mouse_move(self, x: int, y: int):
self._execute_xdotool_command(f"mousemove {x} {y}")
def _key(self, key: str):
self._execute_xdotool_command(f"key {key}")
def _type(self, text: str):
self._execute_xdotool_command(f"type {text}")
def _get_screenshot(self):
image_path = "./screenshot.png"
self.ssh_connection.run("DISPLAY=:0 import -window root screenshot.png")
self._execute_command(["scp", "user@192.168.7.128:~/screenshot.png", image_path])
self.ssh_connection.run("rm -rf ~/screenshot.png")
return image_path
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
with Image.open(screenshot_image_path) as img:
return np.array(img)
def reset(self):
self._execute_command(["vmrun", "revertToSnapshot", self.snapshot_path])
observation = self._get_obs()
return observation
def step(self, action):
action_type = Action(action['action_type'])
if action_type == Action.CLICK:
self._click(MouseClick(action['click_type']))
elif action_type == Action.MOUSE_DOWN:
self._mousedown(MouseClick(action['click_type']))
elif action_type == Action.MOUSE_UP:
self._mouseup(MouseClick(action['click_type']))
elif action_type == Action.MOUSE_MOVE:
self._mouse_move(action['x'], action['y'])
elif action_type == Action.KEY:
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
self.key(key_sequence)
elif action_type == Action.TYPE:
text = ''.join(map(chr, action['text'])) # Convert integer array to string
self._type(text)
# Capture new state
observation = self._get_obs()
reward = 0 # Define reward calculation
done = False # Define episode termination condition
info = {}
return observation, reward, done, info
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self._get_obs()
else:
raise ValueError('Unsupported render mode: {}'.format(mode))
def close(self):
self._execute_command(["vmrun", "stop", self.path_to_vm])