Add --provider_name parameter to run.py and fix Docker provider initialization (#277)

- Add command-line argument --provider_name to support flexible provider selection
- Default provider remains vmware for backward compatibility
- Fix Docker provider controller initialization issue with delayed setup
- Add safety checks for controller existence in error handling

This enables users to specify different virtualization providers directly
from the command line and resolves Docker container lifecycle issues.
This commit is contained in:
张逸群
2025-07-23 04:09:36 +08:00
committed by GitHub
parent 73de48af75
commit 4d6e0fd031
2 changed files with 23 additions and 4 deletions

14
run.py
View File

@@ -67,6 +67,10 @@ def config() -> argparse.Namespace:
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--provider_name", type=str, default="vmware",
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
@@ -119,6 +123,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
# set wandb project
cfg_args = {
"path_to_vm": args.path_to_vm,
"provider_name": args.provider_name,
"headless": args.headless,
"action_space": args.action_space,
"observation_type": args.observation_type,
@@ -146,6 +151,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
)
env = DesktopEnv(
provider_name=args.provider_name,
path_to_vm=args.path_to_vm,
action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height),
@@ -199,9 +205,11 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
)
except Exception as e:
logger.error(f"Exception in {domain}/{example_id}: {e}")
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
# Only attempt to end recording if controller exists (not Docker provider)
if hasattr(env, 'controller') and env.controller is not None:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(