Add --provider_name parameter to run.py and fix Docker provider initialization (#277)
- Add command-line argument --provider_name to support flexible provider selection - Default provider remains vmware for backward compatibility - Fix Docker provider controller initialization issue with delayed setup - Add safety checks for controller existence in error handling This enables users to specify different virtualization providers directly from the command line and resolves Docker container lifecycle issues.
This commit is contained in:
@@ -194,7 +194,13 @@ class DesktopEnv(gym.Env):
|
||||
self.require_terminal = require_terminal
|
||||
|
||||
# Initialize emulator and controller
|
||||
if provider_name != "docker": # Check if this is applicable to other VM providers
|
||||
# Docker provider needs delayed initialization due to container lifecycle
|
||||
if provider_name == "docker":
|
||||
logger.info("Docker provider detected - will initialize on first reset()")
|
||||
# Initialize controllers as None for Docker - they'll be set up in reset()
|
||||
self.controller = None
|
||||
self.setup_controller = None
|
||||
else:
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
|
||||
@@ -289,6 +295,11 @@ class DesktopEnv(gym.Env):
|
||||
self.is_environment_used = False
|
||||
else:
|
||||
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
|
||||
|
||||
# Initialize Docker provider controllers if not already done
|
||||
if self.provider_name == "docker" and self.controller is None:
|
||||
logger.info("Initializing Docker provider controllers...")
|
||||
self._start_emulator()
|
||||
|
||||
if task_config is not None:
|
||||
if task_config.get("proxy", False) and self.enable_proxy:
|
||||
|
||||
14
run.py
14
run.py
@@ -67,6 +67,10 @@ def config() -> argparse.Namespace:
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
@@ -119,6 +123,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
# set wandb project
|
||||
cfg_args = {
|
||||
"path_to_vm": args.path_to_vm,
|
||||
"provider_name": args.provider_name,
|
||||
"headless": args.headless,
|
||||
"action_space": args.action_space,
|
||||
"observation_type": args.observation_type,
|
||||
@@ -146,6 +151,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
)
|
||||
|
||||
env = DesktopEnv(
|
||||
provider_name=args.provider_name,
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=agent.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
@@ -199,9 +205,11 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
# Only attempt to end recording if controller exists (not Docker provider)
|
||||
if hasattr(env, 'controller') and env.controller is not None:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
|
||||
Reference in New Issue
Block a user