Improve code logic for password & resolution
This commit is contained in:
@@ -26,18 +26,19 @@ class DesktopEnv(gym.Env):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str = "vmware",
|
||||
provider_name: str = "aws",
|
||||
region: str = None,
|
||||
path_to_vm: str = None,
|
||||
snapshot_name: str = "init_state",
|
||||
action_space: str = "computer_13",
|
||||
cache_dir: str = "cache",
|
||||
screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))),
|
||||
screen_size: Tuple[int] = (1920, 1080),
|
||||
headless: bool = False,
|
||||
require_a11y_tree: bool = True,
|
||||
require_terminal: bool = False,
|
||||
os_type: str = "Ubuntu",
|
||||
enable_proxy: bool = False,
|
||||
client_password: str = "",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -59,6 +60,16 @@ class DesktopEnv(gym.Env):
|
||||
self.region = region
|
||||
self.provider_name = provider_name
|
||||
self.enable_proxy = enable_proxy # Store proxy enablement setting
|
||||
if client_password == "":
|
||||
if self.provider_name == "aws":
|
||||
self.client_password = "osworld-public-evaluation"
|
||||
else:
|
||||
self.client_password = "password"
|
||||
else:
|
||||
self.client_password = client_password
|
||||
|
||||
self.screen_width = screen_size[0]
|
||||
self.screen_height = screen_size[1]
|
||||
|
||||
# Default
|
||||
self.server_port = 5000
|
||||
@@ -88,7 +99,7 @@ class DesktopEnv(gym.Env):
|
||||
if provider_name in {"vmware", "virtualbox"} else path_to_vm
|
||||
else:
|
||||
|
||||
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region)
|
||||
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region, screen_size=(self.screen_width, self.screen_height))
|
||||
try:
|
||||
self.snapshot_name = snapshot_name
|
||||
self.cache_dir_base: str = cache_dir
|
||||
@@ -136,7 +147,7 @@ class DesktopEnv(gym.Env):
|
||||
self.vnc_port = int(vm_ip_ports[3])
|
||||
self.vlc_port = int(vm_ip_ports[4])
|
||||
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base, client_password=self.client_password, screen_width=self.screen_width, screen_height=self.screen_height)
|
||||
|
||||
def _revert_to_snapshot(self):
|
||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||
@@ -197,7 +208,7 @@ class DesktopEnv(gym.Env):
|
||||
if task_config is not None:
|
||||
if task_config.get("proxy", False) and self.enable_proxy:
|
||||
# If using proxy and proxy is enabled, set up the proxy configuration
|
||||
self.setup_controller._proxy_setup()
|
||||
self.setup_controller._proxy_setup(self.client_password)
|
||||
self._set_task_info(task_config)
|
||||
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||
logger.info("Setting up environment...")
|
||||
|
||||
Reference in New Issue
Block a user