Merge branch 'main' into zdy
This commit is contained in:
190
desktop_env/envs/actions.py
Normal file
190
desktop_env/envs/actions.py
Normal file
@@ -0,0 +1,190 @@
|
||||
X_MAX = 1920 # TODO: get the screen resolution
|
||||
Y_MAX = 1080
|
||||
|
||||
KEYBOARD_KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', 'shift', 'shiftleft', 'shiftright', 'sleep', 'stop', 'subtract', 'tab', 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', 'command', 'option', 'optionleft', 'optionright']
|
||||
|
||||
ACTION_SPACE = [
|
||||
{
|
||||
"action_type": "MOVE_TO",
|
||||
"note": "move the cursor to the specified position",
|
||||
"parameters": {
|
||||
"x": {
|
||||
"type": float,
|
||||
"range": [0, X_MAX],
|
||||
"optional": False,
|
||||
},
|
||||
"y": {
|
||||
"type": float,
|
||||
"range": [0, Y_MAX],
|
||||
"optional": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "CLICK",
|
||||
"note": "click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position",
|
||||
"parameters": {
|
||||
"button": {
|
||||
"type": str,
|
||||
"range": ["left", "right", "middle"],
|
||||
"optional": True,
|
||||
},
|
||||
"x": {
|
||||
"type": float,
|
||||
"range": [0, X_MAX],
|
||||
"optional": True,
|
||||
},
|
||||
"y": {
|
||||
"type": float,
|
||||
"range": [0, Y_MAX],
|
||||
"optional": True,
|
||||
},
|
||||
"num_clicks": {
|
||||
"type": int,
|
||||
"range": [1, 2, 3],
|
||||
"optional": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "MOUSE_DOWN",
|
||||
"note": "press the left button if the button not specified, otherwise press the specified button",
|
||||
"parameters": {
|
||||
"button": {
|
||||
"type": str,
|
||||
"range": ["left", "right", "middle"],
|
||||
"optional": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "MOUSE_UP",
|
||||
"note": "release the left button if the button not specified, otherwise release the specified button",
|
||||
"parameters": {
|
||||
"button": {
|
||||
"type": str,
|
||||
"range": ["left", "right", "middle"],
|
||||
"optional": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "RIGHT_CLICK",
|
||||
"note": "right click at the current position if x and y are not specified, otherwise right click at the specified position",
|
||||
"parameters": {
|
||||
"x": {
|
||||
"type": float,
|
||||
"range": [0, X_MAX],
|
||||
"optional": True,
|
||||
},
|
||||
"y": {
|
||||
"type": float,
|
||||
"range": [0, Y_MAX],
|
||||
"optional": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "DOUBLE_CLICK",
|
||||
"note": "double click at the current position if x and y are not specified, otherwise double click at the specified position",
|
||||
"parameters": {
|
||||
"x": {
|
||||
"type": float,
|
||||
"range": [0, X_MAX],
|
||||
"optional": True,
|
||||
},
|
||||
"y": {
|
||||
"type": float,
|
||||
"range": [0, Y_MAX],
|
||||
"optional": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "DRAG_TO",
|
||||
"note": "drag the cursor to the specified position with the left button pressed",
|
||||
"parameters": {
|
||||
"x": {
|
||||
"type": float,
|
||||
"range": [0, X_MAX],
|
||||
"optional": False,
|
||||
},
|
||||
"y": {
|
||||
"type": float,
|
||||
"range": [0, Y_MAX],
|
||||
"optional": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "SCROLL",
|
||||
"note": "scroll the mouse wheel up or down",
|
||||
"parameters": {
|
||||
"dx": {
|
||||
"type": int,
|
||||
"range": None,
|
||||
"optional": False,
|
||||
},
|
||||
"dy": {
|
||||
"type": int,
|
||||
"range": None,
|
||||
"optional": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "TYPING",
|
||||
"note": "type the specified text",
|
||||
"parameters": {
|
||||
"text": {
|
||||
"type": str,
|
||||
"range": None,
|
||||
"optional": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "PRESS",
|
||||
"note": "press the specified key and release it",
|
||||
"parameters": {
|
||||
"key": {
|
||||
"type": str,
|
||||
"range": KEYBOARD_KEYS,
|
||||
"optional": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "KEY_DOWN",
|
||||
"note": "press the specified key",
|
||||
"parameters": {
|
||||
"key": {
|
||||
"type": str,
|
||||
"range": KEYBOARD_KEYS,
|
||||
"optional": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "KEY_UP",
|
||||
"note": "release the specified key",
|
||||
"parameters": {
|
||||
"key": {
|
||||
"type": str,
|
||||
"range": KEYBOARD_KEYS,
|
||||
"optional": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"action_type": "HOTKEY",
|
||||
"note": "press the specified key combination",
|
||||
"parameters": {
|
||||
"keys": {
|
||||
"type": list,
|
||||
"range": [KEYBOARD_KEYS],
|
||||
"optional": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -1,203 +1,186 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from fabric import Connection
|
||||
import time
|
||||
import uuid
|
||||
import platform
|
||||
from typing import List
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, PythonMouseController
|
||||
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, PythonKeyboardController
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
from desktop_env.evaluators import eval_funcs
|
||||
|
||||
class Action(Enum):
|
||||
CLICK = 0
|
||||
MOUSE_DOWN = 1
|
||||
MOUSE_UP = 2
|
||||
MOUSE_MOVE = 3
|
||||
KEY = 4
|
||||
TYPE = 5
|
||||
|
||||
VM_TYPE = Literal['ubuntu', 'windows']
|
||||
def _execute_command(command: List[str]) -> None:
|
||||
if command[:4] == ["vmrun", "-T", "ws", "start"]:
|
||||
p = subprocess.Popen(command)
|
||||
p.wait()
|
||||
else:
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||
return result.stdout
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""DesktopEnv with OpenAI Gym interface."""
|
||||
|
||||
def __init__(self, path_to_vm: str, username: str, password: str,
|
||||
host: str, snapshot_path: str = "snapshot", vm_os: VM_TYPE = "ubuntu"):
|
||||
def __init__(
|
||||
self,
|
||||
path_to_vm: str,
|
||||
snapshot_path: str = "base",
|
||||
instruction: str = None,
|
||||
config: dict = None,
|
||||
evaluator: dict = None,
|
||||
action_space: str = "computer_13",
|
||||
):
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = path_to_vm
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.snapshot_path = snapshot_path
|
||||
|
||||
self.screen_width = 800
|
||||
self.screen_height = 800
|
||||
# Define the action and observation space
|
||||
self.action_space = spaces.Dict({
|
||||
"action_type": spaces.Discrete(len(Action)),
|
||||
"click_type": spaces.Discrete(len(MouseClick)),
|
||||
"x": spaces.Discrete(self.screen_width),
|
||||
"y": spaces.Discrete(self.screen_height),
|
||||
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
|
||||
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
|
||||
})
|
||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8)
|
||||
|
||||
# Additional setup
|
||||
self.metadata = {'render.modes': ['rgb_array']}
|
||||
# Initialize emulator and controller
|
||||
print("Initializing...")
|
||||
self._start_emulator()
|
||||
self._wait_for_emulator_load()
|
||||
self.host = f"http://{self._get_vm_ip()}:5000"
|
||||
self.controller = PythonController(http_server=self.host)
|
||||
self.setup_controller = SetupController(http_server=self.host)
|
||||
self.instruction = instruction
|
||||
self.config = config
|
||||
self.evaluator = evaluator
|
||||
|
||||
# set up controllers
|
||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
||||
|
||||
def _create_controllers(self, vm_os: VM_TYPE) -> tuple[AbstractMouseController, AbstractKeyboardController]:
|
||||
if vm_os == "ubuntu":
|
||||
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
||||
mouse_controller = XDoToolMouseController(ssh_connection)
|
||||
keyboard_controller = XDoToolKeyboardController(ssh_connection)
|
||||
elif vm_os == "windows":
|
||||
mouse_controller = PythonMouseController(http_server=self.host)
|
||||
keyboard_controller = PythonKeyboardController(http_server=self.host)
|
||||
else:
|
||||
raise NotImplementedError(vm_os)
|
||||
|
||||
return mouse_controller, keyboard_controller
|
||||
# mode: human or machine
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
self.action_space = action_space
|
||||
# todo: define the action space and the observation space as gym did, or extend theirs
|
||||
|
||||
def _start_emulator(self):
|
||||
self._execute_command(["vmrun", "start", self.path_to_vm])
|
||||
|
||||
def _wait_for_emulator_load(self):
|
||||
while True:
|
||||
try:
|
||||
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
||||
output = output.decode()
|
||||
if self.path_to_vm.lstrip("~/") in output:
|
||||
print("VM is running.")
|
||||
return
|
||||
break
|
||||
else:
|
||||
print("Waiting for VM to start...")
|
||||
time.sleep(5)
|
||||
print("Starting VM...")
|
||||
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||
time.sleep(3)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing command: {e.output.decode().strip()}")
|
||||
return
|
||||
|
||||
def _execute_command(self, command: list[str]) -> None:
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = process.communicate()
|
||||
if process.returncode != 0:
|
||||
print(f"Error executing command: {command}")
|
||||
print(stderr.decode())
|
||||
return None
|
||||
else:
|
||||
return stdout.decode()
|
||||
|
||||
def _execute_xdotool_command(self, command: list[str]) -> None:
|
||||
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
|
||||
return result.stdout.strip()
|
||||
def _get_vm_ip(self):
|
||||
max_retries = 10
|
||||
print("Getting IP Address...")
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
|
||||
print(f"IP address: {output}")
|
||||
return output
|
||||
except:
|
||||
time.sleep(5)
|
||||
print("Retrying...")
|
||||
raise Exception("Failed to get VM IP address!")
|
||||
|
||||
def _save_state(self):
|
||||
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
|
||||
def _click(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"click {click.value}")
|
||||
|
||||
def _mousedown(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"mousedown {click.value}")
|
||||
|
||||
def _mouseup(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"mouseup {click.value}")
|
||||
|
||||
def _mouse_move(self, x: int, y: int):
|
||||
self._execute_xdotool_command(f"mousemove {x} {y}")
|
||||
|
||||
def _key(self, key: str):
|
||||
self._execute_xdotool_command(f"key {key}")
|
||||
|
||||
def _type(self, text: str):
|
||||
self._execute_xdotool_command(f"type {text}")
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
|
||||
def _get_screenshot(self):
|
||||
image_path = "./screenshot.png"
|
||||
self._execute_command(["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
|
||||
random_uuid = str(uuid.uuid4())
|
||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
||||
|
||||
# Get the screenshot and save to the image_path
|
||||
screenshot = self.controller.get_screenshot()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
return image_path
|
||||
|
||||
|
||||
def _get_obs(self):
|
||||
print("OBS 1")
|
||||
screenshot_image_path = self._get_screenshot()
|
||||
print("OBS 2")
|
||||
with Image.open(screenshot_image_path) as img:
|
||||
return np.array(img)
|
||||
return screenshot_image_path
|
||||
|
||||
def reset(self):
|
||||
input("Reset #1 PE")
|
||||
#self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
input("Revert to snapshot #2 PE")
|
||||
def reset(self, seed=None, options=None):
|
||||
print("Resetting environment...")
|
||||
|
||||
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
time.sleep(5)
|
||||
|
||||
print("Starting emulator...")
|
||||
self._start_emulator()
|
||||
input("Started emulator #3 PE")
|
||||
self._wait_for_emulator_load()
|
||||
observation = self._get_obs()
|
||||
print("Emulator started.")
|
||||
|
||||
print("Setting up environment...")
|
||||
self.setup_controller.setup(self.config)
|
||||
|
||||
time.sleep(5)
|
||||
print("Environment setup complete.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
action_type = Action(action['action_type'])
|
||||
if action_type == Action.CLICK:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_click()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_click()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_click()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_DOWN:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_down()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_down()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_down()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_UP:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_up()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_up()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_up()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_MOVE:
|
||||
self.mouse_controller.mouse_move(x = action['x'], y = action['y'])
|
||||
elif action_type == Action.KEY:
|
||||
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
||||
self.keyboard_controller.key(key_sequence)
|
||||
elif action_type == Action.TYPE:
|
||||
text = ''.join(map(chr, action['text'])) # Convert integer array to string
|
||||
self.keyboard_controller.type(text)
|
||||
def step(self, action, pause=0.5):
|
||||
# fixme: add reminding logic here, decide if the action is valid for the current action_space
|
||||
if self.action_space == "computer_13":
|
||||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
elif self.action_space == "pyautogui":
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
self.controller.execute_python_command(action)
|
||||
|
||||
# Capture new state
|
||||
observation = self._get_obs()
|
||||
reward = 0 # Define reward calculation
|
||||
done = False # Define episode termination condition
|
||||
# todo: maybe for the better here we need to add a logic to wait until the rendering is done
|
||||
time.sleep(pause)
|
||||
observation = {
|
||||
"screenshot": self._get_obs(),
|
||||
"instruction": self.instruction
|
||||
}
|
||||
reward = 0 # todo: Define reward calculation for each example
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
return observation, reward, done, info
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Evaluate whether the task is successfully completed.
|
||||
"""
|
||||
def copy_file_to_local(_file_info):
|
||||
random_uuid = str(uuid.uuid4())
|
||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||
_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
|
||||
if _file_info["type"] == "cloud_file":
|
||||
url = _file_info["path"]
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
elif _file_info["type"] == "vm_file":
|
||||
# fixme: stream this part maybe as well
|
||||
file = self.controller.get_file(_file_info["path"])
|
||||
with open(_path, "wb") as f:
|
||||
f.write(file)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return _path
|
||||
|
||||
# todo: make this more flexible by refactoring
|
||||
eval_func = eval_funcs[self.evaluator["func"]]
|
||||
eval_func_vars = {}
|
||||
|
||||
for var_name, file_info in self.evaluator["paths"].items():
|
||||
path = copy_file_to_local(file_info)
|
||||
eval_func_vars[var_name] = path
|
||||
|
||||
return eval_func(**eval_func_vars)
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self._get_obs()
|
||||
@@ -205,4 +188,4 @@ class DesktopEnv(gym.Env):
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
||||
def close(self):
|
||||
self._execute_command(["vmrun", "stop", self.path_to_vm])
|
||||
_execute_command(["vmrun", "stop", self.path_to_vm])
|
||||
|
||||
Reference in New Issue
Block a user