Adapt for Windows os; Refine README

This commit is contained in:
Timothyxxx
2023-11-27 00:29:09 +08:00
parent 6dee58252e
commit 8c0525c20e
6 changed files with 66 additions and 84 deletions

View File

@@ -1,9 +1,11 @@
from fabric import Connection
from typing import List
class XDoToolController:
def __init__(self, ssh_connection: Connection):
self.ssh_connection = ssh_connection
def _execute_xdotool_command(self, command: list[str]) -> None:
def _execute_xdotool_command(self, command: List[str]) -> None:
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
return result.stdout.strip()
return result.stdout.strip()

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Literal
from typing import Literal, List, Tuple
import subprocess
from fabric import Connection
import time
@@ -22,19 +22,21 @@ class Action(Enum):
KEY_UP = 6
TYPE = 7
VM_TYPE = Literal['ubuntu', 'windows']
class DesktopEnv(gym.Env):
"""DesktopEnv with OpenAI Gym interface."""
def __init__(self, path_to_vm: str, username: str, password: str,
host: str, snapshot_path: str = "snapshot", vm_os: VM_TYPE = "ubuntu"):
def __init__(self, path_to_vm: str, username: str, password: str,
host: str, snapshot_path: str = "some_point_browser", vm_os: VM_TYPE = "ubuntu"):
self.path_to_vm = path_to_vm
self.username = username
self.password = password
self.host = host
self.snapshot_path = snapshot_path
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
self.screen_width = 800
self.screen_height = 800
# Define the action and observation space
@@ -51,13 +53,15 @@ class DesktopEnv(gym.Env):
# Additional setup
self.metadata = {'render.modes': ['rgb_array']}
# Initialize emulator
print("Initializing...")
self._start_emulator()
self._wait_for_emulator_load()
# set up controllers
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
def _create_controllers(self, vm_os: VM_TYPE) -> tuple[AbstractMouseController, AbstractKeyboardController]:
def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
if vm_os == "ubuntu":
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
mouse_controller = XDoToolMouseController(ssh_connection)
@@ -67,33 +71,29 @@ class DesktopEnv(gym.Env):
keyboard_controller = PythonKeyboardController(http_server=self.host)
else:
raise NotImplementedError(vm_os)
return mouse_controller, keyboard_controller
def _start_emulator(self):
self._execute_command(["vmrun", "start", self.path_to_vm])
def _wait_for_emulator_load(self):
while True:
try:
output = subprocess.check_output(f"vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
output = output.decode()
if self.path_to_vm.lstrip("~/") in output:
print("VM is running.")
return
break
else:
print("Waiting for VM to start...")
print("Starting VM...")
self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
time.sleep(5)
except subprocess.CalledProcessError as e:
print(f"Error executing command: {e.output.decode().strip()}")
return
def _execute_command(self, command: list[str]) -> None:
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
def _execute_command(self, command: List[str]) -> None:
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error executing command: {command}")
print(stderr.decode())
return None
else:
return stdout.decode()
@@ -103,23 +103,27 @@ class DesktopEnv(gym.Env):
def _get_screenshot(self):
image_path = "./screenshot.png"
self._execute_command(["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
self._execute_command(
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm,
image_path])
return image_path
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
with Image.open(screenshot_image_path) as img:
return np.array(img)
def reset(self):
input()
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
input()
self._start_emulator()
input()
self._wait_for_emulator_load()
observation = self._get_obs()
print("Resetting environment...")
print("Reverting to snapshot to {}...".format(self.snapshot_path))
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
print("Starting emulator...")
self._start_emulator()
print("Emulator started.")
observation = self._get_obs()
return observation
def step(self, action):