Merge branch 'main' into zdy

This commit is contained in:
zdy023
2023-12-19 11:06:17 +08:00
111 changed files with 22918 additions and 497 deletions

View File

@@ -1,203 +1,186 @@
from enum import Enum
from typing import Literal
from __future__ import annotations
import os
import subprocess
from fabric import Connection
import time
import uuid
import platform
from typing import List
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from PIL import Image
import requests
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, PythonMouseController
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, PythonKeyboardController
from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController
from desktop_env.evaluators import eval_funcs
class Action(Enum):
CLICK = 0
MOUSE_DOWN = 1
MOUSE_UP = 2
MOUSE_MOVE = 3
KEY = 4
TYPE = 5
VM_TYPE = Literal['ubuntu', 'windows']
def _execute_command(command: List[str]) -> None:
if command[:4] == ["vmrun", "-T", "ws", "start"]:
p = subprocess.Popen(command)
p.wait()
else:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
if result.returncode != 0:
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
return result.stdout
class DesktopEnv(gym.Env):
"""DesktopEnv with OpenAI Gym interface."""
def __init__(self, path_to_vm: str, username: str, password: str,
host: str, snapshot_path: str = "snapshot", vm_os: VM_TYPE = "ubuntu"):
def __init__(
self,
path_to_vm: str,
snapshot_path: str = "base",
instruction: str = None,
config: dict = None,
evaluator: dict = None,
action_space: str = "computer_13",
):
# Initialize environment variables
self.path_to_vm = path_to_vm
self.username = username
self.password = password
self.host = host
self.snapshot_path = snapshot_path
self.screen_width = 800
self.screen_height = 800
# Define the action and observation space
self.action_space = spaces.Dict({
"action_type": spaces.Discrete(len(Action)),
"click_type": spaces.Discrete(len(MouseClick)),
"x": spaces.Discrete(self.screen_width),
"y": spaces.Discrete(self.screen_height),
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
})
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8)
# Additional setup
self.metadata = {'render.modes': ['rgb_array']}
# Initialize emulator and controller
print("Initializing...")
self._start_emulator()
self._wait_for_emulator_load()
self.host = f"http://{self._get_vm_ip()}:5000"
self.controller = PythonController(http_server=self.host)
self.setup_controller = SetupController(http_server=self.host)
self.instruction = instruction
self.config = config
self.evaluator = evaluator
# set up controllers
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
def _create_controllers(self, vm_os: VM_TYPE) -> tuple[AbstractMouseController, AbstractKeyboardController]:
if vm_os == "ubuntu":
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
mouse_controller = XDoToolMouseController(ssh_connection)
keyboard_controller = XDoToolKeyboardController(ssh_connection)
elif vm_os == "windows":
mouse_controller = PythonMouseController(http_server=self.host)
keyboard_controller = PythonKeyboardController(http_server=self.host)
else:
raise NotImplementedError(vm_os)
return mouse_controller, keyboard_controller
# mode: human or machine
assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs
def _start_emulator(self):
self._execute_command(["vmrun", "start", self.path_to_vm])
def _wait_for_emulator_load(self):
while True:
try:
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
output = output.decode()
if self.path_to_vm.lstrip("~/") in output:
print("VM is running.")
return
break
else:
print("Waiting for VM to start...")
time.sleep(5)
print("Starting VM...")
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
time.sleep(3)
except subprocess.CalledProcessError as e:
print(f"Error executing command: {e.output.decode().strip()}")
return
def _execute_command(self, command: list[str]) -> None:
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error executing command: {command}")
print(stderr.decode())
return None
else:
return stdout.decode()
def _execute_xdotool_command(self, command: list[str]) -> None:
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
return result.stdout.strip()
def _get_vm_ip(self):
max_retries = 10
print("Getting IP Address...")
for _ in range(max_retries):
try:
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
print(f"IP address: {output}")
return output
except:
time.sleep(5)
print("Retrying...")
raise Exception("Failed to get VM IP address!")
def _save_state(self):
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
def _click(self, click: MouseClick):
self._execute_xdotool_command(f"click {click.value}")
def _mousedown(self, click: MouseClick):
self._execute_xdotool_command(f"mousedown {click.value}")
def _mouseup(self, click: MouseClick):
self._execute_xdotool_command(f"mouseup {click.value}")
def _mouse_move(self, x: int, y: int):
self._execute_xdotool_command(f"mousemove {x} {y}")
def _key(self, key: str):
self._execute_xdotool_command(f"key {key}")
def _type(self, text: str):
self._execute_xdotool_command(f"type {text}")
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
def _get_screenshot(self):
image_path = "./screenshot.png"
self._execute_command(["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
random_uuid = str(uuid.uuid4())
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
# Get the screenshot and save to the image_path
screenshot = self.controller.get_screenshot()
with open(image_path, "wb") as f:
f.write(screenshot)
return image_path
def _get_obs(self):
print("OBS 1")
screenshot_image_path = self._get_screenshot()
print("OBS 2")
with Image.open(screenshot_image_path) as img:
return np.array(img)
return screenshot_image_path
def reset(self):
input("Reset #1 PE")
#self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
input("Revert to snapshot #2 PE")
def reset(self, seed=None, options=None):
print("Resetting environment...")
print("Reverting to snapshot to {}...".format(self.snapshot_path))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5)
print("Starting emulator...")
self._start_emulator()
input("Started emulator #3 PE")
self._wait_for_emulator_load()
observation = self._get_obs()
print("Emulator started.")
print("Setting up environment...")
self.setup_controller.setup(self.config)
time.sleep(5)
print("Environment setup complete.")
observation = self._get_obs()
return observation
def step(self, action):
action_type = Action(action['action_type'])
if action_type == Action.CLICK:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_click()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_click()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_click()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_DOWN:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_down()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_down()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_down()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_UP:
click = MouseClick(action['click_type'])
if click == MouseClick.LEFT:
self.mouse_controller.left_up()
elif click == MouseClick.MIDDLE:
self.mouse_controller.middle_up()
elif click == MouseClick.RIGHT:
self.mouse_controller.right_up()
elif click == MouseClick.WHEEL_UP:
self.mouse_controller.scroll_up()
elif click == MouseClick.WHEEL_DOWN:
self.mouse_controller.scroll_down()
elif action_type == Action.MOUSE_MOVE:
self.mouse_controller.mouse_move(x = action['x'], y = action['y'])
elif action_type == Action.KEY:
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
self.keyboard_controller.key(key_sequence)
elif action_type == Action.TYPE:
text = ''.join(map(chr, action['text'])) # Convert integer array to string
self.keyboard_controller.type(text)
def step(self, action, pause=0.5):
# fixme: add reminding logic here, decide if the action is valid for the current action_space
if self.action_space == "computer_13":
# the set of all possible actions defined in the action representation
self.controller.execute_action(action)
elif self.action_space == "pyautogui":
# the set of all possible python commands insides `pyautogui`
self.controller.execute_python_command(action)
# Capture new state
observation = self._get_obs()
reward = 0 # Define reward calculation
done = False # Define episode termination condition
# todo: maybe for the better here we need to add a logic to wait until the rendering is done
time.sleep(pause)
observation = {
"screenshot": self._get_obs(),
"instruction": self.instruction
}
reward = 0 # todo: Define reward calculation for each example
done = False # todo: Define episode termination condition for each example
info = {}
return observation, reward, done, info
def evaluate(self):
"""
Evaluate whether the task is successfully completed.
"""
def copy_file_to_local(_file_info):
random_uuid = str(uuid.uuid4())
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
if _file_info["type"] == "cloud_file":
url = _file_info["path"]
response = requests.get(url, stream=True)
response.raise_for_status()
with open(_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
elif _file_info["type"] == "vm_file":
# fixme: stream this part maybe as well
file = self.controller.get_file(_file_info["path"])
with open(_path, "wb") as f:
f.write(file)
else:
raise NotImplementedError
return _path
# todo: make this more flexible by refactoring
eval_func = eval_funcs[self.evaluator["func"]]
eval_func_vars = {}
for var_name, file_info in self.evaluator["paths"].items():
path = copy_file_to_local(file_info)
eval_func_vars[var_name] = path
return eval_func(**eval_func_vars)
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self._get_obs()
@@ -205,4 +188,4 @@ class DesktopEnv(gym.Env):
raise ValueError('Unsupported render mode: {}'.format(mode))
def close(self):
self._execute_command(["vmrun", "stop", self.path_to_vm])
_execute_command(["vmrun", "stop", self.path_to_vm])