Merge branch 'main' into zdy
This commit is contained in:
@@ -1,203 +1,186 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from fabric import Connection
|
||||
import time
|
||||
import uuid
|
||||
import platform
|
||||
from typing import List
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, PythonMouseController
|
||||
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, PythonKeyboardController
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
from desktop_env.evaluators import eval_funcs
|
||||
|
||||
class Action(Enum):
|
||||
CLICK = 0
|
||||
MOUSE_DOWN = 1
|
||||
MOUSE_UP = 2
|
||||
MOUSE_MOVE = 3
|
||||
KEY = 4
|
||||
TYPE = 5
|
||||
|
||||
VM_TYPE = Literal['ubuntu', 'windows']
|
||||
def _execute_command(command: List[str]) -> None:
|
||||
if command[:4] == ["vmrun", "-T", "ws", "start"]:
|
||||
p = subprocess.Popen(command)
|
||||
p.wait()
|
||||
else:
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||
return result.stdout
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""DesktopEnv with OpenAI Gym interface."""
|
||||
|
||||
def __init__(self, path_to_vm: str, username: str, password: str,
|
||||
host: str, snapshot_path: str = "snapshot", vm_os: VM_TYPE = "ubuntu"):
|
||||
def __init__(
|
||||
self,
|
||||
path_to_vm: str,
|
||||
snapshot_path: str = "base",
|
||||
instruction: str = None,
|
||||
config: dict = None,
|
||||
evaluator: dict = None,
|
||||
action_space: str = "computer_13",
|
||||
):
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = path_to_vm
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.snapshot_path = snapshot_path
|
||||
|
||||
self.screen_width = 800
|
||||
self.screen_height = 800
|
||||
# Define the action and observation space
|
||||
self.action_space = spaces.Dict({
|
||||
"action_type": spaces.Discrete(len(Action)),
|
||||
"click_type": spaces.Discrete(len(MouseClick)),
|
||||
"x": spaces.Discrete(self.screen_width),
|
||||
"y": spaces.Discrete(self.screen_height),
|
||||
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
|
||||
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
|
||||
})
|
||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8)
|
||||
|
||||
# Additional setup
|
||||
self.metadata = {'render.modes': ['rgb_array']}
|
||||
# Initialize emulator and controller
|
||||
print("Initializing...")
|
||||
self._start_emulator()
|
||||
self._wait_for_emulator_load()
|
||||
self.host = f"http://{self._get_vm_ip()}:5000"
|
||||
self.controller = PythonController(http_server=self.host)
|
||||
self.setup_controller = SetupController(http_server=self.host)
|
||||
self.instruction = instruction
|
||||
self.config = config
|
||||
self.evaluator = evaluator
|
||||
|
||||
# set up controllers
|
||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
||||
|
||||
def _create_controllers(self, vm_os: VM_TYPE) -> tuple[AbstractMouseController, AbstractKeyboardController]:
|
||||
if vm_os == "ubuntu":
|
||||
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
||||
mouse_controller = XDoToolMouseController(ssh_connection)
|
||||
keyboard_controller = XDoToolKeyboardController(ssh_connection)
|
||||
elif vm_os == "windows":
|
||||
mouse_controller = PythonMouseController(http_server=self.host)
|
||||
keyboard_controller = PythonKeyboardController(http_server=self.host)
|
||||
else:
|
||||
raise NotImplementedError(vm_os)
|
||||
|
||||
return mouse_controller, keyboard_controller
|
||||
# mode: human or machine
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
self.action_space = action_space
|
||||
# todo: define the action space and the observation space as gym did, or extend theirs
|
||||
|
||||
def _start_emulator(self):
|
||||
self._execute_command(["vmrun", "start", self.path_to_vm])
|
||||
|
||||
def _wait_for_emulator_load(self):
|
||||
while True:
|
||||
try:
|
||||
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
||||
output = output.decode()
|
||||
if self.path_to_vm.lstrip("~/") in output:
|
||||
print("VM is running.")
|
||||
return
|
||||
break
|
||||
else:
|
||||
print("Waiting for VM to start...")
|
||||
time.sleep(5)
|
||||
print("Starting VM...")
|
||||
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||
time.sleep(3)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing command: {e.output.decode().strip()}")
|
||||
return
|
||||
|
||||
def _execute_command(self, command: list[str]) -> None:
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = process.communicate()
|
||||
if process.returncode != 0:
|
||||
print(f"Error executing command: {command}")
|
||||
print(stderr.decode())
|
||||
return None
|
||||
else:
|
||||
return stdout.decode()
|
||||
|
||||
def _execute_xdotool_command(self, command: list[str]) -> None:
|
||||
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
|
||||
return result.stdout.strip()
|
||||
def _get_vm_ip(self):
|
||||
max_retries = 10
|
||||
print("Getting IP Address...")
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
|
||||
print(f"IP address: {output}")
|
||||
return output
|
||||
except:
|
||||
time.sleep(5)
|
||||
print("Retrying...")
|
||||
raise Exception("Failed to get VM IP address!")
|
||||
|
||||
def _save_state(self):
|
||||
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
|
||||
def _click(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"click {click.value}")
|
||||
|
||||
def _mousedown(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"mousedown {click.value}")
|
||||
|
||||
def _mouseup(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"mouseup {click.value}")
|
||||
|
||||
def _mouse_move(self, x: int, y: int):
|
||||
self._execute_xdotool_command(f"mousemove {x} {y}")
|
||||
|
||||
def _key(self, key: str):
|
||||
self._execute_xdotool_command(f"key {key}")
|
||||
|
||||
def _type(self, text: str):
|
||||
self._execute_xdotool_command(f"type {text}")
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
|
||||
def _get_screenshot(self):
|
||||
image_path = "./screenshot.png"
|
||||
self._execute_command(["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
|
||||
random_uuid = str(uuid.uuid4())
|
||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
||||
|
||||
# Get the screenshot and save to the image_path
|
||||
screenshot = self.controller.get_screenshot()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
return image_path
|
||||
|
||||
|
||||
def _get_obs(self):
|
||||
print("OBS 1")
|
||||
screenshot_image_path = self._get_screenshot()
|
||||
print("OBS 2")
|
||||
with Image.open(screenshot_image_path) as img:
|
||||
return np.array(img)
|
||||
return screenshot_image_path
|
||||
|
||||
def reset(self):
|
||||
input("Reset #1 PE")
|
||||
#self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
input("Revert to snapshot #2 PE")
|
||||
def reset(self, seed=None, options=None):
|
||||
print("Resetting environment...")
|
||||
|
||||
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
time.sleep(5)
|
||||
|
||||
print("Starting emulator...")
|
||||
self._start_emulator()
|
||||
input("Started emulator #3 PE")
|
||||
self._wait_for_emulator_load()
|
||||
observation = self._get_obs()
|
||||
print("Emulator started.")
|
||||
|
||||
print("Setting up environment...")
|
||||
self.setup_controller.setup(self.config)
|
||||
|
||||
time.sleep(5)
|
||||
print("Environment setup complete.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
action_type = Action(action['action_type'])
|
||||
if action_type == Action.CLICK:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_click()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_click()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_click()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_DOWN:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_down()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_down()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_down()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_UP:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
self.mouse_controller.left_up()
|
||||
elif click == MouseClick.MIDDLE:
|
||||
self.mouse_controller.middle_up()
|
||||
elif click == MouseClick.RIGHT:
|
||||
self.mouse_controller.right_up()
|
||||
elif click == MouseClick.WHEEL_UP:
|
||||
self.mouse_controller.scroll_up()
|
||||
elif click == MouseClick.WHEEL_DOWN:
|
||||
self.mouse_controller.scroll_down()
|
||||
elif action_type == Action.MOUSE_MOVE:
|
||||
self.mouse_controller.mouse_move(x = action['x'], y = action['y'])
|
||||
elif action_type == Action.KEY:
|
||||
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
||||
self.keyboard_controller.key(key_sequence)
|
||||
elif action_type == Action.TYPE:
|
||||
text = ''.join(map(chr, action['text'])) # Convert integer array to string
|
||||
self.keyboard_controller.type(text)
|
||||
def step(self, action, pause=0.5):
|
||||
# fixme: add reminding logic here, decide if the action is valid for the current action_space
|
||||
if self.action_space == "computer_13":
|
||||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
elif self.action_space == "pyautogui":
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
self.controller.execute_python_command(action)
|
||||
|
||||
# Capture new state
|
||||
observation = self._get_obs()
|
||||
reward = 0 # Define reward calculation
|
||||
done = False # Define episode termination condition
|
||||
# todo: maybe for the better here we need to add a logic to wait until the rendering is done
|
||||
time.sleep(pause)
|
||||
observation = {
|
||||
"screenshot": self._get_obs(),
|
||||
"instruction": self.instruction
|
||||
}
|
||||
reward = 0 # todo: Define reward calculation for each example
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
return observation, reward, done, info
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Evaluate whether the task is successfully completed.
|
||||
"""
|
||||
def copy_file_to_local(_file_info):
|
||||
random_uuid = str(uuid.uuid4())
|
||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||
_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
|
||||
if _file_info["type"] == "cloud_file":
|
||||
url = _file_info["path"]
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
elif _file_info["type"] == "vm_file":
|
||||
# fixme: stream this part maybe as well
|
||||
file = self.controller.get_file(_file_info["path"])
|
||||
with open(_path, "wb") as f:
|
||||
f.write(file)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return _path
|
||||
|
||||
# todo: make this more flexible by refactoring
|
||||
eval_func = eval_funcs[self.evaluator["func"]]
|
||||
eval_func_vars = {}
|
||||
|
||||
for var_name, file_info in self.evaluator["paths"].items():
|
||||
path = copy_file_to_local(file_info)
|
||||
eval_func_vars[var_name] = path
|
||||
|
||||
return eval_func(**eval_func_vars)
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self._get_obs()
|
||||
@@ -205,4 +188,4 @@ class DesktopEnv(gym.Env):
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
||||
def close(self):
|
||||
self._execute_command(["vmrun", "stop", self.path_to_vm])
|
||||
_execute_command(["vmrun", "stop", self.path_to_vm])
|
||||
|
||||
Reference in New Issue
Block a user