Add --provider_name parameter to run.py and fix Docker provider initialization (#277)

- Add command-line argument --provider_name to support flexible provider selection
- Default provider remains vmware for backward compatibility
- Fix Docker provider controller initialization issue with delayed setup
- Add safety checks for controller existence in error handling

This enables users to specify different virtualization providers directly
from the command line and resolves Docker container lifecycle issues.
This commit is contained in:
张逸群
2025-07-23 04:09:36 +08:00
committed by GitHub
parent 73de48af75
commit 4d6e0fd031
2 changed files with 23 additions and 4 deletions

View File

@@ -194,7 +194,13 @@ class DesktopEnv(gym.Env):
self.require_terminal = require_terminal
# Initialize emulator and controller
if provider_name != "docker": # Check if this is applicable to other VM providers
# Docker provider needs delayed initialization due to container lifecycle
if provider_name == "docker":
logger.info("Docker provider detected - will initialize on first reset()")
# Initialize controllers as None for Docker - they'll be set up in reset()
self.controller = None
self.setup_controller = None
else:
logger.info("Initializing...")
self._start_emulator()
@@ -289,6 +295,11 @@ class DesktopEnv(gym.Env):
self.is_environment_used = False
else:
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
# Initialize Docker provider controllers if not already done
if self.provider_name == "docker" and self.controller is None:
logger.info("Initializing Docker provider controllers...")
self._start_emulator()
if task_config is not None:
if task_config.get("proxy", False) and self.enable_proxy:

14
run.py
View File

@@ -67,6 +67,10 @@ def config() -> argparse.Namespace:
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--provider_name", type=str, default="vmware",
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
@@ -119,6 +123,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
# set wandb project
cfg_args = {
"path_to_vm": args.path_to_vm,
"provider_name": args.provider_name,
"headless": args.headless,
"action_space": args.action_space,
"observation_type": args.observation_type,
@@ -146,6 +151,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
)
env = DesktopEnv(
provider_name=args.provider_name,
path_to_vm=args.path_to_vm,
action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height),
@@ -199,9 +205,11 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
)
except Exception as e:
logger.error(f"Exception in {domain}/{example_id}: {e}")
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
# Only attempt to end recording if controller exists (not Docker provider)
if hasattr(env, 'controller') and env.controller is not None:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(