diff --git a/README.md b/README.md index f31f363..b2230f6 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,8 @@ 3. `sudo ufw disable` (disable firewall - safe for local network, otherwise `sudo ufw allow ssh`) 4. `ip a` - find ip address 5. ssh username@ -5. Install screenshot tool + 6. On host, run `ssh-copy-id @` +5. Install screenshot tool (in vm) 1. `sudo apt install imagemagick-6.q16hdri` 2. `DISPLAY=:0 import -window root screenshot.png` 6. Get screenshot diff --git a/desktop_env/__init__.py b/desktop_env/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/desktop_env/envs/__init__.py b/desktop_env/envs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py new file mode 100644 index 0000000..bda6ed8 --- /dev/null +++ b/desktop_env/envs/desktop_env.py @@ -0,0 +1,142 @@ +from enum import Enum +import subprocess +from fabric import Connection + +import gymnasium as gym +from gymnasium import spaces +import numpy as np +from PIL import Image + +class Action(Enum): + CLICK = 0 + MOUSE_DOWN = 1 + MOUSE_UP = 2 + MOUSE_MOVE = 3 + KEY = 4 + TYPE = 5 + +class MouseClick(Enum): + LEFT = 1 + MIDDLE = 2 + RIGHT = 3 + WHEEL_UP = 4 + WHEEL_DOWN = 5 + +class DesktopEnv(gym.Env): + """DesktopEnv with OpenAI Gym interface.""" + + def __init__(self, path_to_vm: str, username: str, password: str, + host: str, snapshot_path: str = "snapshot"): + self.path_to_vm = path_to_vm + self.username = username + self.password = password + self.host = host + self.snapshot_path = snapshot_path + self.ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": password}) + + self.screen_width = 800 + self.screen_height = 800 + # Define the action and observation space + self.action_space = spaces.Dict({ + "action_type": spaces.Discrete(len(Action)), + "click_type": spaces.Discrete(len(MouseClick)), + "x": spaces.Discrete(self.screen_width), + "y": spaces.Discrete(self.screen_height), + "key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII + "text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII + }) + + self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8) + + # Additional setup + self.metadata = {'render.modes': ['rgb_array']} + self._start_emulator() + + def _start_emulator(self): + self._execute_command(["vmrun", "start", self.path_to_vm]) + + def _execute_command(self, command: list[str]) -> None: + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + if process.returncode != 0: + print(f"Error executing command: {command}") + print(stderr.decode()) + return None + else: + return stdout.decode() + + def _execute_xdotool_command(self, command: list[str]) -> None: + result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True) + return result.stdout.strip() + + def _save_state(self): + self._execute_command(["vmrun", "snapshot", self.snapshot_path]) + + def _click(self, click: MouseClick): + self._execute_xdotool_command(f"click {click.value}") + + def _mousedown(self, click: MouseClick): + self._execute_xdotool_command(f"mousedown {click.value}") + + def _mouseup(self, click: MouseClick): + self._execute_xdotool_command(f"mouseup {click.value}") + + def _mouse_move(self, x: int, y: int): + self._execute_xdotool_command(f"mousemove {x} {y}") + + def _key(self, key: str): + self._execute_xdotool_command(f"key {key}") + + def _type(self, text: str): + self._execute_xdotool_command(f"type {text}") + + def _get_screenshot(self): + image_path = "./screenshot.png" + self.ssh_connection.run("DISPLAY=:0 import -window root screenshot.png") + self._execute_command(["scp", "user@192.168.7.128:~/screenshot.png", image_path]) + self.ssh_connection.run("rm -rf ~/screenshot.png") + return image_path + + def _get_obs(self): + screenshot_image_path = self._get_screenshot() + with Image.open(screenshot_image_path) as img: + return np.array(img) + + def reset(self): + self._execute_command(["vmrun", "revertToSnapshot", self.snapshot_path]) + observation = self._get_obs() + + return observation + + def step(self, action): + action_type = Action(action['action_type']) + if action_type == Action.CLICK: + self._click(MouseClick(action['click_type'])) + elif action_type == Action.MOUSE_DOWN: + self._mousedown(MouseClick(action['click_type'])) + elif action_type == Action.MOUSE_UP: + self._mouseup(MouseClick(action['click_type'])) + elif action_type == Action.MOUSE_MOVE: + self._mouse_move(action['x'], action['y']) + elif action_type == Action.KEY: + key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string + self.key(key_sequence) + elif action_type == Action.TYPE: + text = ''.join(map(chr, action['text'])) # Convert integer array to string + self._type(text) + + # Capture new state + observation = self._get_obs() + reward = 0 # Define reward calculation + done = False # Define episode termination condition + info = {} + return observation, reward, done, info + + def render(self, mode='rgb_array'): + if mode == 'rgb_array': + return self._get_obs() + else: + raise ValueError('Unsupported render mode: {}'.format(mode)) + + def close(self): + self._execute_command(["vmrun", "stop", self.path_to_vm]) diff --git a/main.py b/main.py index 8209884..2e3ca27 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,65 @@ -from controller import Controller, Action, MouseClick +from pprint import pprint +from desktop_env.envs.desktop_env import DesktopEnv, Action, MouseClick -controller = Controller(vm_name="KUbuntu-23.10", username="username", password="password", host="192.168.56.101") +def get_human_action(): + """ + Prompts the human player for an action and returns a structured action. + """ + print("\nAvailable actions:", [action.name for action in Action]) + action_type = None + while action_type not in [action.value for action in Action]: + action_type = Action[input("Enter the type of action: ".strip())].value -input("enter to continue") -img = controller.get_state() -print(img) -input("enter to continue") -controller.step(action=Action.MOUSE_MOVE, x=100, y=100) -input("enter to continue") -controller.step(action=Action.CLICK, click=MouseClick.LEFT) -input("enter to continue") -controller.step(action=Action.TYPE, text="hello world") + action = {"action_type": action_type} + if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value: + print("\n Available clicks:", [action.name for action in MouseClick]) + click_type = input("Enter click type: ") + action["click_type"] = MouseClick[click_type].value + + if action_type == Action.MOUSE_MOVE.value: + x = int(input("Enter x-coordinate for mouse move: ")) + y = int(input("Enter y-coordinate for mouse move: ")) + action["x"] = x + action["y"] = y + + if action_type == Action.KEY.value: + key = input("Enter the key to press: ") + action["key"] = [ord(c) for c in key] + + if action_type == Action.TYPE.value: + text = input("Enter the text to type: ") + action["text"] = [ord(c) for c in text] + + return action + + +def human_agent(): + """ + Runs the Gym environment with human input. + """ + env = DesktopEnv(path_to_vm="~/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx", + username="user", + password="password", + host="192.168.7.128") + observation = env.reset() + done = False + + while not done: + action = get_human_action() + observation, reward, done, info = env.step(action) + print("Observation:", observation) + print("Reward:", reward) + print("Info:", info) + + print("================================\n") + + if done: + print("The episode is done.") + break + + env.close() + print("Environment closed.") + +if __name__ == "__main__": + human_agent() diff --git a/requirements.txt b/requirements.txt index 4b72cae..e881093 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy Pillow fabric +gymnasium