ver Jan5th

debugged
This commit is contained in:
David Chang
2024-01-05 15:20:47 +08:00
parent 5fedf5b891
commit eeb8a120d6
17 changed files with 158 additions and 86 deletions

View File

@@ -17,10 +17,12 @@ from desktop_env.controllers.setup import SetupController
# from desktop_env.evaluators import eval_funcs
from desktop_env.evaluators import metrics, getters
import logging
logger = logging.getLogger("desktopenv.env")
Metric = Callable[[Any, Any], float]
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
def _execute_command(command: List[str]) -> None:
if command[:4] == ["vmrun", "-T", "ws", "start"]:
p = subprocess.Popen(command)
@@ -72,7 +74,7 @@ class DesktopEnv(gym.Env):
self._set_task_info(task_config)
# Initialize emulator and controller
print("Initializing...")
logger.info("Initializing...")
self._start_emulator()
self.host = f"http://{self._get_vm_ip()}:5000"
self.controller = PythonController(http_server=self.host)
@@ -98,26 +100,26 @@ class DesktopEnv(gym.Env):
output: List[str] = output.splitlines()
# if self.path_to_vm.lstrip("~/") in output:
if self.path_to_vm in output:
print("VM is running.")
logger.info("VM is running.")
break
else:
print("Starting VM...")
logger.info("Starting VM...")
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm])
time.sleep(3)
except subprocess.CalledProcessError as e:
print(f"Error executing command: {e.output.decode().strip()}")
logger.error(f"Error executing command: {e.output.decode().strip()}")
def _get_vm_ip(self):
max_retries = 10
print("Getting IP Address...")
logger.info("Getting IP Address...")
for _ in range(max_retries):
try:
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm]).strip()
print(f"IP address: {output}")
logger.info(f"IP address: {output}")
return output
except:
time.sleep(5)
print("Retrying...")
logger.info("Retrying...")
raise Exception("Failed to get VM IP address!")
def _save_state(self):
@@ -156,38 +158,38 @@ class DesktopEnv(gym.Env):
self.metric_options: Dict[str, Any] = self.evaluator.get("options", {})
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None):
print("Resetting environment...")
logger.info("Resetting environment...")
print("Switching task...")
logger.info("Switching task...")
if task_config is not None:
self._set_task_info(task_config)
self.setup_controller.reset_cache_dir(self.cache_dir)
print("Setting counters...")
logger.info("Setting counters...")
self._traj_no += 1
self._step_no = 0
self.action_history.clear()
print("Setup new temp dir...")
logger.info("Setup new temp dir...")
self.tmp_dir = tempfile.mkdtemp(
prefix="{:d}.{:}.".format(self._traj_no, self.task_id),
dir=self.tmp_dir_base
)
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
print("Reverting to snapshot to {}...".format(self.snapshot_path))
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5)
print("Starting emulator...")
logger.info("Starting emulator...")
self._start_emulator()
print("Emulator started.")
logger.info("Emulator started.")
print("Setting up environment...")
logger.info("Setting up environment...")
self.setup_controller.setup(self.config)
time.sleep(5)
print("Environment setup complete.")
logger.info("Environment setup complete.")
observation = self._get_obs()
return observation