Fix the width and height of vm, make agent perform more accurate
This commit is contained in:
@@ -9,6 +9,7 @@ import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
import uuid
|
||||
from PIL import Image
|
||||
|
||||
from desktop_env.controllers.mouse import MouseClick, AbstractMouseController, XDoToolMouseController, \
|
||||
PythonMouseController
|
||||
@@ -39,8 +40,10 @@ class DesktopEnv(gym.Env):
|
||||
username: str,
|
||||
password: str = None,
|
||||
host: str = "192.168.7.128:5000",
|
||||
snapshot_path: str = "initial_state_with_env_set",
|
||||
vm_os: VM_TYPE = "ubuntu"):
|
||||
snapshot_path: str = "base",
|
||||
vm_os: VM_TYPE = "ubuntu"
|
||||
):
|
||||
|
||||
# The path to the vmx file of your vm
|
||||
self.path_to_vm = path_to_vm
|
||||
|
||||
@@ -51,9 +54,13 @@ class DesktopEnv(gym.Env):
|
||||
self.host = host
|
||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||
|
||||
# TODO: get the screen width and height from the vm, or standardize it
|
||||
self.screen_width = 800
|
||||
self.screen_height = 800
|
||||
# Initialize emulator
|
||||
print("Initializing...")
|
||||
self._start_emulator()
|
||||
|
||||
# Get the screen size
|
||||
self.screen_width, self.screen_height = self._get_screensize()
|
||||
|
||||
# Define the action and observation space
|
||||
self.action_space = spaces.Dict({
|
||||
"action_type": spaces.Discrete(len(Action)),
|
||||
@@ -70,13 +77,14 @@ class DesktopEnv(gym.Env):
|
||||
# Additional setup
|
||||
self.metadata = {'render.modes': ['rgb_array']}
|
||||
|
||||
# Initialize emulator
|
||||
print("Initializing...")
|
||||
self._start_emulator()
|
||||
|
||||
# set up controllers
|
||||
self.mouse_controller, self.keyboard_controller = self._create_controllers(vm_os)
|
||||
|
||||
def _get_screensize(self):
|
||||
screenshot_path = self._get_obs()
|
||||
img = Image.open(screenshot_path)
|
||||
return img.size
|
||||
|
||||
def _create_controllers(self, vm_os: VM_TYPE) -> Tuple[AbstractMouseController, AbstractKeyboardController]:
|
||||
if vm_os == "ubuntu":
|
||||
ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": self.password})
|
||||
@@ -145,7 +153,18 @@ class DesktopEnv(gym.Env):
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
action_type = Action(action['action_type'])
|
||||
if isinstance(action, list):
|
||||
for a in action:
|
||||
observation, reward, done, info = self.step(a)
|
||||
return observation, reward, done, info
|
||||
|
||||
# todo: handle the case when the action is not a single action
|
||||
try:
|
||||
action_type = Action(action['action_type'])
|
||||
except KeyError:
|
||||
done = True
|
||||
return self._get_obs(), 0, done, {}
|
||||
|
||||
if action_type == Action.CLICK:
|
||||
click = MouseClick(action['click_type'])
|
||||
if click == MouseClick.LEFT:
|
||||
@@ -185,17 +204,19 @@ class DesktopEnv(gym.Env):
|
||||
elif action_type == Action.MOUSE_MOVE:
|
||||
self.mouse_controller.mouse_move(x=action['x'], y=action['y'])
|
||||
elif action_type == Action.KEY:
|
||||
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
||||
self.keyboard_controller.key(key_sequence)
|
||||
self.keyboard_controller.key(action['key'])
|
||||
elif action_type == Action.KEY_DOWN:
|
||||
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
||||
self.keyboard_controller.key_down(key_sequence)
|
||||
self.keyboard_controller.key_down(action['key'])
|
||||
elif action_type == Action.KEY_UP:
|
||||
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
||||
self.keyboard_controller.key_up(key_sequence)
|
||||
self.keyboard_controller.key_up(action['key'])
|
||||
elif action_type == Action.TYPE:
|
||||
text = ''.join(map(chr, action['text'])) # Convert integer array to string
|
||||
self.keyboard_controller.type(text)
|
||||
for key in action['text']:
|
||||
if key == "\r" or key == "\n":
|
||||
self.keyboard_controller.key("enter")
|
||||
else:
|
||||
self.keyboard_controller.key(key)
|
||||
# sleep for 0.05 seconds with some random noise
|
||||
time.sleep(0.05 + np.random.normal(0, 0.01))
|
||||
|
||||
# Capture new state
|
||||
observation = self._get_obs()
|
||||
|
||||
Reference in New Issue
Block a user