Merge remote-tracking branch 'origin/main'
# Conflicts: # main.py
This commit is contained in:
@@ -9,12 +9,14 @@ from typing import List
|
||||
import gymnasium as gym
|
||||
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
|
||||
|
||||
def _execute_command(command: List[str]) -> None:
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||
return result.stdout
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
@@ -23,19 +25,19 @@ class DesktopEnv(gym.Env):
|
||||
def __init__(
|
||||
self,
|
||||
path_to_vm: str,
|
||||
host: str = "192.168.7.128:5000",
|
||||
snapshot_path: str = "base",
|
||||
action_space: str = "pyautogui",
|
||||
):
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = path_to_vm
|
||||
self.host = host
|
||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||
|
||||
# Initialize emulator and controller
|
||||
print("Initializing...")
|
||||
self._start_emulator()
|
||||
self.host = f"http://{self._get_vm_ip()}:5000"
|
||||
self.controller = PythonController(http_server=self.host)
|
||||
self.setup_controller = SetupController(http_server=self.host)
|
||||
|
||||
# mode: human or machine
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
@@ -53,10 +55,23 @@ class DesktopEnv(gym.Env):
|
||||
else:
|
||||
print("Starting VM...")
|
||||
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
|
||||
time.sleep(10)
|
||||
time.sleep(3)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing command: {e.output.decode().strip()}")
|
||||
|
||||
def _get_vm_ip(self):
|
||||
max_retries = 10
|
||||
print("Getting IP Address...")
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
|
||||
print(f"IP address: {output}")
|
||||
return output
|
||||
except:
|
||||
time.sleep(5)
|
||||
print("Retrying...")
|
||||
raise Exception("Failed to get VM IP address!")
|
||||
|
||||
def _save_state(self):
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
|
||||
@@ -105,6 +120,9 @@ class DesktopEnv(gym.Env):
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
return observation, reward, done, info
|
||||
|
||||
def setup(self, config: dict):
|
||||
self.setup_controller.setup(config)
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
|
||||
Reference in New Issue
Block a user