Refactor experiments and agent implementation
This commit is contained in:
@@ -53,8 +53,8 @@ class DesktopEnv(gym.Env):
|
||||
def __init__(
|
||||
self,
|
||||
path_to_vm: str,
|
||||
snapshot_name: str ="init_state",
|
||||
action_space: str = "computer_13",
|
||||
task_config: Dict[str, Any] = None,
|
||||
tmp_dir: str = "tmp",
|
||||
cache_dir: str = "cache",
|
||||
screen_size: Tuple[int] = (1920, 1080),
|
||||
@@ -64,15 +64,6 @@ class DesktopEnv(gym.Env):
|
||||
Args:
|
||||
path_to_vm (str): path to .vmx file
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
|
||||
task_config (Dict[str, Any]): manages task configs integratedly,
|
||||
including
|
||||
* base snapshot
|
||||
* task id (uuid)
|
||||
* instruction
|
||||
* setup config
|
||||
* evaluator config
|
||||
|
||||
tmp_dir (str): temporary directory to store trajectory stuffs like
|
||||
the extracted screenshots
|
||||
cache_dir (str): cache directory to cache task-related stuffs like
|
||||
@@ -81,6 +72,7 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
||||
self.snapshot_name = snapshot_name
|
||||
self.tmp_dir_base: str = tmp_dir
|
||||
self.cache_dir_base: str = cache_dir
|
||||
self.vm_screen_size = screen_size
|
||||
@@ -88,16 +80,12 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
||||
|
||||
# task-aware stuffs
|
||||
# todo: handling the logic of snapshot directory
|
||||
self._set_task_info(task_config)
|
||||
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
self.vm_ip = self._get_vm_ip()
|
||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||
|
||||
# Meta info of the VM, move to the reset() function
|
||||
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
||||
@@ -147,7 +135,7 @@ class DesktopEnv(gym.Env):
|
||||
raise Exception("Failed to get VM IP address!")
|
||||
|
||||
def _save_state(self):
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
||||
|
||||
def _get_screenshot(self):
|
||||
# random_uuid = str(uuid.uuid4())
|
||||
@@ -167,7 +155,6 @@ class DesktopEnv(gym.Env):
|
||||
return screenshot_image_path
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
self.snapshot_path = task_config["snapshot"] # todo: save the snapshot when first start the environment, and then revert to it when reset
|
||||
self.task_id: str = task_config["id"]
|
||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
@@ -239,8 +226,8 @@ class DesktopEnv(gym.Env):
|
||||
)
|
||||
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
||||
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
|
||||
time.sleep(5)
|
||||
|
||||
print(self.vm_screen_size)
|
||||
|
||||
Reference in New Issue
Block a user