diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 488d4ef..c0893ab 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -194,7 +194,13 @@ class DesktopEnv(gym.Env): self.require_terminal = require_terminal # Initialize emulator and controller - if provider_name != "docker": # Check if this is applicable to other VM providers + # Docker provider needs delayed initialization due to container lifecycle + if provider_name == "docker": + logger.info("Docker provider detected - will initialize on first reset()") + # Initialize controllers as None for Docker - they'll be set up in reset() + self.controller = None + self.setup_controller = None + else: logger.info("Initializing...") self._start_emulator() @@ -289,6 +295,11 @@ class DesktopEnv(gym.Env): self.is_environment_used = False else: logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name)) + + # Initialize Docker provider controllers if not already done + if self.provider_name == "docker" and self.controller is None: + logger.info("Initializing Docker provider controllers...") + self._start_emulator() if task_config is not None: if task_config.get("proxy", False) and self.enable_proxy: diff --git a/run.py b/run.py index a915ac2..3730902 100644 --- a/run.py +++ b/run.py @@ -67,6 +67,10 @@ def config() -> argparse.Namespace: # environment config parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--provider_name", type=str, default="vmware", + help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)" + ) parser.add_argument( "--headless", action="store_true", help="Run in headless machine" ) @@ -119,6 +123,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: # set wandb project cfg_args = { "path_to_vm": args.path_to_vm, + "provider_name": args.provider_name, "headless": args.headless, "action_space": args.action_space, "observation_type": args.observation_type, @@ -146,6 +151,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: ) env = DesktopEnv( + provider_name=args.provider_name, path_to_vm=args.path_to_vm, action_space=agent.action_space, screen_size=(args.screen_width, args.screen_height), @@ -199,9 +205,11 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: ) except Exception as e: logger.error(f"Exception in {domain}/{example_id}: {e}") - env.controller.end_recording( - os.path.join(example_result_dir, "recording.mp4") - ) + # Only attempt to end recording if controller exists (not Docker provider) + if hasattr(env, 'controller') and env.controller is not None: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write( json.dumps(