Refactor with pyautogui
This commit is contained in:
@@ -1,23 +1,6 @@
|
|||||||
# Server Setup Guide
|
# Server Setup Guide
|
||||||
|
|
||||||
- [Linux](#linux)
|
1. Copy and paste the file `server/main.py` to the windows vm
|
||||||
- [Windows](#windows)
|
2. Install the requirements `pip install -r requirements.txt`
|
||||||
|
|
||||||
## Linux
|
|
||||||
|
|
||||||
<https://averagelinuxuser.com/ssh-into-virtualbox/>
|
|
||||||
|
|
||||||
1. `sudo apt install openssh-server`
|
|
||||||
2. `sudo systemctl enable ssh --now`
|
|
||||||
3. `sudo ufw disable` (disable firewall - safe for local network, otherwise `sudo ufw allow ssh`)
|
|
||||||
4. `ip a` - find ip address
|
|
||||||
5. ssh username@<ip_address>
|
|
||||||
6. On host, run `ssh-copy-id <username>@<ip_address>`
|
|
||||||
|
|
||||||
|
|
||||||
## Windows
|
|
||||||
|
|
||||||
1. Copy and paste the file `windows_server/main.py` to the windows vm
|
|
||||||
2. Make sure `mouse` and `keyboard` are installed
|
|
||||||
3. Run the file `python main.py`
|
3. Run the file `python main.py`
|
||||||
4. `ipconfig /all` and find the ip address
|
4. `ipconfig /all` and find the ip address
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from fabric import Connection
|
|
||||||
|
|
||||||
from .xdotool import XDoToolController
|
|
||||||
from .python import PythonController
|
|
||||||
|
|
||||||
class AbstractKeyboardController(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def type(self, text: str):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def key(self, key: str):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def key_down(self, key: str):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def key_up(self, key: str):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class XDoToolKeyboardController(AbstractKeyboardController, XDoToolController):
|
|
||||||
def __init__(self, ssh_connection: Connection):
|
|
||||||
super().__init__(ssh_connection=ssh_connection)
|
|
||||||
|
|
||||||
def type(self, text: str):
|
|
||||||
self._execute_xdotool_command(f"type {text}")
|
|
||||||
|
|
||||||
def key(self, key: str):
|
|
||||||
self._execute_xdotool_command(f"key {key}")
|
|
||||||
|
|
||||||
def key_down(self, key: str):
|
|
||||||
self._execute_xdotool_command(f"keydown {key}")
|
|
||||||
|
|
||||||
def key_up(self, key: str):
|
|
||||||
self._execute_xdotool_command(f"keyup {key}")
|
|
||||||
|
|
||||||
class PythonKeyboardController(AbstractKeyboardController, PythonController):
|
|
||||||
def __init__(self, http_server: str):
|
|
||||||
super().__init__(http_server=http_server)
|
|
||||||
self.command = "python -c \"import keyboard; {command}\""
|
|
||||||
|
|
||||||
def type(self, text: str):
|
|
||||||
self._execute_python_command(self.command.format(command=f"keyboard.write('{text}')"))
|
|
||||||
|
|
||||||
def key(self, key: str):
|
|
||||||
self._execute_python_command(self.command.format(command=f"keyboard.press_and_release('{key}')"))
|
|
||||||
|
|
||||||
def key_down(self, key: str):
|
|
||||||
self._execute_python_command(self.command.format(command=f"keyboard.press('{key}')"))
|
|
||||||
|
|
||||||
def key_up(self, key: str):
|
|
||||||
self._execute_python_command(self.command.format(command=f"keyboard.release('{key}')"))
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from fabric import Connection
|
|
||||||
import re
|
|
||||||
|
|
||||||
from .xdotool import XDoToolController
|
|
||||||
from .python import PythonController
|
|
||||||
class MouseClick(Enum):
|
|
||||||
LEFT = 1
|
|
||||||
MIDDLE = 2
|
|
||||||
RIGHT = 3
|
|
||||||
WHEEL_UP = 4
|
|
||||||
WHEEL_DOWN = 5
|
|
||||||
|
|
||||||
class AbstractMouseController(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def get_mouse(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def mouse_move(self, x: int, y: int):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def left_down(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def left_up(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def left_click(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def middle_down(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def middle_up(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def middle_click(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def right_down(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def right_up(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def right_click(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def scroll_up(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def scroll_down(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class XDoToolMouseController(AbstractMouseController, XDoToolController):
|
|
||||||
def __init__(self, ssh_connection: Connection):
|
|
||||||
super().__init__(ssh_connection=ssh_connection)
|
|
||||||
|
|
||||||
def get_mouse(self):
|
|
||||||
output = self._execute_xdotool_command(f"")
|
|
||||||
parts = output.split(" ")
|
|
||||||
x = int(parts[0].split(":")[1])
|
|
||||||
y = int(parts[1].split(":")[1])
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
def mouse_move(self, x: int, y: int):
|
|
||||||
self._execute_xdotool_command(f"mousemove {x} {y}")
|
|
||||||
|
|
||||||
def left_down(self):
|
|
||||||
self._execute_xdotool_command(f"mousedown 1")
|
|
||||||
|
|
||||||
def left_up(self):
|
|
||||||
self._execute_xdotool_command(f"mouseup 1")
|
|
||||||
|
|
||||||
def left_click(self):
|
|
||||||
self._execute_xdotool_command(f"click 1")
|
|
||||||
|
|
||||||
def middle_down(self):
|
|
||||||
self._execute_xdotool_command(f"mousedown 2")
|
|
||||||
|
|
||||||
def middle_up(self):
|
|
||||||
self._execute_xdotool_command(f"mouseup 2")
|
|
||||||
|
|
||||||
def middle_click(self):
|
|
||||||
self._execute_xdotool_command(f"click 2")
|
|
||||||
|
|
||||||
def right_down(self):
|
|
||||||
self._execute_xdotool_command(f"mousedown 3")
|
|
||||||
|
|
||||||
def right_up(self):
|
|
||||||
self._execute_xdotool_command(f"mouseup 3")
|
|
||||||
|
|
||||||
def right_click(self):
|
|
||||||
self._execute_xdotool_command(f"click 3")
|
|
||||||
|
|
||||||
def scroll_up(self):
|
|
||||||
self._execute_xdotool_command(f"click 4")
|
|
||||||
|
|
||||||
def scroll_down(self):
|
|
||||||
self._execute_xdotool_command(f"click 5")
|
|
||||||
|
|
||||||
class PythonMouseController(AbstractMouseController, PythonController):
|
|
||||||
def __init__(self, http_server: str):
|
|
||||||
super().__init__(http_server=http_server)
|
|
||||||
self.command = "py -c \"import mouse; {command}\""
|
|
||||||
|
|
||||||
def get_mouse(self):
|
|
||||||
response = self._execute_python_command(self.command.format(command=f"print(mouse.get_position())"))
|
|
||||||
numbers = re.findall(r'-?\d+', response["output"])
|
|
||||||
x, y = map(int, numbers)
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
def mouse_move(self, x: int, y: int):
|
|
||||||
self._execute_python_command(self.command.format(command=f"mouse.move({x}, {y})"))
|
|
||||||
|
|
||||||
def left_down(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.press(button='left')"))
|
|
||||||
|
|
||||||
def left_up(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.release(button='left')"))
|
|
||||||
|
|
||||||
def left_click(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.click(button='left')"))
|
|
||||||
|
|
||||||
def middle_down(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.press(button='middle')"))
|
|
||||||
|
|
||||||
def middle_up(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.release(button='middle')"))
|
|
||||||
|
|
||||||
def middle_click(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.click(button='middle')"))
|
|
||||||
|
|
||||||
def right_down(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.press(button='right')"))
|
|
||||||
|
|
||||||
def right_up(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.release(button='right')"))
|
|
||||||
|
|
||||||
def right_click(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.click(button='right')"))
|
|
||||||
|
|
||||||
def scroll_up(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.wheel(10)"))
|
|
||||||
|
|
||||||
def scroll_down(self):
|
|
||||||
self._execute_python_command(self.command.format(command="mouse.wheel(-10)"))
|
|
||||||
@@ -1,15 +1,31 @@
|
|||||||
import requests
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class PythonController:
|
class PythonController:
|
||||||
def __init__(self, http_server: str):
|
def __init__(self, http_server: str, pkgs_prefix: str = "py -c \"import pyautogui; {command}\""):
|
||||||
self.http_server = http_server
|
self.http_server = http_server
|
||||||
|
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
|
||||||
|
|
||||||
def _execute_python_command(self, command: str) -> None:
|
def get_screenshot(self):
|
||||||
payload = json.dumps({
|
"""
|
||||||
"command": command
|
Gets a screenshot from the server. With the cursor.
|
||||||
})
|
"""
|
||||||
|
response = requests.get(self.http_server + "/screenshot")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.content
|
||||||
|
else:
|
||||||
|
print("Failed to get screenshot. Status code:", response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def execute_python_command(self, command: str) -> None:
|
||||||
|
"""
|
||||||
|
Executes a python command on the server.
|
||||||
|
It can be used to execute the pyautogui commands, or... any other python command. who knows?
|
||||||
|
"""
|
||||||
|
command = self.pkgs_prefix.format(command=command)
|
||||||
|
payload = json.dumps({"command": command})
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
@@ -23,15 +39,3 @@ class PythonController:
|
|||||||
return response.json()
|
return response.json()
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
print("An error occurred while trying to execute the command:", e)
|
print("An error occurred while trying to execute the command:", e)
|
||||||
|
|
||||||
|
|
||||||
# example usage
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# replace with your actual server URL of the vm
|
|
||||||
server_url = "http://192.168.7.129:5000"
|
|
||||||
controller = PythonController(server_url)
|
|
||||||
|
|
||||||
# example commands
|
|
||||||
python_command = "python -c \"import keyboard; keyboard.write('hello world')\""
|
|
||||||
python_command = "python -c \"import mouse; mouse.move(100,100);mouse.right_click()\""
|
|
||||||
controller._execute_python_command(python_command)
|
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
from fabric import Connection
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class XDoToolController:
|
|
||||||
def __init__(self, ssh_connection: Connection):
|
|
||||||
self.ssh_connection = ssh_connection
|
|
||||||
|
|
||||||
def _execute_xdotool_command(self, command: List[str]) -> None:
|
|
||||||
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
|
|
||||||
return result.stdout.strip()
|
|
||||||
@@ -1,34 +1,20 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, List, Tuple
|
|
||||||
import subprocess
|
import subprocess
|
||||||
from fabric import Connection
|
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
|
||||||
import numpy as np
|
|
||||||
import uuid
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, \
|
from desktop_env.controllers.python import PythonController
|
||||||
PythonMouseController
|
|
||||||
from desktop_env.controllers.keyboard import AbstractKeyboardController, XDoToolKeyboardController, \
|
|
||||||
PythonKeyboardController
|
|
||||||
|
|
||||||
|
|
||||||
class Action(Enum):
|
def _execute_command(command: List[str]) -> None:
|
||||||
CLICK = 0
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||||
MOUSE_DOWN = 1
|
if result.returncode != 0:
|
||||||
MOUSE_UP = 2
|
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||||
MOUSE_MOVE = 3
|
|
||||||
KEY = 4
|
|
||||||
KEY_DOWN = 5
|
|
||||||
KEY_UP = 6
|
|
||||||
TYPE = 7
|
|
||||||
|
|
||||||
|
|
||||||
VM_TYPE = Literal['ubuntu', 'windows']
|
|
||||||
|
|
||||||
|
|
||||||
class DesktopEnv(gym.Env):
|
class DesktopEnv(gym.Env):
|
||||||
@@ -37,67 +23,20 @@ class DesktopEnv(gym.Env):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
path_to_vm: str,
|
path_to_vm: str,
|
||||||
username: str,
|
|
||||||
password: str = None,
|
|
||||||
host: str = "192.168.7.128:5000",
|
host: str = "192.168.7.128:5000",
|
||||||
snapshot_path: str = "base",
|
snapshot_path: str = "base",
|
||||||
vm_os: VM_TYPE = "ubuntu"
|
|
||||||
):
|
):
|
||||||
|
# Initialize environment variables
|
||||||
# The path to the vmx file of your vm
|
|
||||||
self.path_to_vm = path_to_vm
|
self.path_to_vm = path_to_vm
|
||||||
|
|
||||||
# username and password for your vm
|
|
||||||
self.username = username
|
|
||||||
self.password = password
|
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||||
|
|
||||||
# Initialize emulator
|
# Initialize emulator and controller
|
||||||
print("Initializing...")
|
print("Initializing...")
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
|
self.controller = PythonController(http_server=self.host)
|
||||||
|
|
||||||
# set up controllers
|
# todo: define the action space and the observation space as gym did
|
||||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
|
||||||
|
|
||||||
# Get the screen size
|
|
||||||
self.screen_width, self.screen_height = self._get_screensize()
|
|
||||||
|
|
||||||
# Define the action and observation space
|
|
||||||
self.action_space = spaces.Dict({
|
|
||||||
"action_type": spaces.Discrete(len(Action)),
|
|
||||||
"click_type": spaces.Discrete(len(MouseClick)),
|
|
||||||
"x": spaces.Discrete(self.screen_width),
|
|
||||||
"y": spaces.Discrete(self.screen_height),
|
|
||||||
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
|
|
||||||
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
|
|
||||||
})
|
|
||||||
|
|
||||||
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3),
|
|
||||||
dtype=np.uint8)
|
|
||||||
|
|
||||||
# Additional setup
|
|
||||||
self.metadata = {'render.modes': ['rgb_array']}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_screensize(self):
|
|
||||||
screenshot_path = self._get_obs()
|
|
||||||
img = Image.open(screenshot_path)
|
|
||||||
return img.size
|
|
||||||
|
|
||||||
def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
|
|
||||||
if vm_os == "ubuntu":
|
|
||||||
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
|
||||||
mouse_controller = XDoToolMouseController(ssh_connection)
|
|
||||||
keyboard_controller = XDoToolKeyboardController(ssh_connection)
|
|
||||||
elif vm_os == "windows":
|
|
||||||
mouse_controller = PythonMouseController(http_server=self.host)
|
|
||||||
keyboard_controller = PythonKeyboardController(http_server=self.host)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(vm_os)
|
|
||||||
|
|
||||||
return mouse_controller, keyboard_controller
|
|
||||||
|
|
||||||
def _start_emulator(self):
|
def _start_emulator(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -109,52 +48,35 @@ class DesktopEnv(gym.Env):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
print("Starting VM...")
|
print("Starting VM...")
|
||||||
self._execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"Error executing command: {e.output.decode().strip()}")
|
print(f"Error executing command: {e.output.decode().strip()}")
|
||||||
|
|
||||||
def _execute_command(self, command: List[str]) -> None:
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
|
||||||
|
|
||||||
def _save_state(self):
|
def _save_state(self):
|
||||||
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||||
|
|
||||||
def _get_screenshot(self):
|
def _get_screenshot(self):
|
||||||
random_uuid = str(uuid.uuid4())
|
random_uuid = str(uuid.uuid4())
|
||||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||||
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
||||||
|
|
||||||
if self.password:
|
# Get the screenshot and save to the image_path
|
||||||
self._execute_command(
|
screenshot = self.controller.get_screenshot()
|
||||||
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
|
with open(image_path, "wb") as f:
|
||||||
else:
|
f.write(screenshot)
|
||||||
self._execute_command(
|
|
||||||
["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path])
|
|
||||||
|
|
||||||
return image_path
|
return image_path
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
screenshot_image_path = self._get_screenshot()
|
screenshot_image_path = self._get_screenshot()
|
||||||
self._add_cursor(screenshot_image_path)
|
|
||||||
return screenshot_image_path
|
return screenshot_image_path
|
||||||
|
|
||||||
def _add_cursor(self, img_path: str):
|
def reset(self, seed=None, options=None):
|
||||||
x, y = self.mouse_controller.get_mouse()
|
|
||||||
cursor_image = Image.open("./desktop_env/assets/cursor.png")
|
|
||||||
cursor_image = cursor_image.resize((int(cursor_image.width / 2), int(cursor_image.height / 2)))
|
|
||||||
|
|
||||||
screenshot = Image.open(img_path)
|
|
||||||
screenshot.paste(cursor_image, (x, y), cursor_image)
|
|
||||||
screenshot.save(img_path)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
print("Resetting environment...")
|
print("Resetting environment...")
|
||||||
|
|
||||||
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||||
self._execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
print("Starting emulator...")
|
print("Starting emulator...")
|
||||||
@@ -165,75 +87,11 @@ class DesktopEnv(gym.Env):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
if isinstance(action, list):
|
# Our action space is the set of all possible python commands insides `pyautogui`
|
||||||
for a in action:
|
self.controller.execute_python_command(action)
|
||||||
observation, reward, done, info = self.step(a)
|
|
||||||
return observation, reward, done, info
|
|
||||||
|
|
||||||
# todo: handle the case when the action is not a single action
|
|
||||||
try:
|
|
||||||
action_type = Action(action['action_type'])
|
|
||||||
except KeyError:
|
|
||||||
done = True
|
|
||||||
return self._get_obs(), 0, done, {}
|
|
||||||
|
|
||||||
if action_type == Action.CLICK:
|
|
||||||
click = MouseClick(action['click_type'])
|
|
||||||
if click == MouseClick.LEFT:
|
|
||||||
self.mouse_controller.left_click()
|
|
||||||
elif click == MouseClick.MIDDLE:
|
|
||||||
self.mouse_controller.middle_click()
|
|
||||||
elif click == MouseClick.RIGHT:
|
|
||||||
self.mouse_controller.right_click()
|
|
||||||
elif click == MouseClick.WHEEL_UP:
|
|
||||||
self.mouse_controller.scroll_up()
|
|
||||||
elif click == MouseClick.WHEEL_DOWN:
|
|
||||||
self.mouse_controller.scroll_down()
|
|
||||||
elif action_type == Action.MOUSE_DOWN:
|
|
||||||
click = MouseClick(action['click_type'])
|
|
||||||
if click == MouseClick.LEFT:
|
|
||||||
self.mouse_controller.left_down()
|
|
||||||
elif click == MouseClick.MIDDLE:
|
|
||||||
self.mouse_controller.middle_down()
|
|
||||||
elif click == MouseClick.RIGHT:
|
|
||||||
self.mouse_controller.right_down()
|
|
||||||
elif click == MouseClick.WHEEL_UP:
|
|
||||||
self.mouse_controller.scroll_up()
|
|
||||||
elif click == MouseClick.WHEEL_DOWN:
|
|
||||||
self.mouse_controller.scroll_down()
|
|
||||||
elif action_type == Action.MOUSE_UP:
|
|
||||||
click = MouseClick(action['click_type'])
|
|
||||||
if click == MouseClick.LEFT:
|
|
||||||
self.mouse_controller.left_up()
|
|
||||||
elif click == MouseClick.MIDDLE:
|
|
||||||
self.mouse_controller.middle_up()
|
|
||||||
elif click == MouseClick.RIGHT:
|
|
||||||
self.mouse_controller.right_up()
|
|
||||||
elif click == MouseClick.WHEEL_UP:
|
|
||||||
self.mouse_controller.scroll_up()
|
|
||||||
elif click == MouseClick.WHEEL_DOWN:
|
|
||||||
self.mouse_controller.scroll_down()
|
|
||||||
elif action_type == Action.MOUSE_MOVE:
|
|
||||||
self.mouse_controller.mouse_move(x=action['x'], y=action['y'])
|
|
||||||
elif action_type == Action.KEY:
|
|
||||||
self.keyboard_controller.key(action['key'])
|
|
||||||
elif action_type == Action.KEY_DOWN:
|
|
||||||
self.keyboard_controller.key_down(action['key'])
|
|
||||||
elif action_type == Action.KEY_UP:
|
|
||||||
self.keyboard_controller.key_up(action['key'])
|
|
||||||
elif action_type == Action.TYPE:
|
|
||||||
for key in action['text']:
|
|
||||||
if key == "\r" or key == "\n":
|
|
||||||
self.keyboard_controller.key("enter")
|
|
||||||
else:
|
|
||||||
self.keyboard_controller.key(key)
|
|
||||||
# sleep for 0.05 seconds with some random noise
|
|
||||||
time.sleep(0.05 + np.random.normal(0, 0.01))
|
|
||||||
|
|
||||||
# Capture new state
|
|
||||||
observation = self._get_obs()
|
observation = self._get_obs()
|
||||||
reward = 0 # Define reward calculation
|
reward = 0 # todo: Define reward calculation for each example
|
||||||
done = False # Define episode termination condition
|
done = False # todo: Define episode termination condition for each example
|
||||||
info = {}
|
info = {}
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
@@ -244,4 +102,4 @@ class DesktopEnv(gym.Env):
|
|||||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._execute_command(["vmrun", "stop", self.path_to_vm])
|
_execute_command(["vmrun", "stop", self.path_to_vm])
|
||||||
|
|||||||
64
desktop_env/server/main.py
Normal file
64
desktop_env/server/main.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import Xlib.display
|
||||||
|
import pyautogui
|
||||||
|
from PIL import ImageGrab
|
||||||
|
from flask import Flask, request, jsonify, send_file
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
pyautogui.PAUSE = 0
|
||||||
|
pyautogui.DARWIN_CATCH_UP_TIME = 0
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/execute', methods=['POST'])
|
||||||
|
def execute_command():
|
||||||
|
data = request.json
|
||||||
|
# The 'command' key in the JSON request should contain the command to be executed.
|
||||||
|
command = data.get('command', '')
|
||||||
|
|
||||||
|
# Execute the command without any safety checks.
|
||||||
|
try:
|
||||||
|
result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
return jsonify({
|
||||||
|
'status': 'success',
|
||||||
|
'output': result.stdout,
|
||||||
|
'error': result.stderr
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'status': 'error',
|
||||||
|
'message': str(e)
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/screenshot', methods=['GET'])
|
||||||
|
def capture_screen_with_cursor():
|
||||||
|
file_path = os.path.join("screenshots", "screenshot.png")
|
||||||
|
user_platform = platform.system()
|
||||||
|
|
||||||
|
# Ensure the screenshots directory exists
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
|
||||||
|
if user_platform == "Windows":
|
||||||
|
screenshot = pyautogui.screenshot()
|
||||||
|
screenshot.save(file_path)
|
||||||
|
elif user_platform == "Linux":
|
||||||
|
# Use xlib to prevent scrot dependency for Linux
|
||||||
|
screen = Xlib.display.Display().screen()
|
||||||
|
size = screen.width_in_pixels, screen.height_in_pixels
|
||||||
|
screenshot = ImageGrab.grab(bbox=(0, 0, size[0], size[1]))
|
||||||
|
screenshot.save(file_path)
|
||||||
|
elif user_platform == "Darwin": # (Mac OS)
|
||||||
|
# Use the screencapture utility to capture the screen with the cursor
|
||||||
|
subprocess.run(["screencapture", "-C", file_path])
|
||||||
|
else:
|
||||||
|
print(f"The platform you're using ({user_platform}) is not currently supported")
|
||||||
|
|
||||||
|
return send_file(file_path, mimetype='image/png')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(debug=True, host="0.0.0.0")
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
from flask import Flask, request, jsonify
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
@app.route('/execute', methods=['POST'])
|
|
||||||
def execute_command():
|
|
||||||
data = request.json
|
|
||||||
# The 'command' key in the JSON request should contain the command to be executed.
|
|
||||||
command = data.get('command', '')
|
|
||||||
|
|
||||||
# Execute the command without any safety checks.
|
|
||||||
try:
|
|
||||||
result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
||||||
return jsonify({
|
|
||||||
'status': 'success',
|
|
||||||
'output': result.stdout,
|
|
||||||
'error': result.stderr
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
|
||||||
'status': 'error',
|
|
||||||
'message': str(e)
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
app.run(debug=True, host="0.0.0.0")
|
|
||||||
@@ -8,29 +8,26 @@ import uuid
|
|||||||
def gpt_4v_agent():
|
def gpt_4v_agent():
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
# meta_info = {
|
|
||||||
# "instruction": "Open WSJ website to get latest news",
|
|
||||||
# "task_name": "open_wsj",
|
|
||||||
# "snapshot_path": "base",
|
|
||||||
# }
|
|
||||||
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"instruction": "Clear the recycle bin",
|
"instruction": "Open WSJ website to get latest news",
|
||||||
"task_name": "clean_recycle_bin",
|
"task_name": "open_wsj",
|
||||||
"snapshot_path": "base",
|
"snapshot_path": "base",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# meta_info = {
|
||||||
|
# "instruction": "Clear the recycle bin",
|
||||||
|
# "task_name": "clean_recycle_bin",
|
||||||
|
# "snapshot_path": "base",
|
||||||
|
# }
|
||||||
|
|
||||||
agent = GPT4v_Agent(api_key=api_key, instruction=meta_info["instruction"])
|
agent = GPT4v_Agent(api_key=api_key, instruction=meta_info["instruction"])
|
||||||
env = DesktopEnv(
|
env = DesktopEnv(
|
||||||
path_to_vm=r"""C:\Users\tianbaox\Documents\Virtual Machines\Win10\Win10.vmx""",
|
path_to_vm=r"""C:\Users\tianbaox\Documents\Virtual Machines\Win10\Win10.vmx""",
|
||||||
# automitically load the snapshot and start the vm
|
# automitically load the snapshot and start the vm
|
||||||
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
|
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
|
||||||
snapshot_path="base",
|
snapshot_path="base",
|
||||||
username="tianbaox",
|
|
||||||
password="951753",
|
|
||||||
# host="192.168.7.128",
|
# host="192.168.7.128",
|
||||||
host="http://192.168.13.128:5000",
|
host="http://192.168.13.128:5000",
|
||||||
vm_os="windows"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# reset the environment to certain snapshot
|
# reset the environment to certain snapshot
|
||||||
|
|||||||
63
main.py
63
main.py
@@ -1,55 +1,31 @@
|
|||||||
from pprint import pprint
|
from desktop_env.envs.desktop_env import DesktopEnv
|
||||||
from desktop_env.envs.desktop_env import DesktopEnv, Action, MouseClick
|
|
||||||
|
|
||||||
def get_human_action():
|
|
||||||
"""
|
|
||||||
Prompts the human player for an action and returns a structured action.
|
|
||||||
"""
|
|
||||||
print("\nAvailable actions:", [action.name for action in Action])
|
|
||||||
action_type = None
|
|
||||||
while action_type not in [action.value for action in Action]:
|
|
||||||
action_type = Action[input("Enter the type of action: ".strip())].value
|
|
||||||
|
|
||||||
action = {"action_type": action_type}
|
|
||||||
|
|
||||||
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
|
||||||
print("\n Available clicks:", [action.name for action in MouseClick])
|
|
||||||
click_type = input("Enter click type: ")
|
|
||||||
action["click_type"] = MouseClick[click_type].value
|
|
||||||
|
|
||||||
if action_type == Action.MOUSE_MOVE.value:
|
|
||||||
x = int(input("Enter x-coordinate for mouse move: "))
|
|
||||||
y = int(input("Enter y-coordinate for mouse move: "))
|
|
||||||
action["x"] = x
|
|
||||||
action["y"] = y
|
|
||||||
|
|
||||||
if action_type == Action.KEY.value:
|
|
||||||
key = input("Enter the key to press: ")
|
|
||||||
action["key"] = [ord(c) for c in key]
|
|
||||||
|
|
||||||
if action_type == Action.TYPE.value:
|
|
||||||
text = input("Enter the text to type: ")
|
|
||||||
action["text"] = [ord(c) for c in text]
|
|
||||||
|
|
||||||
return action
|
|
||||||
|
|
||||||
|
|
||||||
def human_agent():
|
def human_agent():
|
||||||
"""
|
"""
|
||||||
Runs the Gym environment with human input.
|
Runs the Gym environment with human input.
|
||||||
"""
|
"""
|
||||||
env = DesktopEnv(path_to_vm="/home/yuri/vmware/Windows 10 x64/Windows 10 x64.vmx",
|
env = DesktopEnv(
|
||||||
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
|
path_to_vm=r"""C:\Users\tianbaox\Documents\Virtual Machines\Win10\Win10.vmx""",
|
||||||
username="user",
|
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
|
||||||
password="password",
|
# host="192.168.7.128",
|
||||||
# host="192.168.7.128",
|
host="http://192.168.13.128:5000",
|
||||||
host="http://192.168.7.129:5000",
|
)
|
||||||
vm_os="windows")
|
|
||||||
observation = env.reset()
|
# reset the environment to certain snapshot
|
||||||
|
# observation = env.reset()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
while not done:
|
while not done:
|
||||||
action = get_human_action()
|
# action = get_human_action()
|
||||||
|
|
||||||
|
# action = {
|
||||||
|
# "action_type": 0,
|
||||||
|
# "click_type": 3,
|
||||||
|
# }
|
||||||
|
|
||||||
|
action = "pyautogui.dragTo(100, 200, button='left')"
|
||||||
|
|
||||||
observation, reward, done, info = env.step(action)
|
observation, reward, done, info = env.step(action)
|
||||||
print("Observation:", observation)
|
print("Observation:", observation)
|
||||||
print("Reward:", reward)
|
print("Reward:", reward)
|
||||||
@@ -64,5 +40,6 @@ def human_agent():
|
|||||||
env.close()
|
env.close()
|
||||||
print("Environment closed.")
|
print("Environment closed.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
human_agent()
|
human_agent()
|
||||||
|
|||||||
@@ -88,6 +88,8 @@ class GPT4v_Agent:
|
|||||||
traj_to_show = []
|
traj_to_show = []
|
||||||
for i in range(len(self.trajectory)):
|
for i in range(len(self.trajectory)):
|
||||||
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
|
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
|
||||||
|
if len(self.trajectory[i]["content"]) > 1:
|
||||||
|
traj_to_show.append("screenshot_obs")
|
||||||
print("Trajectory:", traj_to_show)
|
print("Trajectory:", traj_to_show)
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
|
|||||||
Reference in New Issue
Block a user