feat: implement proxy management for AWS VM provider and enhance task configuration handling
This commit is contained in:
@@ -54,13 +54,17 @@ class DesktopEnv(gym.Env):
|
||||
"""
|
||||
# Initialize VM manager and vitualization provider
|
||||
self.region = region
|
||||
self.provider_name = provider_name
|
||||
|
||||
# Default TODO:
|
||||
self.server_port = 5000
|
||||
self.chromium_port = 9222
|
||||
self.vnc_port = 8006
|
||||
self.vlc_port = 8080
|
||||
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
||||
|
||||
# Initialize with default (no proxy) provider
|
||||
self.current_use_proxy = False
|
||||
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region, use_proxy=False)
|
||||
|
||||
self.os_type = os_type
|
||||
|
||||
@@ -149,6 +153,32 @@ class DesktopEnv(gym.Env):
|
||||
self._step_no = 0
|
||||
self.action_history.clear()
|
||||
|
||||
# Check and handle proxy requirement changes BEFORE starting emulator
|
||||
if task_config is not None:
|
||||
task_use_proxy = task_config.get("proxy", False)
|
||||
if task_use_proxy != self.current_use_proxy:
|
||||
logger.info(f"Task proxy requirement changed: {self.current_use_proxy} -> {task_use_proxy}")
|
||||
|
||||
# Close current provider if it exists
|
||||
if hasattr(self, 'provider') and self.provider:
|
||||
try:
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stop current provider: {e}")
|
||||
|
||||
# Create new provider with appropriate proxy setting
|
||||
self.current_use_proxy = task_use_proxy
|
||||
self.manager, self.provider = create_vm_manager_and_provider(
|
||||
self.provider_name,
|
||||
self.region,
|
||||
use_proxy=task_use_proxy
|
||||
)
|
||||
|
||||
if task_use_proxy:
|
||||
logger.info("Using proxy-enabled AWS provider.")
|
||||
else:
|
||||
logger.info("Using regular AWS provider.")
|
||||
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||
self._revert_to_snapshot()
|
||||
logger.info("Starting emulator...")
|
||||
@@ -184,12 +214,17 @@ class DesktopEnv(gym.Env):
|
||||
return self.controller.get_vm_screen_size()
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
"""Set task info (proxy logic is handled in reset method)"""
|
||||
self.task_id: str = task_config["id"]
|
||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
self.instruction = task_config["instruction"]
|
||||
self.config = task_config["config"] if "config" in task_config else []
|
||||
|
||||
self._set_evaluator_info(task_config)
|
||||
|
||||
def _set_evaluator_info(self, task_config: Dict[str, Any]):
|
||||
"""Set evaluator information from task config"""
|
||||
# evaluator dict
|
||||
# func -> metric function string, or list of metric function strings
|
||||
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
||||
|
||||
Reference in New Issue
Block a user