diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 543c17d..096dc54 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -101,7 +101,7 @@ class DesktopEnv(gym.Env): provider_name: str = "vmware", region: str = None, path_to_vm: str = None, - snapshot_name: str = "init_state", + snapshot_name: str = "snapshot", action_space: str = "pyautogui", cache_dir: str = "cache", screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))), @@ -117,7 +117,7 @@ class DesktopEnv(gym.Env): provider_name (str): virtualization provider name, default to "vmware" region (str): the region for allocate machines, work for cloud services, default to "us-east-1" path_to_vm (str): path to .vmx file - snapshot_name (str): snapshot name to revert to, default to "init_state" + snapshot_name (str): snapshot name to revert to, default to "snapshot" action_space (str): "computer_13" | "pyautogui" cache_dir (str): cache directory to cache task-related stuffs like reference file for evaluation @@ -265,7 +265,7 @@ class DesktopEnv(gym.Env): self.current_use_proxy = task_use_proxy if self.is_environment_used: - logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name)) + logger.info("Environment has been used, reverting to snapshot: {}...".format(self.snapshot_name)) self._revert_to_snapshot() logger.info("Starting emulator...") self._start_emulator() @@ -402,6 +402,7 @@ class DesktopEnv(gym.Env): if self.action_space == "computer_13": # the set of all possible actions defined in the action representation + logger.info(f"======executing here======{self.action_space}========================") self.controller.execute_action(action) elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use": if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']): @@ -411,6 +412,8 @@ class DesktopEnv(gym.Env): if type(action) == str: # Fix PyAutoGUI '<' character bug before execution fixed_command = _fix_pyautogui_less_than_bug(action) + logger.info(f"======executing here======{self.action_space}========================") + logger.info(f"Fixed command: {fixed_command}") self.controller.execute_python_command(fixed_command) elif type(action) == dict: # Fix PyAutoGUI '<' character bug before execution diff --git a/lib_run_single.py b/lib_run_single.py index f159fc9..2262ea5 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -13,30 +13,44 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl runtime_logger = setup_logger(example, example_result_dir) # Reset environment first to get fresh VM IP - env.reset(task_config=example) + # env.reset(task_config=example) + # logger.info("=======Environment reset completed=======") - # Reset agent with fresh VM IP (for snapshot reverts) - try: - agent.reset(runtime_logger, vm_ip=env.vm_ip) - except Exception as e: - agent.reset(vm_ip=env.vm_ip) + # # Reset agent with fresh VM IP (for snapshot reverts) + # try: + # agent.reset(runtime_logger, vm_ip=env.vm_ip) + # except Exception as e: + # agent.reset(vm_ip=env.vm_ip) - time.sleep(60) # Wait for the environment to be ready + # time.sleep(10) # Wait for the environment to be ready + + # get initial observation + logger.info("Getting initial observation...") obs = env._get_obs() # Get the initial observation + logger.info("Initial observation obtained.") done = False step_idx = 0 - env.controller.start_recording() + if getattr(args, 'enable_recording', False): + env.controller.start_recording() while not done and step_idx < max_steps: + logger.info(f"Step {step_idx + 1} prediction...") response, actions = agent.predict( instruction, obs ) + logger.info(f"Response: {response}") + logger.info(f"Actions: {actions}") + + logger.info(f"Executing actions...") for action in actions: # Capture the timestamp before executing the action action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f") logger.info("Step %d: %s", step_idx + 1, action) + + logger.info("执行动作中...") obs, reward, done, info = env.step(action, args.sleep_after_execution) - + logger.info("动作执行完成。") + logger.info("Reward: %.2f", reward) logger.info("Done: %s", done) # Save screenshot and trajectory information @@ -69,7 +83,8 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl # Log task completion to results.json log_task_completion(example, result, example_result_dir, args) - env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + if getattr(args, 'enable_recording', False): + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) def setup_logger(example, example_result_dir): diff --git a/quickstart.py b/quickstart.py index b4240d3..009d4ba 100644 --- a/quickstart.py +++ b/quickstart.py @@ -1,5 +1,11 @@ -from desktop_env.desktop_env import DesktopEnv import argparse +import logging + +from desktop_env.desktop_env import DesktopEnv + +logging.basicConfig( + level=logging.INFO, +) example = { "id": "94d95f96-9699-4208-98ba-3c3119edf9c2", diff --git a/run.py b/run.py index 75148e3..e8844f6 100644 --- a/run.py +++ b/run.py @@ -86,6 +86,7 @@ def config() -> argparse.Namespace: parser.add_argument("--screen_height", type=int, default=1080) parser.add_argument("--sleep_after_execution", type=float, default=0.0) parser.add_argument("--max_steps", type=int, default=15) + parser.add_argument("--enable_recording", action="store_true", help="Enable video recording (disabled by default)") # agent config parser.add_argument("--max_trajectory_length", type=int, default=3) @@ -94,10 +95,10 @@ def config() -> argparse.Namespace: ) # lm config - parser.add_argument("--model", type=str, default="gpt-4o") + parser.add_argument("--model", type=str, default="gpt-4-vision-preview") parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--max_tokens", type=int, default=1500) + parser.add_argument("--max_tokens", type=int, default=16384) parser.add_argument("--stop_token", type=str, default=None) # example config @@ -147,6 +148,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: action_space=args.action_space, observation_type=args.observation_type, max_trajectory_length=args.max_trajectory_length, + screen_width=args.screen_width, + screen_height=args.screen_height, ) env = DesktopEnv( @@ -155,11 +158,31 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: action_space=agent.action_space, screen_size=(args.screen_width, args.screen_height), headless=args.headless, - os_type = "Ubuntu", + os_type = "Windows", require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], ) + # get actual VM screen size after environment initialization + try: + actual_screen_size = env.vm_screen_size + if actual_screen_size and 'width' in actual_screen_size and 'height' in actual_screen_size: + actual_width = actual_screen_size['width'] + actual_height = actual_screen_size['height'] + logger.info(f"Actual VM screen size: {actual_width}x{actual_height}") + + # update agent's screen size if different + if actual_width != args.screen_width or actual_height != args.screen_height: + logger.warning(f"Screen size mismatch! Expected: {args.screen_width}x{args.screen_height}, Actual: {actual_width}x{actual_height}") + agent.screen_width = actual_width + agent.screen_height = actual_height + # replace in system message as well + agent.system_message = agent.system_message.replace( + f"({args.screen_width}, {args.screen_height})", + f"({actual_width}, {actual_height})" + ) + except Exception as e: + logger.warning(f"Unable to get actual VM screen size: {e}") for domain in tqdm(test_all_meta, desc="Domain"): for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): config_file = os.path.join( @@ -204,8 +227,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: ) except Exception as e: logger.error(f"Exception in {domain}/{example_id}: {e}") - # Only attempt to end recording if controller exists (not Docker provider) - if hasattr(env, 'controller') and env.controller is not None: + # Only attempt to end recording if controller exists (not Docker provider) and recording is enabled + if args.enable_recording and hasattr(env, 'controller') and env.controller is not None: env.controller.end_recording( os.path.join(example_result_dir, "recording.mp4") ) @@ -217,7 +240,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: ) f.write("\n") - env.close() + # env.close() logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")