Fix GIMP bug; Speedup the environment, when there is not a11y tree needed, we can do no controller.get
This commit is contained in:
@@ -58,7 +58,8 @@ class DesktopEnv(gym.Env):
|
|||||||
tmp_dir: str = "tmp",
|
tmp_dir: str = "tmp",
|
||||||
cache_dir: str = "cache",
|
cache_dir: str = "cache",
|
||||||
screen_size: Tuple[int] = (1920, 1080),
|
screen_size: Tuple[int] = (1920, 1080),
|
||||||
headless: bool = False
|
headless: bool = False,
|
||||||
|
require_a11y_tree: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -77,6 +78,7 @@ class DesktopEnv(gym.Env):
|
|||||||
self.cache_dir_base: str = cache_dir
|
self.cache_dir_base: str = cache_dir
|
||||||
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
|
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
|
||||||
self.headless = headless
|
self.headless = headless
|
||||||
|
self.require_a11y_tree = require_a11y_tree
|
||||||
|
|
||||||
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
||||||
|
|
||||||
@@ -248,7 +250,7 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
"screenshot": self._get_obs(),
|
"screenshot": self._get_obs(),
|
||||||
"accessibility_tree": self.controller.get_accessibility_tree(),
|
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||||
}
|
}
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
@@ -284,7 +286,7 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
"screenshot": self._get_obs(),
|
"screenshot": self._get_obs(),
|
||||||
"accessibility_tree": self.controller.get_accessibility_tree(),
|
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||||
# "terminal": self.controller.get_terminal_output(),
|
# "terminal": self.controller.get_terminal_output(),
|
||||||
"instruction": self.instruction
|
"instruction": self.instruction
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ from .general import (
|
|||||||
literal_match
|
literal_match
|
||||||
)
|
)
|
||||||
from .gimp import (
|
from .gimp import (
|
||||||
|
check_structure_sim_resized,
|
||||||
check_brightness_decrease_and_structure_sim,
|
check_brightness_decrease_and_structure_sim,
|
||||||
check_contrast_increase_and_structure_sim,
|
check_contrast_increase_and_structure_sim,
|
||||||
check_saturation_increase_and_structure_sim,
|
check_saturation_increase_and_structure_sim,
|
||||||
|
|||||||
@@ -350,7 +350,7 @@ class PromptAgent:
|
|||||||
# {{{1
|
# {{{1
|
||||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||||
base64_image = encode_image(obs["screenshot"])
|
base64_image = encode_image(obs["screenshot"])
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
|
|
||||||
if self.observation_type == "screenshot_a11y_tree":
|
if self.observation_type == "screenshot_a11y_tree":
|
||||||
|
|||||||
10
run.py
10
run.py
@@ -95,6 +95,10 @@ def config() -> argparse.Namespace:
|
|||||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||||
parser.add_argument("--stop_token", type=str, default=None)
|
parser.add_argument("--stop_token", type=str, default=None)
|
||||||
|
|
||||||
|
# example config
|
||||||
|
parser.add_argument("--domain", type=str, default="all")
|
||||||
|
parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_all.json")
|
||||||
|
|
||||||
# logging related
|
# logging related
|
||||||
parser.add_argument("--result_dir", type=str, default="./results")
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -144,6 +148,7 @@ def test(
|
|||||||
action_space=agent.action_space,
|
action_space=agent.action_space,
|
||||||
screen_size=(args.screen_width, args.screen_height),
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
headless=args.headless,
|
headless=args.headless,
|
||||||
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
)
|
)
|
||||||
|
|
||||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||||
@@ -264,9 +269,12 @@ if __name__ == '__main__':
|
|||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
args = config()
|
args = config()
|
||||||
|
|
||||||
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
|
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||||
test_all_meta = json.load(f)
|
test_all_meta = json.load(f)
|
||||||
|
|
||||||
|
if args.domain != "all":
|
||||||
|
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||||
|
|
||||||
test_file_list = get_unfinished(
|
test_file_list = get_unfinished(
|
||||||
args.action_space,
|
args.action_space,
|
||||||
args.model,
|
args.model,
|
||||||
|
|||||||
Reference in New Issue
Block a user