Refactor experiments and agent implementation

This commit is contained in:
Timothyxxx
2024-03-14 22:32:49 +08:00
parent 71ca8fbe1c
commit 44ff027801
8 changed files with 359 additions and 1944 deletions

View File

@@ -53,8 +53,8 @@ class DesktopEnv(gym.Env):
def __init__(
self,
path_to_vm: str,
snapshot_name: str ="init_state",
action_space: str = "computer_13",
task_config: Dict[str, Any] = None,
tmp_dir: str = "tmp",
cache_dir: str = "cache",
screen_size: Tuple[int] = (1920, 1080),
@@ -64,15 +64,6 @@ class DesktopEnv(gym.Env):
Args:
path_to_vm (str): path to .vmx file
action_space (str): "computer_13" | "pyautogui"
task_config (Dict[str, Any]): manages task configs integratedly,
including
* base snapshot
* task id (uuid)
* instruction
* setup config
* evaluator config
tmp_dir (str): temporary directory to store trajectory stuffs like
the extracted screenshots
cache_dir (str): cache directory to cache task-related stuffs like
@@ -81,6 +72,7 @@ class DesktopEnv(gym.Env):
# Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.snapshot_name = snapshot_name
self.tmp_dir_base: str = tmp_dir
self.cache_dir_base: str = cache_dir
self.vm_screen_size = screen_size
@@ -88,16 +80,12 @@ class DesktopEnv(gym.Env):
os.makedirs(self.tmp_dir_base, exist_ok=True)
# task-aware stuffs
# todo: handling the logic of snapshot directory
self._set_task_info(task_config)
# Initialize emulator and controller
logger.info("Initializing...")
self._start_emulator()
self.vm_ip = self._get_vm_ip()
self.controller = PythonController(vm_ip=self.vm_ip)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
# Meta info of the VM, move to the reset() function
self.vm_platform: str = "" # self.controller.get_vm_platform()
@@ -147,7 +135,7 @@ class DesktopEnv(gym.Env):
raise Exception("Failed to get VM IP address!")
def _save_state(self):
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
def _get_screenshot(self):
# random_uuid = str(uuid.uuid4())
@@ -167,7 +155,6 @@ class DesktopEnv(gym.Env):
return screenshot_image_path
def _set_task_info(self, task_config: Dict[str, Any]):
self.snapshot_path = task_config["snapshot"] # todo: save the snapshot when first start the environment, and then revert to it when reset
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
@@ -239,8 +226,8 @@ class DesktopEnv(gym.Env):
)
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
logger.info("Reverting to snapshot to {}...".format(self.snapshot_path))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
time.sleep(5)
print(self.vm_screen_size)