Fix GIMP bug; Speedup the environment, when there is not a11y tree needed, we can do no controller.get

This commit is contained in:
Timothyxxx
2024-03-20 22:22:59 +08:00
parent 21e3ce5cba
commit d1e2b12b41
4 changed files with 16 additions and 5 deletions

View File

@@ -58,7 +58,8 @@ class DesktopEnv(gym.Env):
tmp_dir: str = "tmp", tmp_dir: str = "tmp",
cache_dir: str = "cache", cache_dir: str = "cache",
screen_size: Tuple[int] = (1920, 1080), screen_size: Tuple[int] = (1920, 1080),
headless: bool = False headless: bool = False,
require_a11y_tree: bool = True,
): ):
""" """
Args: Args:
@@ -77,6 +78,7 @@ class DesktopEnv(gym.Env):
self.cache_dir_base: str = cache_dir self.cache_dir_base: str = cache_dir
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
self.headless = headless self.headless = headless
self.require_a11y_tree = require_a11y_tree
os.makedirs(self.tmp_dir_base, exist_ok=True) os.makedirs(self.tmp_dir_base, exist_ok=True)
@@ -248,7 +250,7 @@ class DesktopEnv(gym.Env):
observation = { observation = {
"screenshot": self._get_obs(), "screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree(), "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
} }
return observation return observation
@@ -284,7 +286,7 @@ class DesktopEnv(gym.Env):
observation = { observation = {
"screenshot": self._get_obs(), "screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree(), "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
# "terminal": self.controller.get_terminal_output(), # "terminal": self.controller.get_terminal_output(),
"instruction": self.instruction "instruction": self.instruction
} }

View File

@@ -77,6 +77,7 @@ from .general import (
literal_match literal_match
) )
from .gimp import ( from .gimp import (
check_structure_sim_resized,
check_brightness_decrease_and_structure_sim, check_brightness_decrease_and_structure_sim,
check_contrast_increase_and_structure_sim, check_contrast_increase_and_structure_sim,
check_saturation_increase_and_structure_sim, check_saturation_increase_and_structure_sim,

View File

@@ -350,7 +350,7 @@ class PromptAgent:
# {{{1 # {{{1
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"]) base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
logger.debug("LINEAR AT: %s", linearized_accessibility_tree) logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if self.observation_type == "screenshot_a11y_tree": if self.observation_type == "screenshot_a11y_tree":

10
run.py
View File

@@ -95,6 +95,10 @@ def config() -> argparse.Namespace:
parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--stop_token", type=str, default=None) parser.add_argument("--stop_token", type=str, default=None)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_all.json")
# logging related # logging related
parser.add_argument("--result_dir", type=str, default="./results") parser.add_argument("--result_dir", type=str, default="./results")
args = parser.parse_args() args = parser.parse_args()
@@ -144,6 +148,7 @@ def test(
action_space=agent.action_space, action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height), screen_size=(args.screen_width, args.screen_height),
headless=args.headless, headless=args.headless,
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
) )
for domain in tqdm(test_all_meta, desc="Domain"): for domain in tqdm(test_all_meta, desc="Domain"):
@@ -264,9 +269,12 @@ if __name__ == '__main__':
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config() args = config()
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f) test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished( test_file_list = get_unfinished(
args.action_space, args.action_space,
args.model, args.model,