Add --provider_name parameter to run.py and fix Docker provider initialization (#277)

- Add command-line argument --provider_name to support flexible provider selection
- Default provider remains vmware for backward compatibility
- Fix Docker provider controller initialization issue with delayed setup
- Add safety checks for controller existence in error handling

This enables users to specify different virtualization providers directly
from the command line and resolves Docker container lifecycle issues.
This commit is contained in:
张逸群
2025-07-23 04:09:36 +08:00
committed by GitHub
parent 73de48af75
commit 4d6e0fd031
2 changed files with 23 additions and 4 deletions

View File

@@ -194,7 +194,13 @@ class DesktopEnv(gym.Env):
self.require_terminal = require_terminal self.require_terminal = require_terminal
# Initialize emulator and controller # Initialize emulator and controller
if provider_name != "docker": # Check if this is applicable to other VM providers # Docker provider needs delayed initialization due to container lifecycle
if provider_name == "docker":
logger.info("Docker provider detected - will initialize on first reset()")
# Initialize controllers as None for Docker - they'll be set up in reset()
self.controller = None
self.setup_controller = None
else:
logger.info("Initializing...") logger.info("Initializing...")
self._start_emulator() self._start_emulator()
@@ -289,6 +295,11 @@ class DesktopEnv(gym.Env):
self.is_environment_used = False self.is_environment_used = False
else: else:
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name)) logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
# Initialize Docker provider controllers if not already done
if self.provider_name == "docker" and self.controller is None:
logger.info("Initializing Docker provider controllers...")
self._start_emulator()
if task_config is not None: if task_config is not None:
if task_config.get("proxy", False) and self.enable_proxy: if task_config.get("proxy", False) and self.enable_proxy:

14
run.py
View File

@@ -67,6 +67,10 @@ def config() -> argparse.Namespace:
# environment config # environment config
parser.add_argument("--path_to_vm", type=str, default=None) parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--provider_name", type=str, default="vmware",
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
)
parser.add_argument( parser.add_argument(
"--headless", action="store_true", help="Run in headless machine" "--headless", action="store_true", help="Run in headless machine"
) )
@@ -119,6 +123,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
# set wandb project # set wandb project
cfg_args = { cfg_args = {
"path_to_vm": args.path_to_vm, "path_to_vm": args.path_to_vm,
"provider_name": args.provider_name,
"headless": args.headless, "headless": args.headless,
"action_space": args.action_space, "action_space": args.action_space,
"observation_type": args.observation_type, "observation_type": args.observation_type,
@@ -146,6 +151,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
) )
env = DesktopEnv( env = DesktopEnv(
provider_name=args.provider_name,
path_to_vm=args.path_to_vm, path_to_vm=args.path_to_vm,
action_space=agent.action_space, action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height), screen_size=(args.screen_width, args.screen_height),
@@ -199,9 +205,11 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
) )
except Exception as e: except Exception as e:
logger.error(f"Exception in {domain}/{example_id}: {e}") logger.error(f"Exception in {domain}/{example_id}: {e}")
env.controller.end_recording( # Only attempt to end recording if controller exists (not Docker provider)
os.path.join(example_result_dir, "recording.mp4") if hasattr(env, 'controller') and env.controller is not None:
) env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write( f.write(
json.dumps( json.dumps(