diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 7dd70b6..b443a4a 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -58,7 +58,8 @@ class DesktopEnv(gym.Env): tmp_dir: str = "tmp", cache_dir: str = "cache", screen_size: Tuple[int] = (1920, 1080), - headless: bool = False + headless: bool = False, + require_a11y_tree: bool = True, ): """ Args: @@ -77,6 +78,7 @@ class DesktopEnv(gym.Env): self.cache_dir_base: str = cache_dir self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM self.headless = headless + self.require_a11y_tree = require_a11y_tree os.makedirs(self.tmp_dir_base, exist_ok=True) @@ -248,7 +250,7 @@ class DesktopEnv(gym.Env): observation = { "screenshot": self._get_obs(), - "accessibility_tree": self.controller.get_accessibility_tree(), + "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None, } return observation @@ -284,7 +286,7 @@ class DesktopEnv(gym.Env): observation = { "screenshot": self._get_obs(), - "accessibility_tree": self.controller.get_accessibility_tree(), + "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None, # "terminal": self.controller.get_terminal_output(), "instruction": self.instruction } diff --git a/desktop_env/evaluators/metrics/__init__.py b/desktop_env/evaluators/metrics/__init__.py index 61bb025..341e138 100644 --- a/desktop_env/evaluators/metrics/__init__.py +++ b/desktop_env/evaluators/metrics/__init__.py @@ -77,6 +77,7 @@ from .general import ( literal_match ) from .gimp import ( + check_structure_sim_resized, check_brightness_decrease_and_structure_sim, check_contrast_increase_and_structure_sim, check_saturation_increase_and_structure_sim, diff --git a/mm_agents/agent.py b/mm_agents/agent.py index f2d4b5c..e9f1147 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -350,7 +350,7 @@ class PromptAgent: # {{{1 if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: base64_image = encode_image(obs["screenshot"]) - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None logger.debug("LINEAR AT: %s", linearized_accessibility_tree) if self.observation_type == "screenshot_a11y_tree": diff --git a/run.py b/run.py index 92e989a..e6f67f9 100644 --- a/run.py +++ b/run.py @@ -95,6 +95,10 @@ def config() -> argparse.Namespace: parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--stop_token", type=str, default=None) + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_all.json") + # logging related parser.add_argument("--result_dir", type=str, default="./results") args = parser.parse_args() @@ -144,6 +148,7 @@ def test( action_space=agent.action_space, screen_size=(args.screen_width, args.screen_height), headless=args.headless, + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], ) for domain in tqdm(test_all_meta, desc="Domain"): @@ -264,9 +269,12 @@ if __name__ == '__main__': os.environ["TOKENIZERS_PARALLELISM"] = "false" args = config() - with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: test_all_meta = json.load(f) + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + test_file_list = get_unfinished( args.action_space, args.model,