diff --git a/fake_run_single.py b/fake_run_single.py
new file mode 100644
index 0000000..4e90ea8
--- /dev/null
+++ b/fake_run_single.py
@@ -0,0 +1,65 @@
+import datetime
+import json
+import logging
+import os
+import time
+from wrapt_timeout_decorator import *
+
+logger = logging.getLogger("desktopenv.experiment")
+
+
+def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
+ runtime_logger = setup_logger(example, example_result_dir)
+ agent.reset(runtime_logger)
+ env.reset(task_config=example)
+ # time.sleep(60) # Wait for the environment to be ready
+ obs = env._get_obs() # Get the initial observation
+ done = False
+ step_idx = 0
+ env.controller.start_recording()
+ while not done and step_idx < max_steps:
+ response, actions = agent.predict(
+ instruction,
+ obs
+ )
+
+ for action in actions:
+ # Capture the timestamp before executing the action
+ action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
+ logger.info("Step %d: %s", step_idx + 1, action)
+ obs, reward, done, info = env.step(action, args.sleep_after_execution)
+
+ logger.info("Reward: %.2f", reward)
+ logger.info("Done: %s", done)
+ # Save screenshot and trajectory information
+ with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
+ "wb") as _f:
+ _f.write(obs['screenshot'])
+ with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
+ f.write(json.dumps({
+ "step_num": step_idx + 1,
+ "action_timestamp": action_timestamp,
+ "action": action,
+ "reward": reward,
+ "done": done,
+ "info": info,
+ "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
+ }))
+ f.write("\n")
+ if done:
+ logger.info("The episode is done.")
+ break
+ step_idx += 1
+ result = env.evaluate()
+ logger.info("Result: %.2f", result)
+ scores.append(result)
+ with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
+ f.write(f"{result}\n")
+ env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
+
+
+def setup_logger(example, example_result_dir):
+ runtime_logger = logging.getLogger(f"desktopenv.example.{example['id']}")
+ runtime_logger.setLevel(logging.DEBUG)
+ runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log")))
+ return runtime_logger
diff --git a/run_test_env.py b/run_test_env.py
new file mode 100644
index 0000000..feba441
--- /dev/null
+++ b/run_test_env.py
@@ -0,0 +1,374 @@
+"""Script to run end-to-end evaluation on the benchmark.
+Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
+"""
+
+import argparse
+import datetime
+import json
+import logging
+import os
+import sys
+from typing import List, Dict
+import math
+from tqdm import tqdm
+from multiprocessing import Process, Manager
+import fake_run_single
+from test_env import DesktopEnv
+from mm_agents.agent import PromptAgent
+
+# import wandb
+
+
+# Logger Configs {{{ #
+logger = logging.getLogger()
+logger.setLevel(logging.DEBUG)
+
+datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
+
+file_handler = logging.FileHandler(
+ os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
+)
+debug_handler = logging.FileHandler(
+ os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
+)
+stdout_handler = logging.StreamHandler(sys.stdout)
+sdebug_handler = logging.FileHandler(
+ os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
+)
+
+file_handler.setLevel(logging.INFO)
+debug_handler.setLevel(logging.DEBUG)
+stdout_handler.setLevel(logging.INFO)
+sdebug_handler.setLevel(logging.DEBUG)
+
+formatter = logging.Formatter(
+ fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
+)
+file_handler.setFormatter(formatter)
+debug_handler.setFormatter(formatter)
+stdout_handler.setFormatter(formatter)
+sdebug_handler.setFormatter(formatter)
+
+stdout_handler.addFilter(logging.Filter("desktopenv"))
+sdebug_handler.addFilter(logging.Filter("desktopenv"))
+
+logger.addHandler(file_handler)
+logger.addHandler(debug_handler)
+logger.addHandler(stdout_handler)
+logger.addHandler(sdebug_handler)
+# }}} Logger Configs #
+
+logger = logging.getLogger("desktopenv.experiment")
+
+
+def config() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Run end-to-end evaluation on the benchmark"
+ )
+
+ # environment config
+ parser.add_argument("--path_to_vm", type=str, default=None)
+ parser.add_argument(
+ "--headless", action="store_true", help="Run in headless machine"
+ )
+ parser.add_argument(
+ "--action_space", type=str, default="pyautogui", help="Action type"
+ )
+ parser.add_argument(
+ "--observation_type",
+ choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
+ default="a11y_tree",
+ help="Observation type",
+ )
+ parser.add_argument("--screen_width", type=int, default=1920)
+ parser.add_argument("--screen_height", type=int, default=1080)
+ parser.add_argument("--sleep_after_execution", type=float, default=0.0)
+ parser.add_argument("--max_steps", type=int, default=15)
+
+ # agent config
+ parser.add_argument("--max_trajectory_length", type=int, default=3)
+ parser.add_argument(
+ "--test_config_base_dir", type=str, default="evaluation_examples"
+ )
+
+ # lm config
+ parser.add_argument("--model", type=str, default="gpt-4o")
+ parser.add_argument("--temperature", type=float, default=1.0)
+ parser.add_argument("--top_p", type=float, default=0.9)
+ parser.add_argument("--max_tokens", type=int, default=1500)
+ parser.add_argument("--stop_token", type=str, default=None)
+
+ # example config
+ parser.add_argument("--domain", type=str, default="all")
+ parser.add_argument(
+ "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
+ )
+
+ # logging related
+ parser.add_argument("--result_dir", type=str, default="./results")
+ parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
+
+ # aws config
+ parser.add_argument(
+ "--region", type=str, default="us-east-1", help="AWS region for the VM"
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
+ """Distribute tasks evenly across environments."""
+ # Flatten the tasks into a single list
+ all_tasks = []
+ for domain, examples in test_all_meta.items():
+ for example_id in examples:
+ all_tasks.append((domain, example_id))
+
+ # Calculate tasks per environment
+ tasks_per_env = math.ceil(len(all_tasks) / num_envs)
+
+ # Distribute tasks
+ distributed_tasks = []
+ for i in range(num_envs):
+ env_tasks = {}
+ start_idx = i * tasks_per_env
+ end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
+
+ for domain, example_id in all_tasks[start_idx:end_idx]:
+ if domain not in env_tasks:
+ env_tasks[domain] = []
+ env_tasks[domain].append(example_id)
+
+ distributed_tasks.append(env_tasks)
+
+ return distributed_tasks
+
+
+
+def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
+ """Run tasks for a single environment."""
+ logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
+
+ for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
+ for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
+ config_file = os.path.join(
+ args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
+ )
+ with open(config_file, "r", encoding="utf-8") as f:
+ example = json.load(f)
+
+ logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
+ logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
+ logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
+
+ example_result_dir = os.path.join(
+ args.result_dir,
+ args.action_space,
+ args.observation_type,
+ args.model,
+ domain,
+ example_id,
+ )
+ os.makedirs(example_result_dir, exist_ok=True)
+
+ # try:
+ fake_run_single.run_single_example(
+ agent,
+ env,
+ example,
+ args.max_steps,
+ example["instruction"],
+ args,
+ example_result_dir,
+ shared_scores,
+ )
+ # except Exception as e:
+ # logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
+ # env.controller.end_recording(
+ # os.path.join(example_result_dir, "recording.mp4")
+ # )
+ # with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
+ # f.write(
+ # json.dumps(
+ # {"Error": f"Time limit exceeded in {domain}/{example_id}"}
+ # )
+ # )
+ # f.write("\n")
+
+ env.close()
+
+
+def test(args: argparse.Namespace, test_all_meta: dict) -> None:
+ logger.info("Args: %s", args)
+
+ distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
+
+ # First, set up all environments
+ logger.info("Setting up all environments...")
+ envs = []
+ agents = []
+
+ for env_idx in range(args.num_envs):
+ logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
+
+ agent = PromptAgent(
+ model=args.model,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ temperature=args.temperature,
+ action_space=args.action_space,
+ observation_type=args.observation_type,
+ max_trajectory_length=args.max_trajectory_length,
+ )
+ agents.append(agent)
+
+ env = DesktopEnv(
+ path_to_vm=args.path_to_vm,
+ action_space=agent.action_space,
+
+ provider_name="aws",
+ region="us-east-1",
+ snapshot_name="ami-05e7d7bd279ea4f14",
+
+ screen_size=(args.screen_width, args.screen_height),
+ headless=args.headless,
+ os_type="Ubuntu",
+ require_a11y_tree=args.observation_type
+ in ["a11y_tree", "screenshot_a11y_tree", "som"],
+ )
+ envs.append(env)
+
+ logger.info("All environments are ready. Starting parallel task execution...")
+
+ # Create a shared list for scores across processes
+ with Manager() as manager:
+ shared_scores = manager.list()
+
+ # Create and start processes for each environment
+ processes = []
+ for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
+ p = Process(
+ target=run_env_tasks,
+ args=(env_idx, env, agent, env_tasks, args, shared_scores)
+ )
+ processes.append(p)
+ p.start()
+
+ # Wait for all processes to complete
+ for p in processes:
+ p.join()
+
+ # Convert shared list to regular list
+ scores = list(shared_scores)
+
+ logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
+
+
+def get_unfinished(
+ action_space, use_model, observation_type, result_dir, total_file_json
+):
+ target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
+
+ if not os.path.exists(target_dir):
+ return total_file_json
+
+ finished = {}
+ for domain in os.listdir(target_dir):
+ finished[domain] = []
+ domain_path = os.path.join(target_dir, domain)
+ if os.path.isdir(domain_path):
+ for example_id in os.listdir(domain_path):
+ if example_id == "onboard":
+ continue
+ example_path = os.path.join(domain_path, example_id)
+ if os.path.isdir(example_path):
+ if "result.txt" not in os.listdir(example_path):
+ # empty all files under example_id
+ for file in os.listdir(example_path):
+ os.remove(os.path.join(example_path, file))
+ else:
+ finished[domain].append(example_id)
+
+ if not finished:
+ return total_file_json
+
+ for domain, examples in finished.items():
+ if domain in total_file_json:
+ total_file_json[domain] = [
+ x for x in total_file_json[domain] if x not in examples
+ ]
+
+ return total_file_json
+
+
+def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
+ target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
+ if not os.path.exists(target_dir):
+ print("New experiment, no result yet.")
+ return None
+
+ all_result = []
+
+ for domain in os.listdir(target_dir):
+ domain_path = os.path.join(target_dir, domain)
+ if os.path.isdir(domain_path):
+ for example_id in os.listdir(domain_path):
+ example_path = os.path.join(domain_path, example_id)
+ if os.path.isdir(example_path):
+ if "result.txt" in os.listdir(example_path):
+ # empty all files under example_id
+ try:
+ all_result.append(
+ float(
+ open(
+ os.path.join(example_path, "result.txt"), "r"
+ ).read()
+ )
+ )
+ except:
+ all_result.append(0.0)
+
+ if not all_result:
+ print("New experiment, no result yet.")
+ return None
+ else:
+ print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
+ return all_result
+
+
+if __name__ == "__main__":
+ ####### The complete version of the list of examples #######
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+ args = config()
+
+ with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
+ test_all_meta = json.load(f)
+
+ if args.domain != "all":
+ test_all_meta = {args.domain: test_all_meta[args.domain]}
+
+ test_file_list = get_unfinished(
+ args.action_space,
+ args.model,
+ args.observation_type,
+ args.result_dir,
+ test_all_meta,
+ )
+ left_info = ""
+ for domain in test_file_list:
+ left_info += f"{domain}: {len(test_file_list[domain])}\n"
+ logger.info(f"Left tasks:\n{left_info}")
+
+ get_result(
+ args.action_space,
+ args.model,
+ args.observation_type,
+ args.result_dir,
+ test_all_meta,
+ )
+ test(args, test_file_list)
+
+
+# path_to_vm can be a list["xxx","xxx"]
\ No newline at end of file
diff --git a/test_env/__init__.py b/test_env/__init__.py
new file mode 100644
index 0000000..e958d0a
--- /dev/null
+++ b/test_env/__init__.py
@@ -0,0 +1,2 @@
+from .fake_python_controller import PythonController
+from .fake_env import DesktopEnv
\ No newline at end of file
diff --git a/test_env/fake_env.py b/test_env/fake_env.py
new file mode 100644
index 0000000..98543c4
--- /dev/null
+++ b/test_env/fake_env.py
@@ -0,0 +1,128 @@
+from typing import Callable, Any, Optional, Tuple
+import os
+from test_env import PythonController
+
+
+class DesktopEnv:
+ def __init__(
+ self,
+ action_space: str = "computer_13",
+ screen_size: Tuple[int] = (1920, 1080),
+ *args: Any,
+ **kwargs: Any,
+ ):
+ self.obs_options = {}
+ self._step_no = 0
+ self.action_history = []
+ self.action_space = action_space
+ self.resolution = screen_size
+ self.controller = PythonController()
+
+
+ # Load test screenshots and accessibility trees
+ test_obs_dir = os.path.join(os.path.dirname(__file__), "test_observations")
+
+ self.screenshots = [
+ self._load_image(os.path.join(test_obs_dir, "screenshot0.jpg")),
+ self._load_image(os.path.join(test_obs_dir, "screenshot1.jpg")),
+ ]
+ self.accessibility_trees = [
+ self._load_accessibility_tree(os.path.join(test_obs_dir, "a11y_tree0.txt")),
+ self._load_accessibility_tree(os.path.join(test_obs_dir, "a11y_tree1.txt")),
+ ]
+
+ def _get_screenshot(self):
+ if self._step_no == 0:
+ return self.screenshots[0]
+ return self.screenshots[1]
+
+ def _get_accessibility_tree(self):
+ if self._step_no == 0:
+ return self.accessibility_trees[0]
+ return self.accessibility_trees[1]
+
+ def set_obs_options(self, obs_options):
+ print(f"Setting obs options to {obs_options}")
+ self.obs_options = obs_options
+
+ def _load_image(self, image_path):
+ try:
+ with open(image_path, "rb") as image_file:
+ # Read the image file in binary mode
+ image_data = image_file.read()
+ # Encode the binary data as Base64
+ return image_data
+ except FileNotFoundError:
+ print(f"Error: File not found at {image_path}")
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+ def _load_accessibility_tree(self, tree_path):
+ try:
+ with open(tree_path, "r") as tree_file:
+ # Read the accessibility tree file
+ tree_data = tree_file.read()
+ return tree_data
+ except FileNotFoundError:
+ print(f"Error: File not found at {tree_path}")
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+ def _get_obs(self):
+ obs = {}
+ obs["screenshot"] = self._get_screenshot()
+ obs["accessibility_tree"] = self._get_accessibility_tree()
+ obs["terminal"] = ""
+ obs["instruction"] = "Open Chrome browser"
+
+ return obs
+
+ def _start_video_recording(self):
+ pass
+
+ def _stop_video_recording(self):
+ pass
+
+ def step(self, action) -> Tuple:
+ self._step_no += 1
+ self.action_history.append(action)
+
+ info = {}
+ terminated = False # todo: Define episode termination condition for each example
+
+ if action == 'FAIL' or action == 'DONE':
+ terminated = True
+
+ else:
+ if self.action_space == "claude_computer_use":
+ tool_result = {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
+ "content": [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": self.screenshots[1],
+ }
+ }
+ ]
+ }
+ ]
+ }
+ info.update({"tool_result": tool_result})
+
+ return (terminated, info)
+
+ def close(self):
+ self._step_no = 0
+ self.action_history = []
+ self.obs_options = {}
+ self.controller = None
+
+ def reset(self, *args: Any, **kwargs: Any) -> dict:
+ return self._get_obs()
\ No newline at end of file
diff --git a/test_env/fake_python_controller.py b/test_env/fake_python_controller.py
new file mode 100644
index 0000000..31a0803
--- /dev/null
+++ b/test_env/fake_python_controller.py
@@ -0,0 +1,50 @@
+from typing import Any, Dict, Optional
+
+
+class PythonController:
+ def __init__(self):
+ pass
+
+ def get_screenshot(self) -> Optional[bytes]:
+ pass
+
+ def get_accessibility_tree(self) -> Optional[str]:
+ pass
+
+ def get_terminal_output(self) -> Optional[str]:
+ pass
+
+ def get_file(self, file_path: str) -> Optional[bytes]:
+ pass
+
+ def execute_python_command(self, command: str) -> None:
+ pass
+
+ def execute_action(self, action: Dict[str, Any]):
+ pass
+
+ # Record video
+ def start_recording(self):
+ pass
+
+ def end_recording(self, dest: str):
+ pass
+
+ # Additional info
+ def get_vm_platform(self):
+ pass
+
+ def get_vm_screen_size(self):
+ pass
+
+ def get_vm_window_size(self, app_class_name: str):
+ pass
+
+ def get_vm_wallpaper(self):
+ pass
+
+ def get_vm_desktop_path(self) -> Optional[str]:
+ pass
+
+ def get_vm_directory_tree(self, path) -> Optional[Dict[str, Any]]:
+ pass
diff --git a/test_env/test_observations/a11y_tree0.txt b/test_env/test_observations/a11y_tree0.txt
new file mode 100644
index 0000000..a871d99
--- /dev/null
+++ b/test_env/test_observations/a11y_tree0.txt
@@ -0,0 +1 @@
+Clear272829300102030405060708091011121314151617181920212223242526272829303101020304050607MinimiseRestoreCloseSearch tabsCloseNew TabBackForwardReloadSide panelYouFinish updateAppsAppsManaged bookmarksManaged bookmarksAll BookmarksAll BookmarksCustomise ChromeCustomise ChromeClose
\ No newline at end of file
diff --git a/test_env/test_observations/a11y_tree1.txt b/test_env/test_observations/a11y_tree1.txt
new file mode 100644
index 0000000..a871d99
--- /dev/null
+++ b/test_env/test_observations/a11y_tree1.txt
@@ -0,0 +1 @@
+Clear272829300102030405060708091011121314151617181920212223242526272829303101020304050607MinimiseRestoreCloseSearch tabsCloseNew TabBackForwardReloadSide panelYouFinish updateAppsAppsManaged bookmarksManaged bookmarksAll BookmarksAll BookmarksCustomise ChromeCustomise ChromeClose
\ No newline at end of file
diff --git a/test_env/test_observations/screenshot0.jpg b/test_env/test_observations/screenshot0.jpg
new file mode 100644
index 0000000..d58bfe6
Binary files /dev/null and b/test_env/test_observations/screenshot0.jpg differ
diff --git a/test_env/test_observations/screenshot1.jpg b/test_env/test_observations/screenshot1.jpg
new file mode 100644
index 0000000..d00ebd6
Binary files /dev/null and b/test_env/test_observations/screenshot1.jpg differ