diff --git a/fake_run_single.py b/fake_run_single.py new file mode 100644 index 0000000..4e90ea8 --- /dev/null +++ b/fake_run_single.py @@ -0,0 +1,65 @@ +import datetime +import json +import logging +import os +import time +from wrapt_timeout_decorator import * + +logger = logging.getLogger("desktopenv.experiment") + + +def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): + runtime_logger = setup_logger(example, example_result_dir) + agent.reset(runtime_logger) + env.reset(task_config=example) + # time.sleep(60) # Wait for the environment to be ready + obs = env._get_obs() # Get the initial observation + done = False + step_idx = 0 + env.controller.start_recording() + while not done and step_idx < max_steps: + response, actions = agent.predict( + instruction, + obs + ) + + for action in actions: + # Capture the timestamp before executing the action + action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + logger.info("Step %d: %s", step_idx + 1, action) + obs, reward, done, info = env.step(action, args.sleep_after_execution) + + logger.info("Reward: %.2f", reward) + logger.info("Done: %s", done) + # Save screenshot and trajectory information + with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), + "wb") as _f: + _f.write(obs['screenshot']) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "step_num": step_idx + 1, + "action_timestamp": action_timestamp, + "action": action, + "reward": reward, + "done": done, + "info": info, + "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" + })) + f.write("\n") + if done: + logger.info("The episode is done.") + break + step_idx += 1 + result = env.evaluate() + logger.info("Result: %.2f", result) + scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + + +def setup_logger(example, example_result_dir): + runtime_logger = logging.getLogger(f"desktopenv.example.{example['id']}") + runtime_logger.setLevel(logging.DEBUG) + runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log"))) + return runtime_logger diff --git a/run_test_env.py b/run_test_env.py new file mode 100644 index 0000000..feba441 --- /dev/null +++ b/run_test_env.py @@ -0,0 +1,374 @@ +"""Script to run end-to-end evaluation on the benchmark. +Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. +""" + +import argparse +import datetime +import json +import logging +import os +import sys +from typing import List, Dict +import math +from tqdm import tqdm +from multiprocessing import Process, Manager +import fake_run_single +from test_env import DesktopEnv +from mm_agents.agent import PromptAgent + +# import wandb + + +# Logger Configs {{{ # +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) +sdebug_handler = logging.FileHandler( + os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" +) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(logging.INFO) +sdebug_handler.setLevel(logging.DEBUG) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) +sdebug_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) +sdebug_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +logger.addHandler(sdebug_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="a11y_tree", + help="Observation type", + ) + parser.add_argument("--screen_width", type=int, default=1920) + parser.add_argument("--screen_height", type=int, default=1080) + parser.add_argument("--sleep_after_execution", type=float, default=0.0) + parser.add_argument("--max_steps", type=int, default=15) + + # agent config + parser.add_argument("--max_trajectory_length", type=int, default=3) + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + parser.add_argument("--model", type=str, default="gpt-4o") + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--max_tokens", type=int, default=1500) + parser.add_argument("--stop_token", type=str, default=None) + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + + args = parser.parse_args() + return args + + +def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: + """Distribute tasks evenly across environments.""" + # Flatten the tasks into a single list + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + + # Calculate tasks per environment + tasks_per_env = math.ceil(len(all_tasks) / num_envs) + + # Distribute tasks + distributed_tasks = [] + for i in range(num_envs): + env_tasks = {} + start_idx = i * tasks_per_env + end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) + + for domain, example_id in all_tasks[start_idx:end_idx]: + if domain not in env_tasks: + env_tasks[domain] = [] + env_tasks[domain].append(example_id) + + distributed_tasks.append(env_tasks) + + return distributed_tasks + + + +def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): + """Run tasks for a single environment.""" + logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") + + for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): + for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + + logger.info(f"[Env {env_idx+1}][Domain]: {domain}") + logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") + logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") + + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + + # try: + fake_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + # except Exception as e: + # logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") + # env.controller.end_recording( + # os.path.join(example_result_dir, "recording.mp4") + # ) + # with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + # f.write( + # json.dumps( + # {"Error": f"Time limit exceeded in {domain}/{example_id}"} + # ) + # ) + # f.write("\n") + + env.close() + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + logger.info("Args: %s", args) + + distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) + + # First, set up all environments + logger.info("Setting up all environments...") + envs = [] + agents = [] + + for env_idx in range(args.num_envs): + logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}") + + agent = PromptAgent( + model=args.model, + max_tokens=args.max_tokens, + top_p=args.top_p, + temperature=args.temperature, + action_space=args.action_space, + observation_type=args.observation_type, + max_trajectory_length=args.max_trajectory_length, + ) + agents.append(agent) + + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=agent.action_space, + + provider_name="aws", + region="us-east-1", + snapshot_name="ami-05e7d7bd279ea4f14", + + screen_size=(args.screen_width, args.screen_height), + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type + in ["a11y_tree", "screenshot_a11y_tree", "som"], + ) + envs.append(env) + + logger.info("All environments are ready. Starting parallel task execution...") + + # Create a shared list for scores across processes + with Manager() as manager: + shared_scores = manager.list() + + # Create and start processes for each environment + processes = [] + for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): + p = Process( + target=run_env_tasks, + args=(env_idx, env, agent, env_tasks, args, shared_scores) + ) + processes.append(p) + p.start() + + # Wait for all processes to complete + for p in processes: + p.join() + + # Convert shared list to regular list + scores = list(shared_scores) + + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + args = config() + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + + +# path_to_vm can be a list["xxx","xxx"] \ No newline at end of file diff --git a/test_env/__init__.py b/test_env/__init__.py new file mode 100644 index 0000000..e958d0a --- /dev/null +++ b/test_env/__init__.py @@ -0,0 +1,2 @@ +from .fake_python_controller import PythonController +from .fake_env import DesktopEnv \ No newline at end of file diff --git a/test_env/fake_env.py b/test_env/fake_env.py new file mode 100644 index 0000000..98543c4 --- /dev/null +++ b/test_env/fake_env.py @@ -0,0 +1,128 @@ +from typing import Callable, Any, Optional, Tuple +import os +from test_env import PythonController + + +class DesktopEnv: + def __init__( + self, + action_space: str = "computer_13", + screen_size: Tuple[int] = (1920, 1080), + *args: Any, + **kwargs: Any, + ): + self.obs_options = {} + self._step_no = 0 + self.action_history = [] + self.action_space = action_space + self.resolution = screen_size + self.controller = PythonController() + + + # Load test screenshots and accessibility trees + test_obs_dir = os.path.join(os.path.dirname(__file__), "test_observations") + + self.screenshots = [ + self._load_image(os.path.join(test_obs_dir, "screenshot0.jpg")), + self._load_image(os.path.join(test_obs_dir, "screenshot1.jpg")), + ] + self.accessibility_trees = [ + self._load_accessibility_tree(os.path.join(test_obs_dir, "a11y_tree0.txt")), + self._load_accessibility_tree(os.path.join(test_obs_dir, "a11y_tree1.txt")), + ] + + def _get_screenshot(self): + if self._step_no == 0: + return self.screenshots[0] + return self.screenshots[1] + + def _get_accessibility_tree(self): + if self._step_no == 0: + return self.accessibility_trees[0] + return self.accessibility_trees[1] + + def set_obs_options(self, obs_options): + print(f"Setting obs options to {obs_options}") + self.obs_options = obs_options + + def _load_image(self, image_path): + try: + with open(image_path, "rb") as image_file: + # Read the image file in binary mode + image_data = image_file.read() + # Encode the binary data as Base64 + return image_data + except FileNotFoundError: + print(f"Error: File not found at {image_path}") + except Exception as e: + print(f"An error occurred: {e}") + + def _load_accessibility_tree(self, tree_path): + try: + with open(tree_path, "r") as tree_file: + # Read the accessibility tree file + tree_data = tree_file.read() + return tree_data + except FileNotFoundError: + print(f"Error: File not found at {tree_path}") + except Exception as e: + print(f"An error occurred: {e}") + + def _get_obs(self): + obs = {} + obs["screenshot"] = self._get_screenshot() + obs["accessibility_tree"] = self._get_accessibility_tree() + obs["terminal"] = "" + obs["instruction"] = "Open Chrome browser" + + return obs + + def _start_video_recording(self): + pass + + def _stop_video_recording(self): + pass + + def step(self, action) -> Tuple: + self._step_no += 1 + self.action_history.append(action) + + info = {} + terminated = False # todo: Define episode termination condition for each example + + if action == 'FAIL' or action == 'DONE': + terminated = True + + else: + if self.action_space == "claude_computer_use": + tool_result = { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_01A09q90qw90lq917835lq9", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": self.screenshots[1], + } + } + ] + } + ] + } + info.update({"tool_result": tool_result}) + + return (terminated, info) + + def close(self): + self._step_no = 0 + self.action_history = [] + self.obs_options = {} + self.controller = None + + def reset(self, *args: Any, **kwargs: Any) -> dict: + return self._get_obs() \ No newline at end of file diff --git a/test_env/fake_python_controller.py b/test_env/fake_python_controller.py new file mode 100644 index 0000000..31a0803 --- /dev/null +++ b/test_env/fake_python_controller.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, Optional + + +class PythonController: + def __init__(self): + pass + + def get_screenshot(self) -> Optional[bytes]: + pass + + def get_accessibility_tree(self) -> Optional[str]: + pass + + def get_terminal_output(self) -> Optional[str]: + pass + + def get_file(self, file_path: str) -> Optional[bytes]: + pass + + def execute_python_command(self, command: str) -> None: + pass + + def execute_action(self, action: Dict[str, Any]): + pass + + # Record video + def start_recording(self): + pass + + def end_recording(self, dest: str): + pass + + # Additional info + def get_vm_platform(self): + pass + + def get_vm_screen_size(self): + pass + + def get_vm_window_size(self, app_class_name: str): + pass + + def get_vm_wallpaper(self): + pass + + def get_vm_desktop_path(self) -> Optional[str]: + pass + + def get_vm_directory_tree(self, path) -> Optional[Dict[str, Any]]: + pass diff --git a/test_env/test_observations/a11y_tree0.txt b/test_env/test_observations/a11y_tree0.txt new file mode 100644 index 0000000..a871d99 --- /dev/null +++ b/test_env/test_observations/a11y_tree0.txt @@ -0,0 +1 @@ +Clear272829300102030405060708091011121314151617181920212223242526272829303101020304050607MinimiseRestoreCloseSearch tabsCloseNew TabBackForwardReloadSide panelYouFinish updateAppsAppsManaged bookmarksManaged bookmarksAll BookmarksAll Bookmarks
GmailGmail
ImagesImages
Web Store
Web Store
Add shortcut
Add shortcut
Customise ChromeCustomise Chrome
CloseCan't update ChromeCloseChrome couldn't update to the latest version, so you're missing out on new features and security fixes.Chrome couldn't update to the latest version, so you're missing out on new features and security fixes.Reinstall Chrome \ No newline at end of file diff --git a/test_env/test_observations/a11y_tree1.txt b/test_env/test_observations/a11y_tree1.txt new file mode 100644 index 0000000..a871d99 --- /dev/null +++ b/test_env/test_observations/a11y_tree1.txt @@ -0,0 +1 @@ +Clear272829300102030405060708091011121314151617181920212223242526272829303101020304050607MinimiseRestoreCloseSearch tabsCloseNew TabBackForwardReloadSide panelYouFinish updateAppsAppsManaged bookmarksManaged bookmarksAll BookmarksAll Bookmarks
GmailGmail
ImagesImages
Web Store
Web Store
Add shortcut
Add shortcut
Customise ChromeCustomise Chrome
CloseCan't update ChromeCloseChrome couldn't update to the latest version, so you're missing out on new features and security fixes.Chrome couldn't update to the latest version, so you're missing out on new features and security fixes.Reinstall Chrome \ No newline at end of file diff --git a/test_env/test_observations/screenshot0.jpg b/test_env/test_observations/screenshot0.jpg new file mode 100644 index 0000000..d58bfe6 Binary files /dev/null and b/test_env/test_observations/screenshot0.jpg differ diff --git a/test_env/test_observations/screenshot1.jpg b/test_env/test_observations/screenshot1.jpg new file mode 100644 index 0000000..d00ebd6 Binary files /dev/null and b/test_env/test_observations/screenshot1.jpg differ