add mouse cursor to screenshot
This commit is contained in:
BIN
desktop_env/assets/cursor.png
Normal file
BIN
desktop_env/assets/cursor.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.7 KiB |
@@ -2,6 +2,7 @@ from enum import Enum
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from fabric import Connection
|
from fabric import Connection
|
||||||
|
import re
|
||||||
|
|
||||||
from .xdotool import XDoToolController
|
from .xdotool import XDoToolController
|
||||||
from .python import PythonController
|
from .python import PythonController
|
||||||
@@ -13,6 +14,10 @@ class MouseClick(Enum):
|
|||||||
WHEEL_DOWN = 5
|
WHEEL_DOWN = 5
|
||||||
|
|
||||||
class AbstractMouseController(ABC):
|
class AbstractMouseController(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_mouse(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mouse_move(self, x: int, y: int):
|
def mouse_move(self, x: int, y: int):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -66,6 +71,10 @@ class XDoToolMouseController(AbstractMouseController, XDoToolController):
|
|||||||
def __init__(self, ssh_connection: Connection):
|
def __init__(self, ssh_connection: Connection):
|
||||||
super().__init__(ssh_connection=ssh_connection)
|
super().__init__(ssh_connection=ssh_connection)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_mouse(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def mouse_move(self, x: int, y: int):
|
def mouse_move(self, x: int, y: int):
|
||||||
self._execute_xdotool_command(f"mousemove {x} {y}")
|
self._execute_xdotool_command(f"mousemove {x} {y}")
|
||||||
|
|
||||||
@@ -105,7 +114,13 @@ class XDoToolMouseController(AbstractMouseController, XDoToolController):
|
|||||||
class PythonMouseController(AbstractMouseController, PythonController):
|
class PythonMouseController(AbstractMouseController, PythonController):
|
||||||
def __init__(self, http_server: str):
|
def __init__(self, http_server: str):
|
||||||
super().__init__(http_server=http_server)
|
super().__init__(http_server=http_server)
|
||||||
self.command = "python -c \"import mouse; {command}\""
|
self.command = "py -c \"import mouse; {command}\""
|
||||||
|
|
||||||
|
def get_mouse(self):
|
||||||
|
response = self._execute_python_command(self.command.format(command=f"print(mouse.get_position())"))
|
||||||
|
numbers = re.findall(r'-?\d+', response["output"])
|
||||||
|
x, y = map(int, numbers)
|
||||||
|
return x, y
|
||||||
|
|
||||||
def mouse_move(self, x: int, y: int):
|
def mouse_move(self, x: int, y: int):
|
||||||
self._execute_python_command(self.command.format(command=f"mouse.move({x}, {y})"))
|
self._execute_python_command(self.command.format(command=f"mouse.move({x}, {y})"))
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class PythonController:
|
|||||||
print("Command executed successfully:", response.text)
|
print("Command executed successfully:", response.text)
|
||||||
else:
|
else:
|
||||||
print("Failed to execute command. Status code:", response.status_code)
|
print("Failed to execute command. Status code:", response.status_code)
|
||||||
|
return response.json()
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
print("An error occurred while trying to execute the command:", e)
|
print("An error occurred while trying to execute the command:", e)
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ class DesktopEnv(gym.Env):
|
|||||||
print("Initializing...")
|
print("Initializing...")
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
|
|
||||||
|
# set up controllers
|
||||||
|
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
||||||
|
|
||||||
# Get the screen size
|
# Get the screen size
|
||||||
self.screen_width, self.screen_height = self._get_screensize()
|
self.screen_width, self.screen_height = self._get_screensize()
|
||||||
|
|
||||||
@@ -77,8 +80,6 @@ class DesktopEnv(gym.Env):
|
|||||||
# Additional setup
|
# Additional setup
|
||||||
self.metadata = {'render.modes': ['rgb_array']}
|
self.metadata = {'render.modes': ['rgb_array']}
|
||||||
|
|
||||||
# set up controllers
|
|
||||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
|
||||||
|
|
||||||
def _get_screensize(self):
|
def _get_screensize(self):
|
||||||
screenshot_path = self._get_obs()
|
screenshot_path = self._get_obs()
|
||||||
@@ -114,7 +115,9 @@ class DesktopEnv(gym.Env):
|
|||||||
print(f"Error executing command: {e.output.decode().strip()}")
|
print(f"Error executing command: {e.output.decode().strip()}")
|
||||||
|
|
||||||
def _execute_command(self, command: List[str]) -> None:
|
def _execute_command(self, command: List[str]) -> None:
|
||||||
subprocess.run(command, shell=True, stderr=subprocess.STDOUT, timeout=60)
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||||
|
|
||||||
def _save_state(self):
|
def _save_state(self):
|
||||||
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||||
@@ -126,8 +129,7 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
if self.password:
|
if self.password:
|
||||||
self._execute_command(
|
self._execute_command(
|
||||||
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm,
|
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
|
||||||
image_path])
|
|
||||||
else:
|
else:
|
||||||
self._execute_command(
|
self._execute_command(
|
||||||
["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path])
|
["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path])
|
||||||
@@ -136,8 +138,18 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
screenshot_image_path = self._get_screenshot()
|
screenshot_image_path = self._get_screenshot()
|
||||||
|
self._add_cursor(screenshot_image_path)
|
||||||
return screenshot_image_path
|
return screenshot_image_path
|
||||||
|
|
||||||
|
def _add_cursor(self, img_path: str):
|
||||||
|
x, y = self.mouse_controller.get_mouse()
|
||||||
|
cursor_image = Image.open("./desktop_env/assets/cursor.png")
|
||||||
|
cursor_image = cursor_image.resize((int(cursor_image.width / 2), int(cursor_image.height / 2)))
|
||||||
|
|
||||||
|
screenshot = Image.open(img_path)
|
||||||
|
screenshot.paste(cursor_image, (x, y), cursor_image)
|
||||||
|
screenshot.save(img_path)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
print("Resetting environment...")
|
print("Resetting environment...")
|
||||||
|
|
||||||
|
|||||||
@@ -11,9 +11,11 @@ def execute_command():
|
|||||||
|
|
||||||
# Execute the command without any safety checks.
|
# Execute the command without any safety checks.
|
||||||
try:
|
try:
|
||||||
subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'status': 'success',
|
'status': 'success',
|
||||||
|
'output': result.stdout,
|
||||||
|
'error': result.stderr
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|||||||
Reference in New Issue
Block a user