add mouse cursor to screenshot

This commit is contained in:
Jing Hua
2023-11-30 17:31:46 +08:00
parent e52ba2ab13
commit ebb5f1cbc5
5 changed files with 37 additions and 7 deletions

View File

@@ -58,6 +58,9 @@ class DesktopEnv(gym.Env):
print("Initializing...")
self._start_emulator()
# set up controllers
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
# Get the screen size
self.screen_width, self.screen_height = self._get_screensize()
@@ -77,8 +80,6 @@ class DesktopEnv(gym.Env):
# Additional setup
self.metadata = {'render.modes': ['rgb_array']}
# set up controllers
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
def _get_screensize(self):
screenshot_path = self._get_obs()
@@ -114,7 +115,9 @@ class DesktopEnv(gym.Env):
print(f"Error executing command: {e.output.decode().strip()}")
def _execute_command(self, command: List[str]) -> None:
subprocess.run(command, shell=True, stderr=subprocess.STDOUT, timeout=60)
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
if result.returncode != 0:
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
def _save_state(self):
self._execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
@@ -126,8 +129,7 @@ class DesktopEnv(gym.Env):
if self.password:
self._execute_command(
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm,
image_path])
["vmrun", "-T", "ws", "-gu", self.username, "-gp", self.password, "captureScreen", self.path_to_vm, image_path])
else:
self._execute_command(
["vmrun", "-T", "ws", "-gu", self.username, "captureScreen", self.path_to_vm, image_path])
@@ -136,7 +138,17 @@ class DesktopEnv(gym.Env):
def _get_obs(self):
screenshot_image_path = self._get_screenshot()
self._add_cursor(screenshot_image_path)
return screenshot_image_path
def _add_cursor(self, img_path: str):
x, y = self.mouse_controller.get_mouse()
cursor_image = Image.open("./desktop_env/assets/cursor.png")
cursor_image = cursor_image.resize((int(cursor_image.width / 2), int(cursor_image.height / 2)))
screenshot = Image.open(img_path)
screenshot.paste(cursor_image, (x, y), cursor_image)
screenshot.save(img_path)
def reset(self):
print("Resetting environment...")