Adapt for Windows os; Refine README
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from fabric import Connection
|
||||
from typing import List
|
||||
|
||||
|
||||
class XDoToolController:
|
||||
def __init__(self, ssh_connection: Connection):
|
||||
self.ssh_connection = ssh_connection
|
||||
|
||||
def _execute_xdotool_command(self, command: list[str]) -> None:
|
||||
def _execute_xdotool_command(self, command: List[str]) -> None:
|
||||
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
|
||||
return result.stdout.strip()
|
||||
return result.stdout.strip()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Literal, List, Tuple
|
||||
import subprocess
|
||||
from fabric import Connection
|
||||
import time
|
||||
@@ -22,19 +22,21 @@ class Action(Enum):
|
||||
KEY_UP = 6
|
||||
TYPE = 7
|
||||
|
||||
|
||||
VM_TYPE = Literal['ubuntu', 'windows']
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""DesktopEnv with OpenAI Gym interface."""
|
||||
|
||||
def __init__(self, path_to_vm: str, username: str, password: str,
|
||||
host: str, snapshot_path: str = "snapshot", vm_os: VM_TYPE = "ubuntu"):
|
||||
def __init__(self, path_to_vm: str, username: str, password: str,
|
||||
host: str, snapshot_path: str = "some_point_browser", vm_os: VM_TYPE = "ubuntu"):
|
||||
self.path_to_vm = path_to_vm
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.snapshot_path = snapshot_path
|
||||
|
||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||
|
||||
self.screen_width = 800
|
||||
self.screen_height = 800
|
||||
# Define the action and observation space
|
||||
@@ -51,13 +53,15 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
# Additional setup
|
||||
self.metadata = {'render.modes': ['rgb_array']}
|
||||
|
||||
# Initialize emulator
|
||||
print("Initializing...")
|
||||
self._start_emulator()
|
||||
self._wait_for_emulator_load()
|
||||
|
||||
# set up controllers
|
||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
||||
|
||||
def _create_controllers(self, vm_os: VM_TYPE) -> tuple[AbstractMouseController, AbstractKeyboardController]:
|
||||
def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
|
||||
if vm_os == "ubuntu":
|
||||
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
||||
mouse_controller = XDoToolMouseController(ssh_connection)
|
||||
@@ -67,33 +71,29 @@ class DesktopEnv(gym.Env):
|
||||
keyboard_controller = PythonKeyboardController(http_server=self.host)
|
||||
else:
|
||||
raise NotImplementedError(vm_os)
|
||||
|
||||
|
||||
return mouse_controller, keyboard_controller
|
||||
|
||||
def _start_emulator(self):
|
||||
self._execute_command(["vmrun", "start", self.path_to_vm])
|
||||
|
||||
def _wait_for_emulator_load(self):
|
||||
while True:
|
||||
try:
|
||||
output = subprocess.check_output(f"vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
||||
output = output.decode()
|
||||
if self.path_to_vm.lstrip("~/") in output:
|
||||
print("VM is running.")
|
||||
return
|
||||
break
|
||||
else:
|
||||
print("Waiting for VM to start...")
|
||||
print("Starting VM...")
|
||||
self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||
time.sleep(5)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing command: {e.output.decode().strip()}")
|
||||
return
|
||||
|
||||
def _execute_command(self, command: list[str]) -> None:
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
def _execute_command(self, command: List[str]) -> None:
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
||||
stdout, stderr = process.communicate()
|
||||
if process.returncode != 0:
|
||||
print(f"Error executing command: {command}")
|
||||
print(stderr.decode())
|
||||
return None
|
||||
else:
|
||||
return stdout.decode()
|
||||
@@ -103,23 +103,27 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
def _get_screenshot(self):
|
||||
image_path = "./screenshot.png"
|
||||
self._execute_command(["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
|
||||
self._execute_command(
|
||||
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm,
|
||||
image_path])
|
||||
return image_path
|
||||
|
||||
|
||||
def _get_obs(self):
|
||||
screenshot_image_path = self._get_screenshot()
|
||||
with Image.open(screenshot_image_path) as img:
|
||||
return np.array(img)
|
||||
|
||||
def reset(self):
|
||||
input()
|
||||
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
input()
|
||||
self._start_emulator()
|
||||
input()
|
||||
self._wait_for_emulator_load()
|
||||
observation = self._get_obs()
|
||||
print("Resetting environment...")
|
||||
|
||||
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
|
||||
print("Starting emulator...")
|
||||
self._start_emulator()
|
||||
print("Emulator started.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
|
||||
Reference in New Issue
Block a user