diff --git a/desktop_env/assets/cursor.png b/desktop_env/assets/cursor.png new file mode 100644 index 0000000..8d9182a Binary files /dev/null and b/desktop_env/assets/cursor.png differ diff --git a/desktop_env/controllers/mouse.py b/desktop_env/controllers/mouse.py index 45961be..cd63175 100644 --- a/desktop_env/controllers/mouse.py +++ b/desktop_env/controllers/mouse.py @@ -2,6 +2,7 @@ from enum import Enum from abc import ABC, abstractmethod from fabric import Connection +import re from .xdotool import XDoToolController from .python import PythonController @@ -13,6 +14,10 @@ class MouseClick(Enum): WHEEL_DOWN = 5 class AbstractMouseController(ABC): + @abstractmethod + def get_mouse(self): + raise NotImplementedError + @abstractmethod def mouse_move(self, x: int, y: int): raise NotImplementedError @@ -66,6 +71,10 @@ class XDoToolMouseController(AbstractMouseController, XDoToolController): def __init__(self, ssh_connection: Connection): super().__init__(ssh_connection=ssh_connection) + @abstractmethod + def get_mouse(self): + raise NotImplementedError + def mouse_move(self, x: int, y: int): self._execute_xdotool_command(f"mousemove {x} {y}") @@ -105,7 +114,13 @@ class XDoToolMouseController(AbstractMouseController, XDoToolController): class PythonMouseController(AbstractMouseController, PythonController): def __init__(self, http_server: str): super().__init__(http_server=http_server) - self.command = "python -c \"import mouse; {command}\"" + self.command = "py -c \"import mouse; {command}\"" + + def get_mouse(self): + response = self._execute_python_command(self.command.format(command=f"print(mouse.get_position())")) + numbers = re.findall(r'-?\d+', response["output"]) + x, y = map(int, numbers) + return x, y def mouse_move(self, x: int, y: int): self._execute_python_command(self.command.format(command=f"mouse.move({x}, {y})")) diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py index 735bd44..1e6a751 100644 --- a/desktop_env/controllers/python.py +++ b/desktop_env/controllers/python.py @@ -20,6 +20,7 @@ class PythonController: print("Command executed successfully:", response.text) else: print("Failed to execute command. Status code:", response.status_code) + return response.json() except requests.exceptions.RequestException as e: print("An error occurred while trying to execute the command:", e) diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 8755fdc..d032f87 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -58,6 +58,9 @@ class DesktopEnv(gym.Env): print("Initializing...") self._start_emulator() + # set up controllers + self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os) + # Get the screen size self.screen_width, self.screen_height = self._get_screensize() @@ -77,8 +80,6 @@ class DesktopEnv(gym.Env): # Additional setup self.metadata = {'render.modes': ['rgb_array']} - # set up controllers - self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os) def _get_screensize(self): screenshot_path = self._get_obs() @@ -114,7 +115,9 @@ class DesktopEnv(gym.Env): print(f"Error executing command: {e.output.decode().strip()}") def _execute_command(self, command: List[str]) -> None: - subprocess.run(command, shell=True, stderr=subprocess.STDOUT, timeout=60) + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True) + if result.returncode != 0: + raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m") def _save_state(self): self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path]) @@ -126,8 +129,7 @@ class DesktopEnv(gym.Env): if self.password: self._execute_command( - ["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, - image_path]) + ["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path]) else: self._execute_command( ["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path]) @@ -136,7 +138,17 @@ class DesktopEnv(gym.Env): def _get_obs(self): screenshot_image_path = self._get_screenshot() + self._add_cursor(screenshot_image_path) return screenshot_image_path + + def _add_cursor(self, img_path: str): + x, y = self.mouse_controller.get_mouse() + cursor_image = Image.open("./desktop_env/assets/cursor.png") + cursor_image = cursor_image.resize((int(cursor_image.width / 2), int(cursor_image.height / 2))) + + screenshot = Image.open(img_path) + screenshot.paste(cursor_image, (x, y), cursor_image) + screenshot.save(img_path) def reset(self): print("Resetting environment...") diff --git a/desktop_env/windows_server/main.py b/desktop_env/windows_server/main.py index 467b40e..098dca5 100644 --- a/desktop_env/windows_server/main.py +++ b/desktop_env/windows_server/main.py @@ -11,9 +11,11 @@ def execute_command(): # Execute the command without any safety checks. try: - subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) return jsonify({ 'status': 'success', + 'output': result.stdout, + 'error': result.stderr }) except Exception as e: return jsonify({