Merge remote-tracking branch 'upstream/feat/aws-provider-support'

This commit is contained in:
yuanmengqi
2025-06-10 02:36:46 +00:00
3 changed files with 74 additions and 44 deletions

View File

@@ -49,7 +49,7 @@ class SetupController:
def reset_cache_dir(self, cache_dir: str): def reset_cache_dir(self, cache_dir: str):
self.cache_dir = cache_dir self.cache_dir = cache_dir
def setup(self, config: List[Dict[str, Any]]): def setup(self, config: List[Dict[str, Any]])-> bool:
""" """
Args: Args:
config (List[Dict[str, Any]]): list of dict like {str: Any}. each config (List[Dict[str, Any]]): list of dict like {str: Any}. each
@@ -64,13 +64,18 @@ class SetupController:
# make sure connection can be established # make sure connection can be established
logger.info(f"try to connect {self.http_server}") logger.info(f"try to connect {self.http_server}")
retry = 0 retry = 0
while retry < 30: while retry < 50:
try: try:
_ = requests.get(self.http_server + "/terminal") _ = requests.get(self.http_server + "/terminal")
break break
except: except:
time.sleep(5) time.sleep(5)
retry += 1 retry += 1
logger.info(f"retry: {retry}/50")
if retry == 50:
return False
for cfg in config: for cfg in config:
config_type: str = cfg["type"] config_type: str = cfg["type"]
@@ -84,6 +89,8 @@ class SetupController:
getattr(self, setup_function)(**parameters) getattr(self, setup_function)(**parameters)
logger.info("SETUP: %s(%s)", setup_function, str(parameters)) logger.info("SETUP: %s(%s)", setup_function, str(parameters))
return True
def _download_setup(self, files: List[Dict[str, str]]): def _download_setup(self, files: List[Dict[str, str]]):
""" """

View File

@@ -18,15 +18,15 @@ logger = logging.getLogger("desktopenv.env")
Metric = Callable[[Any, Any], float] Metric = Callable[[Any, Any], float]
Getter = Callable[[gym.Env, Dict[str, Any]], Any] Getter = Callable[[gym.Env, Dict[str, Any]], Any]
MAX_RETRIES = 5
class DesktopEnv(gym.Env): class DesktopEnv(gym.Env):
""" """
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks. DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
""" """
#TODO:provider_name: str = "vmware",
def __init__( def __init__(
self, self,
provider_name: str = "aws", provider_name: str = "vmware",
region: str = None, region: str = None,
path_to_vm: str = None, path_to_vm: str = None,
snapshot_name: str = "init_state", snapshot_name: str = "init_state",
@@ -37,6 +37,7 @@ class DesktopEnv(gym.Env):
require_a11y_tree: bool = True, require_a11y_tree: bool = True,
require_terminal: bool = False, require_terminal: bool = False,
os_type: str = "Ubuntu", os_type: str = "Ubuntu",
enable_proxy: bool = False,
): ):
""" """
Args: Args:
@@ -51,10 +52,13 @@ class DesktopEnv(gym.Env):
headless (bool): whether to run the VM in headless mode headless (bool): whether to run the VM in headless mode
require_a11y_tree (bool): whether to require accessibility tree require_a11y_tree (bool): whether to require accessibility tree
require_terminal (bool): whether to require terminal output require_terminal (bool): whether to require terminal output
os_type (str): operating system type, default to "Ubuntu"
enable_proxy (bool): whether to enable proxy support, default to False
""" """
# Initialize VM manager and vitualization provider # Initialize VM manager and vitualization provider
self.region = region self.region = region
self.provider_name = provider_name self.provider_name = provider_name
self.enable_proxy = enable_proxy # Store proxy enablement setting
# Default TODO: # Default TODO:
self.server_port = 5000 self.server_port = 5000
@@ -145,6 +149,7 @@ class DesktopEnv(gym.Env):
self.provider.stop_emulator(self.path_to_vm) self.provider.stop_emulator(self.path_to_vm)
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]: def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
# Reset to certain task in OSWorld # Reset to certain task in OSWorld
logger.info("Resetting environment...") logger.info("Resetting environment...")
logger.info("Switching task...") logger.info("Switching task...")
@@ -153,48 +158,66 @@ class DesktopEnv(gym.Env):
self._step_no = 0 self._step_no = 0
self.action_history.clear() self.action_history.clear()
# Check and handle proxy requirement changes BEFORE starting emulator for attempt in range(MAX_RETRIES):
if task_config is not None: # Check and handle proxy requirement changes BEFORE starting emulator
task_use_proxy = task_config.get("proxy", False) if task_config is not None:
if task_use_proxy != self.current_use_proxy: # Only consider task proxy requirement if proxy is enabled at system level
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}") task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
if not self.enable_proxy and task_config.get("proxy", False):
logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
# Close current provider if it exists if task_use_proxy != self.current_use_proxy:
if hasattr(self, 'provider') and self.provider: logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
try:
self.provider.stop_emulator(self.path_to_vm) # Close current provider if it exists
except Exception as e: if hasattr(self, 'provider') and self.provider:
logger.warning(f"Failed to stop current provider: {e}") try:
self.provider.stop_emulator(self.path_to_vm)
# Create new provider with appropriate proxy setting except Exception as e:
self.current_use_proxy = task_use_proxy logger.warning(f"Failed to stop current provider: {e}")
self.manager, self.provider = create_vm_manager_and_provider(
self.provider_name, # Create new provider with appropriate proxy setting
self.region, self.current_use_proxy = task_use_proxy
use_proxy=task_use_proxy self.manager, self.provider = create_vm_manager_and_provider(
) self.provider_name,
self.region,
if task_use_proxy: use_proxy=task_use_proxy
logger.info("Using proxy-enabled AWS provider.") )
else:
logger.info("Using regular AWS provider.") if task_use_proxy:
logger.info("Using proxy-enabled AWS provider.")
else:
logger.info("Using regular AWS provider.")
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
self._revert_to_snapshot()
logger.info("Starting emulator...")
self._start_emulator()
logger.info("Emulator started.")
if task_config is not None:
self._set_task_info(task_config)
self.setup_controller.reset_cache_dir(self.cache_dir)
logger.info("Setting up environment...")
self.setup_controller.setup(self.config)
logger.info("Environment setup complete.")
if task_config.get("proxy", False): logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
# If using proxy, set up the proxy configuration self._revert_to_snapshot()
self.setup_controller._proxy_setup() logger.info("Starting emulator...")
self._start_emulator()
logger.info("Emulator started.")
if task_config is not None:
self._set_task_info(task_config)
self.setup_controller.reset_cache_dir(self.cache_dir)
logger.info("Setting up environment...")
success = self.setup_controller.setup(self.config)
if success:
break
else:
logger.error(
"Environment setup failed, retrying (%d/%d)...",
attempt + 1,
MAX_RETRIES,
)
time.sleep(5)
else:
break
logger.info("Environment setup complete.")
if task_config.get("proxy", False) and self.enable_proxy:
# If using proxy and proxy is enabled, set up the proxy configuration
self.setup_controller._proxy_setup()
observation = self._get_obs() observation = self._get_obs()
return observation return observation

View File

@@ -33,7 +33,7 @@ DEFAULT_REGION = "us-east-1"
# todo: public the AMI images # todo: public the AMI images
# ami-05e7d7bd279ea4f14 # ami-05e7d7bd279ea4f14
IMAGE_ID_MAP = { IMAGE_ID_MAP = {
"us-east-1": "ami-00509b93f2216f419", "us-east-1": "ami-00674d875de9addc1",
"ap-east-1": "ami-0c092a5b8be4116f5", "ap-east-1": "ami-0c092a5b8be4116f5",
} }