From 8be2a409678e4446c8ebab9789b24996c42b38f9 Mon Sep 17 00:00:00 2001 From: Dunjie Lu <127488745+ludunjie1219@users.noreply.github.com> Date: Sat, 2 Nov 2024 22:28:23 +0800 Subject: [PATCH] Docker (#92) * multi_env * multi_env --------- Co-authored-by: Timothyxxx <384084775@qq.com> --- .gitignore | 6 + desktop_env/desktop_env.py | 16 +- desktop_env/providers/__init__.py | 2 + desktop_env/providers/docker/provider.py | 146 ++++++--- run.py | 3 +- run_multienv.py | 360 +++++++++++++++++++++++ show_result.py | 2 +- 7 files changed, 493 insertions(+), 42 deletions(-) create mode 100644 run_multienv.py diff --git a/.gitignore b/.gitignore index cb02398..cae53ca 100644 --- a/.gitignore +++ b/.gitignore @@ -187,3 +187,9 @@ test2.xlsx # vm info .vms /vm_data +docker_vm_data +vmware_vm_data +.vmware* + +# result +**/result*/**/* diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index cf05d03..d1c9cf3 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -26,7 +26,7 @@ class DesktopEnv(gym.Env): def __init__( self, - provider_name: str = "vmware", + provider_name: str = "docker", region: str = None, path_to_vm: str = None, snapshot_name: str = "init_state", @@ -36,7 +36,7 @@ class DesktopEnv(gym.Env): headless: bool = False, require_a11y_tree: bool = True, require_terminal: bool = False, - os_type: str = "Ubuntu", + os_type: str = "Windows", ): """ Args: @@ -60,6 +60,18 @@ class DesktopEnv(gym.Env): self.chromium_port = 9222 self.vnc_port = 8006 self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) + # self.server_port = server_port or 5000 + # self.chromium_port = chromium_port or 9222 + # self.vnc_port = vnc_port or 8006 + + # # Initialize provider with custom ports + # self.manager, self.provider = create_vm_manager_and_provider( + # provider_name, + # region, + # vnc_port=self.vnc_port, + # server_port=self.server_port, + # chromium_port=self.chromium_port + # ) self.os_type = os_type diff --git a/desktop_env/providers/__init__.py b/desktop_env/providers/__init__.py index 7c95382..5dbde19 100644 --- a/desktop_env/providers/__init__.py +++ b/desktop_env/providers/__init__.py @@ -1,5 +1,6 @@ from desktop_env.providers.base import VMManager, Provider +# def create_vm_manager_and_provider(provider_name: str, region: str, vnc_port: int = None, server_port: int = None, chromium_port: int = None): def create_vm_manager_and_provider(provider_name: str, region: str): """ Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name. @@ -24,6 +25,7 @@ def create_vm_manager_and_provider(provider_name: str, region: str): elif provider_name == "docker": from desktop_env.providers.docker.manager import DockerVMManager from desktop_env.providers.docker.provider import DockerProvider + # return DockerVMManager(), DockerProvider(region, vnc_port, server_port, chromium_port) return DockerVMManager(), DockerProvider(region) else: raise NotImplementedError(f"{provider_name} not implemented!") diff --git a/desktop_env/providers/docker/provider.py b/desktop_env/providers/docker/provider.py index 8300b74..eadfb93 100644 --- a/desktop_env/providers/docker/provider.py +++ b/desktop_env/providers/docker/provider.py @@ -30,58 +30,128 @@ class DockerProvider(Provider): self.chromium_port = None self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed - temp_dir = Path(os.getenv('TEMP' if platform.system() == 'Windows' else '/tmp')) + # temp_dir = Path(os.getenv('TEMP' if platform.system() == 'Windows' else '/tmp')) + temp_dir = Path(os.getenv('TEMP') if platform.system() == 'Windows' else '/tmp') self.lock_file = temp_dir / "docker_port_allocation.lck" self.lock_file.parent.mkdir(parents=True, exist_ok=True) - def _get_available_port(self, port: int, lock_file: Path = None): - if lock_file is None: - lock_file = self.lock_file - lock = FileLock(str(lock_file), timeout=LOCK_TIMEOUT) - with lock: - while port < 65354: - if port not in [conn.laddr.port for conn in psutil.net_connections()]: - return port - port += 1 + def _get_used_ports(self): + """Get all currently used ports (both system and Docker).""" + # Get system ports + system_ports = set(conn.laddr.port for conn in psutil.net_connections()) + + # Get Docker container ports + docker_ports = set() + for container in self.client.containers.list(): + ports = container.attrs['NetworkSettings']['Ports'] + if ports: + for port_mappings in ports.values(): + if port_mappings: + docker_ports.update(int(p['HostPort']) for p in port_mappings) + + return system_ports | docker_ports + + def _get_available_port(self, start_port: int) -> int: + """Find next available port starting from start_port.""" + used_ports = self._get_used_ports() + port = start_port + while port < 65354: + if port not in used_ports: + return port + port += 1 + raise PortAllocationError(f"No available ports found starting from {start_port}") + + def _wait_for_vm_ready(self, timeout: int = 300): + """Wait for VM to be ready by checking screenshot endpoint.""" + start_time = time.time() + + def check_screenshot(): + try: + response = requests.get( + f"http://localhost:{self.server_port}/screenshot", + timeout=(10, 10) + ) + return response.status_code == 200 + except Exception: + return False + + while time.time() - start_time < timeout: + if check_screenshot(): + return True + logger.info("Checking if virtual machine is ready...") + time.sleep(RETRY_INTERVAL) + + raise TimeoutError("VM failed to become ready within timeout period") def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): - self.vnc_port = self._get_available_port(8006) - self.server_port = self._get_available_port(5000) - # self.remote_debugging_port = self._get_available_port(1337) - self.chromium_port = self._get_available_port(9222) - logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}") - self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment, - cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={ - os.path.abspath(path_to_vm): {"bind": "/System.qcow2", "mode": "ro"}}, - ports={8006: self.vnc_port, 5000: self.server_port, - 9222: self.chromium_port}, detach=True) + # Use a single lock for all port allocation and container startup + lock = FileLock(str(self.lock_file), timeout=LOCK_TIMEOUT) + + try: + with lock: + # Allocate all required ports + self.vnc_port = self._get_available_port(8006) + self.server_port = self._get_available_port(5000) + self.chromium_port = self._get_available_port(9222) - def download_screenshot(ip, port): - url = f"http://{ip}:{port}/screenshot" - try: - # max trey times 1, max timeout 1 - response = requests.get(url, timeout=(10, 10)) - if response.status_code == 200: - return True - except Exception as e: - time.sleep(RETRY_INTERVAL) - return False + # Start container while still holding the lock + self.container = self.client.containers.run( + "happysixd/osworld-docker", + environment=self.environment, + cap_add=["NET_ADMIN"], + devices=["/dev/kvm"], + volumes={ + os.path.abspath(path_to_vm): { + "bind": "/System.qcow2", + "mode": "ro" + } + }, + ports={ + 8006: self.vnc_port, + 5000: self.server_port, + 9222: self.chromium_port + }, + detach=True + ) - # Try downloading the screenshot until successful - while not download_screenshot("localhost", self.server_port): - logger.info("Check whether the virtual machine is ready...") + logger.info(f"Started container with ports - VNC: {self.vnc_port}, " + f"Server: {self.server_port}, Chrome: {self.chromium_port}") + + # Wait for VM to be ready + self._wait_for_vm_ready() + + except Exception as e: + # Clean up if anything goes wrong + if self.container: + try: + self.container.stop() + self.container.remove() + except: + pass + raise e def get_ip_address(self, path_to_vm: str) -> str: + if not all([self.server_port, self.chromium_port, self.vnc_port]): + raise RuntimeError("VM not started - ports not allocated") return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}" def save_state(self, path_to_vm: str, snapshot_name: str): - raise NotImplementedError("Not available for Docker.") + raise NotImplementedError("Snapshots not available for Docker provider") def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): pass def stop_emulator(self, path_to_vm: str): - logger.info("Stopping VM...") - self.container.stop() - self.container.remove() - time.sleep(WAIT_TIME) + if self.container: + logger.info("Stopping VM...") + try: + self.container.stop() + self.container.remove() + time.sleep(WAIT_TIME) + except Exception as e: + logger.error(f"Error stopping container: {e}") + finally: + self.container = None + self.server_port = None + self.vnc_port = None + self.chromium_port = None diff --git a/run.py b/run.py index d85bb05..d0d91cd 100644 --- a/run.py +++ b/run.py @@ -91,7 +91,7 @@ def config() -> argparse.Namespace: ) # lm config - parser.add_argument("--model", type=str, default="gpt-4-0125-preview") + parser.add_argument("--model", type=str, default="gpt-4o") parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=1500) @@ -150,6 +150,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: action_space=agent.action_space, screen_size=(args.screen_width, args.screen_height), headless=args.headless, + os_type = "Ubuntu", require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], ) diff --git a/run_multienv.py b/run_multienv.py new file mode 100644 index 0000000..8a8a417 --- /dev/null +++ b/run_multienv.py @@ -0,0 +1,360 @@ +"""Script to run end-to-end evaluation on the benchmark. +Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. +""" + +import argparse +import datetime +import json +import logging +import os +import sys +from typing import List, Dict +import math +from tqdm import tqdm +from multiprocessing import Process, Manager +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.agent import PromptAgent + +# import wandb + + +# Logger Configs {{{ # +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) +sdebug_handler = logging.FileHandler( + os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" +) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(logging.INFO) +sdebug_handler.setLevel(logging.DEBUG) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) +sdebug_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) +sdebug_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +logger.addHandler(sdebug_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="a11y_tree", + help="Observation type", + ) + parser.add_argument("--screen_width", type=int, default=1920) + parser.add_argument("--screen_height", type=int, default=1080) + parser.add_argument("--sleep_after_execution", type=float, default=0.0) + parser.add_argument("--max_steps", type=int, default=15) + + # agent config + parser.add_argument("--max_trajectory_length", type=int, default=3) + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + parser.add_argument("--model", type=str, default="gpt-4o") + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--max_tokens", type=int, default=1500) + parser.add_argument("--stop_token", type=str, default=None) + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + + args = parser.parse_args() + return args + + +def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: + """Distribute tasks evenly across environments.""" + # Flatten the tasks into a single list + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + + # Calculate tasks per environment + tasks_per_env = math.ceil(len(all_tasks) / num_envs) + + # Distribute tasks + distributed_tasks = [] + for i in range(num_envs): + env_tasks = {} + start_idx = i * tasks_per_env + end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) + + for domain, example_id in all_tasks[start_idx:end_idx]: + if domain not in env_tasks: + env_tasks[domain] = [] + env_tasks[domain].append(example_id) + + distributed_tasks.append(env_tasks) + + return distributed_tasks + + + +def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): + """Run tasks for a single environment.""" + logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") + + for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): + for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + + logger.info(f"[Env {env_idx+1}][Domain]: {domain}") + logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") + logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") + + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"Time limit exceeded in {domain}/{example_id}"} + ) + ) + f.write("\n") + + env.close() + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + logger.info("Args: %s", args) + + distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) + + # First, set up all environments + logger.info("Setting up all environments...") + envs = [] + agents = [] + + for env_idx in range(args.num_envs): + logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}") + + agent = PromptAgent( + model=args.model, + max_tokens=args.max_tokens, + top_p=args.top_p, + temperature=args.temperature, + action_space=args.action_space, + observation_type=args.observation_type, + max_trajectory_length=args.max_trajectory_length, + ) + agents.append(agent) + + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=agent.action_space, + screen_size=(args.screen_width, args.screen_height), + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type + in ["a11y_tree", "screenshot_a11y_tree", "som"], + ) + envs.append(env) + + logger.info("All environments are ready. Starting parallel task execution...") + + # Create a shared list for scores across processes + with Manager() as manager: + shared_scores = manager.list() + + # Create and start processes for each environment + processes = [] + for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): + p = Process( + target=run_env_tasks, + args=(env_idx, env, agent, env_tasks, args, shared_scores) + ) + processes.append(p) + p.start() + + # Wait for all processes to complete + for p in processes: + p.join() + + # Convert shared list to regular list + scores = list(shared_scores) + + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + args = config() + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) diff --git a/show_result.py b/show_result.py index 241d0d9..36563d7 100644 --- a/show_result.py +++ b/show_result.py @@ -68,4 +68,4 @@ def get_result(action_space, use_model, observation_type, result_dir): if __name__ == '__main__': - get_result("pyautogui", "gpt-4-vision-preview", "screenshot", "./results") + get_result("pyautogui", "gpt-4o", "a11y_tree", "./results")