Merge remote-tracking branch 'upstream/feat/aws-provider-support'
This commit is contained in:
@@ -49,7 +49,7 @@ class SetupController:
|
|||||||
def reset_cache_dir(self, cache_dir: str):
|
def reset_cache_dir(self, cache_dir: str):
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
|
|
||||||
def setup(self, config: List[Dict[str, Any]]):
|
def setup(self, config: List[Dict[str, Any]])-> bool:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
config (List[Dict[str, Any]]): list of dict like {str: Any}. each
|
config (List[Dict[str, Any]]): list of dict like {str: Any}. each
|
||||||
@@ -64,13 +64,18 @@ class SetupController:
|
|||||||
# make sure connection can be established
|
# make sure connection can be established
|
||||||
logger.info(f"try to connect {self.http_server}")
|
logger.info(f"try to connect {self.http_server}")
|
||||||
retry = 0
|
retry = 0
|
||||||
while retry < 30:
|
while retry < 50:
|
||||||
try:
|
try:
|
||||||
_ = requests.get(self.http_server + "/terminal")
|
_ = requests.get(self.http_server + "/terminal")
|
||||||
break
|
break
|
||||||
except:
|
except:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
retry += 1
|
retry += 1
|
||||||
|
logger.info(f"retry: {retry}/50")
|
||||||
|
|
||||||
|
if retry == 50:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
for cfg in config:
|
for cfg in config:
|
||||||
config_type: str = cfg["type"]
|
config_type: str = cfg["type"]
|
||||||
@@ -84,6 +89,8 @@ class SetupController:
|
|||||||
getattr(self, setup_function)(**parameters)
|
getattr(self, setup_function)(**parameters)
|
||||||
|
|
||||||
logger.info("SETUP: %s(%s)", setup_function, str(parameters))
|
logger.info("SETUP: %s(%s)", setup_function, str(parameters))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def _download_setup(self, files: List[Dict[str, str]]):
|
def _download_setup(self, files: List[Dict[str, str]]):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -18,15 +18,15 @@ logger = logging.getLogger("desktopenv.env")
|
|||||||
Metric = Callable[[Any, Any], float]
|
Metric = Callable[[Any, Any], float]
|
||||||
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
||||||
|
|
||||||
|
MAX_RETRIES = 5
|
||||||
|
|
||||||
class DesktopEnv(gym.Env):
|
class DesktopEnv(gym.Env):
|
||||||
"""
|
"""
|
||||||
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
|
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
|
||||||
"""
|
"""
|
||||||
#TODO:provider_name: str = "vmware",
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider_name: str = "aws",
|
provider_name: str = "vmware",
|
||||||
region: str = None,
|
region: str = None,
|
||||||
path_to_vm: str = None,
|
path_to_vm: str = None,
|
||||||
snapshot_name: str = "init_state",
|
snapshot_name: str = "init_state",
|
||||||
@@ -37,6 +37,7 @@ class DesktopEnv(gym.Env):
|
|||||||
require_a11y_tree: bool = True,
|
require_a11y_tree: bool = True,
|
||||||
require_terminal: bool = False,
|
require_terminal: bool = False,
|
||||||
os_type: str = "Ubuntu",
|
os_type: str = "Ubuntu",
|
||||||
|
enable_proxy: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -51,10 +52,13 @@ class DesktopEnv(gym.Env):
|
|||||||
headless (bool): whether to run the VM in headless mode
|
headless (bool): whether to run the VM in headless mode
|
||||||
require_a11y_tree (bool): whether to require accessibility tree
|
require_a11y_tree (bool): whether to require accessibility tree
|
||||||
require_terminal (bool): whether to require terminal output
|
require_terminal (bool): whether to require terminal output
|
||||||
|
os_type (str): operating system type, default to "Ubuntu"
|
||||||
|
enable_proxy (bool): whether to enable proxy support, default to False
|
||||||
"""
|
"""
|
||||||
# Initialize VM manager and vitualization provider
|
# Initialize VM manager and vitualization provider
|
||||||
self.region = region
|
self.region = region
|
||||||
self.provider_name = provider_name
|
self.provider_name = provider_name
|
||||||
|
self.enable_proxy = enable_proxy # Store proxy enablement setting
|
||||||
|
|
||||||
# Default TODO:
|
# Default TODO:
|
||||||
self.server_port = 5000
|
self.server_port = 5000
|
||||||
@@ -145,6 +149,7 @@ class DesktopEnv(gym.Env):
|
|||||||
self.provider.stop_emulator(self.path_to_vm)
|
self.provider.stop_emulator(self.path_to_vm)
|
||||||
|
|
||||||
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||||
|
|
||||||
# Reset to certain task in OSWorld
|
# Reset to certain task in OSWorld
|
||||||
logger.info("Resetting environment...")
|
logger.info("Resetting environment...")
|
||||||
logger.info("Switching task...")
|
logger.info("Switching task...")
|
||||||
@@ -153,48 +158,66 @@ class DesktopEnv(gym.Env):
|
|||||||
self._step_no = 0
|
self._step_no = 0
|
||||||
self.action_history.clear()
|
self.action_history.clear()
|
||||||
|
|
||||||
# Check and handle proxy requirement changes BEFORE starting emulator
|
for attempt in range(MAX_RETRIES):
|
||||||
if task_config is not None:
|
# Check and handle proxy requirement changes BEFORE starting emulator
|
||||||
task_use_proxy = task_config.get("proxy", False)
|
if task_config is not None:
|
||||||
if task_use_proxy != self.current_use_proxy:
|
# Only consider task proxy requirement if proxy is enabled at system level
|
||||||
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
|
task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
|
||||||
|
if not self.enable_proxy and task_config.get("proxy", False):
|
||||||
|
logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
|
||||||
|
|
||||||
# Close current provider if it exists
|
if task_use_proxy != self.current_use_proxy:
|
||||||
if hasattr(self, 'provider') and self.provider:
|
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
|
||||||
try:
|
|
||||||
self.provider.stop_emulator(self.path_to_vm)
|
# Close current provider if it exists
|
||||||
except Exception as e:
|
if hasattr(self, 'provider') and self.provider:
|
||||||
logger.warning(f"Failed to stop current provider: {e}")
|
try:
|
||||||
|
self.provider.stop_emulator(self.path_to_vm)
|
||||||
# Create new provider with appropriate proxy setting
|
except Exception as e:
|
||||||
self.current_use_proxy = task_use_proxy
|
logger.warning(f"Failed to stop current provider: {e}")
|
||||||
self.manager, self.provider = create_vm_manager_and_provider(
|
|
||||||
self.provider_name,
|
# Create new provider with appropriate proxy setting
|
||||||
self.region,
|
self.current_use_proxy = task_use_proxy
|
||||||
use_proxy=task_use_proxy
|
self.manager, self.provider = create_vm_manager_and_provider(
|
||||||
)
|
self.provider_name,
|
||||||
|
self.region,
|
||||||
if task_use_proxy:
|
use_proxy=task_use_proxy
|
||||||
logger.info("Using proxy-enabled AWS provider.")
|
)
|
||||||
else:
|
|
||||||
logger.info("Using regular AWS provider.")
|
if task_use_proxy:
|
||||||
|
logger.info("Using proxy-enabled AWS provider.")
|
||||||
|
else:
|
||||||
|
logger.info("Using regular AWS provider.")
|
||||||
|
|
||||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
|
||||||
self._revert_to_snapshot()
|
|
||||||
logger.info("Starting emulator...")
|
|
||||||
self._start_emulator()
|
|
||||||
logger.info("Emulator started.")
|
|
||||||
|
|
||||||
if task_config is not None:
|
|
||||||
self._set_task_info(task_config)
|
|
||||||
self.setup_controller.reset_cache_dir(self.cache_dir)
|
|
||||||
logger.info("Setting up environment...")
|
|
||||||
self.setup_controller.setup(self.config)
|
|
||||||
logger.info("Environment setup complete.")
|
|
||||||
|
|
||||||
if task_config.get("proxy", False):
|
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||||
# If using proxy, set up the proxy configuration
|
self._revert_to_snapshot()
|
||||||
self.setup_controller._proxy_setup()
|
logger.info("Starting emulator...")
|
||||||
|
self._start_emulator()
|
||||||
|
logger.info("Emulator started.")
|
||||||
|
|
||||||
|
if task_config is not None:
|
||||||
|
self._set_task_info(task_config)
|
||||||
|
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||||
|
logger.info("Setting up environment...")
|
||||||
|
success = self.setup_controller.setup(self.config)
|
||||||
|
if success:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Environment setup failed, retrying (%d/%d)...",
|
||||||
|
attempt + 1,
|
||||||
|
MAX_RETRIES,
|
||||||
|
)
|
||||||
|
time.sleep(5)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("Environment setup complete.")
|
||||||
|
|
||||||
|
if task_config.get("proxy", False) and self.enable_proxy:
|
||||||
|
# If using proxy and proxy is enabled, set up the proxy configuration
|
||||||
|
self.setup_controller._proxy_setup()
|
||||||
|
|
||||||
observation = self._get_obs()
|
observation = self._get_obs()
|
||||||
return observation
|
return observation
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ DEFAULT_REGION = "us-east-1"
|
|||||||
# todo: public the AMI images
|
# todo: public the AMI images
|
||||||
# ami-05e7d7bd279ea4f14
|
# ami-05e7d7bd279ea4f14
|
||||||
IMAGE_ID_MAP = {
|
IMAGE_ID_MAP = {
|
||||||
"us-east-1": "ami-00509b93f2216f419",
|
"us-east-1": "ami-00674d875de9addc1",
|
||||||
"ap-east-1": "ami-0c092a5b8be4116f5",
|
"ap-east-1": "ami-0c092a5b8be4116f5",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user