* multi_env

* multi_env

---------

Co-authored-by: Timothyxxx <384084775@qq.com>
This commit is contained in:
Dunjie Lu
2024-11-02 22:28:23 +08:00
committed by GitHub
parent 3933e0d303
commit 8be2a40967
7 changed files with 493 additions and 42 deletions

6
.gitignore vendored
View File

@@ -187,3 +187,9 @@ test2.xlsx
# vm info # vm info
.vms .vms
/vm_data /vm_data
docker_vm_data
vmware_vm_data
.vmware*
# result
**/result*/**/*

View File

@@ -26,7 +26,7 @@ class DesktopEnv(gym.Env):
def __init__( def __init__(
self, self,
provider_name: str = "vmware", provider_name: str = "docker",
region: str = None, region: str = None,
path_to_vm: str = None, path_to_vm: str = None,
snapshot_name: str = "init_state", snapshot_name: str = "init_state",
@@ -36,7 +36,7 @@ class DesktopEnv(gym.Env):
headless: bool = False, headless: bool = False,
require_a11y_tree: bool = True, require_a11y_tree: bool = True,
require_terminal: bool = False, require_terminal: bool = False,
os_type: str = "Ubuntu", os_type: str = "Windows",
): ):
""" """
Args: Args:
@@ -60,6 +60,18 @@ class DesktopEnv(gym.Env):
self.chromium_port = 9222 self.chromium_port = 9222
self.vnc_port = 8006 self.vnc_port = 8006
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
# self.server_port = server_port or 5000
# self.chromium_port = chromium_port or 9222
# self.vnc_port = vnc_port or 8006
# # Initialize provider with custom ports
# self.manager, self.provider = create_vm_manager_and_provider(
# provider_name,
# region,
# vnc_port=self.vnc_port,
# server_port=self.server_port,
# chromium_port=self.chromium_port
# )
self.os_type = os_type self.os_type = os_type

View File

@@ -1,5 +1,6 @@
from desktop_env.providers.base import VMManager, Provider from desktop_env.providers.base import VMManager, Provider
# def create_vm_manager_and_provider(provider_name: str, region: str, vnc_port: int = None, server_port: int = None, chromium_port: int = None):
def create_vm_manager_and_provider(provider_name: str, region: str): def create_vm_manager_and_provider(provider_name: str, region: str):
""" """
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name. Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
@@ -24,6 +25,7 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
elif provider_name == "docker": elif provider_name == "docker":
from desktop_env.providers.docker.manager import DockerVMManager from desktop_env.providers.docker.manager import DockerVMManager
from desktop_env.providers.docker.provider import DockerProvider from desktop_env.providers.docker.provider import DockerProvider
# return DockerVMManager(), DockerProvider(region, vnc_port, server_port, chromium_port)
return DockerVMManager(), DockerProvider(region) return DockerVMManager(), DockerProvider(region)
else: else:
raise NotImplementedError(f"{provider_name} not implemented!") raise NotImplementedError(f"{provider_name} not implemented!")

View File

@@ -30,58 +30,128 @@ class DockerProvider(Provider):
self.chromium_port = None self.chromium_port = None
self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed
temp_dir = Path(os.getenv('TEMP' if platform.system() == 'Windows' else '/tmp')) # temp_dir = Path(os.getenv('TEMP' if platform.system() == 'Windows' else '/tmp'))
temp_dir = Path(os.getenv('TEMP') if platform.system() == 'Windows' else '/tmp')
self.lock_file = temp_dir / "docker_port_allocation.lck" self.lock_file = temp_dir / "docker_port_allocation.lck"
self.lock_file.parent.mkdir(parents=True, exist_ok=True) self.lock_file.parent.mkdir(parents=True, exist_ok=True)
def _get_available_port(self, port: int, lock_file: Path = None): def _get_used_ports(self):
if lock_file is None: """Get all currently used ports (both system and Docker)."""
lock_file = self.lock_file # Get system ports
lock = FileLock(str(lock_file), timeout=LOCK_TIMEOUT) system_ports = set(conn.laddr.port for conn in psutil.net_connections())
with lock:
while port < 65354: # Get Docker container ports
if port not in [conn.laddr.port for conn in psutil.net_connections()]: docker_ports = set()
return port for container in self.client.containers.list():
port += 1 ports = container.attrs['NetworkSettings']['Ports']
if ports:
for port_mappings in ports.values():
if port_mappings:
docker_ports.update(int(p['HostPort']) for p in port_mappings)
return system_ports | docker_ports
def _get_available_port(self, start_port: int) -> int:
"""Find next available port starting from start_port."""
used_ports = self._get_used_ports()
port = start_port
while port < 65354:
if port not in used_ports:
return port
port += 1
raise PortAllocationError(f"No available ports found starting from {start_port}")
def _wait_for_vm_ready(self, timeout: int = 300):
"""Wait for VM to be ready by checking screenshot endpoint."""
start_time = time.time()
def check_screenshot():
try:
response = requests.get(
f"http://localhost:{self.server_port}/screenshot",
timeout=(10, 10)
)
return response.status_code == 200
except Exception:
return False
while time.time() - start_time < timeout:
if check_screenshot():
return True
logger.info("Checking if virtual machine is ready...")
time.sleep(RETRY_INTERVAL)
raise TimeoutError("VM failed to become ready within timeout period")
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
self.vnc_port = self._get_available_port(8006) # Use a single lock for all port allocation and container startup
self.server_port = self._get_available_port(5000) lock = FileLock(str(self.lock_file), timeout=LOCK_TIMEOUT)
# self.remote_debugging_port = self._get_available_port(1337)
self.chromium_port = self._get_available_port(9222) try:
logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}") with lock:
self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment, # Allocate all required ports
cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={ self.vnc_port = self._get_available_port(8006)
os.path.abspath(path_to_vm): {"bind": "/System.qcow2", "mode": "ro"}}, self.server_port = self._get_available_port(5000)
ports={8006: self.vnc_port, 5000: self.server_port, self.chromium_port = self._get_available_port(9222)
9222: self.chromium_port}, detach=True)
def download_screenshot(ip, port): # Start container while still holding the lock
url = f"http://{ip}:{port}/screenshot" self.container = self.client.containers.run(
try: "happysixd/osworld-docker",
# max trey times 1, max timeout 1 environment=self.environment,
response = requests.get(url, timeout=(10, 10)) cap_add=["NET_ADMIN"],
if response.status_code == 200: devices=["/dev/kvm"],
return True volumes={
except Exception as e: os.path.abspath(path_to_vm): {
time.sleep(RETRY_INTERVAL) "bind": "/System.qcow2",
return False "mode": "ro"
}
},
ports={
8006: self.vnc_port,
5000: self.server_port,
9222: self.chromium_port
},
detach=True
)
# Try downloading the screenshot until successful logger.info(f"Started container with ports - VNC: {self.vnc_port}, "
while not download_screenshot("localhost", self.server_port): f"Server: {self.server_port}, Chrome: {self.chromium_port}")
logger.info("Check whether the virtual machine is ready...")
# Wait for VM to be ready
self._wait_for_vm_ready()
except Exception as e:
# Clean up if anything goes wrong
if self.container:
try:
self.container.stop()
self.container.remove()
except:
pass
raise e
def get_ip_address(self, path_to_vm: str) -> str: def get_ip_address(self, path_to_vm: str) -> str:
if not all([self.server_port, self.chromium_port, self.vnc_port]):
raise RuntimeError("VM not started - ports not allocated")
return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}" return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}"
def save_state(self, path_to_vm: str, snapshot_name: str): def save_state(self, path_to_vm: str, snapshot_name: str):
raise NotImplementedError("Not available for Docker.") raise NotImplementedError("Snapshots not available for Docker provider")
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
pass pass
def stop_emulator(self, path_to_vm: str): def stop_emulator(self, path_to_vm: str):
logger.info("Stopping VM...") if self.container:
self.container.stop() logger.info("Stopping VM...")
self.container.remove() try:
time.sleep(WAIT_TIME) self.container.stop()
self.container.remove()
time.sleep(WAIT_TIME)
except Exception as e:
logger.error(f"Error stopping container: {e}")
finally:
self.container = None
self.server_port = None
self.vnc_port = None
self.chromium_port = None

3
run.py
View File

@@ -91,7 +91,7 @@ def config() -> argparse.Namespace:
) )
# lm config # lm config
parser.add_argument("--model", type=str, default="gpt-4-0125-preview") parser.add_argument("--model", type=str, default="gpt-4o")
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--max_tokens", type=int, default=1500)
@@ -150,6 +150,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
action_space=agent.action_space, action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height), screen_size=(args.screen_width, args.screen_height),
headless=args.headless, headless=args.headless,
os_type = "Ubuntu",
require_a11y_tree=args.observation_type require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"], in ["a11y_tree", "screenshot_a11y_tree", "som"],
) )

360
run_multienv.py Normal file
View File

@@ -0,0 +1,360 @@
"""Script to run end-to-end evaluation on the benchmark.
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import argparse
import datetime
import json
import logging
import os
import sys
from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent
# import wandb
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="a11y_tree",
help="Observation type",
)
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="gpt-4o")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--stop_token", type=str, default=None)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
args = parser.parse_args()
return args
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
"""Distribute tasks evenly across environments."""
# Flatten the tasks into a single list
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
# Calculate tasks per environment
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
# Distribute tasks
distributed_tasks = []
for i in range(num_envs):
env_tasks = {}
start_idx = i * tasks_per_env
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
for domain, example_id in all_tasks[start_idx:end_idx]:
if domain not in env_tasks:
env_tasks[domain] = []
env_tasks[domain].append(example_id)
distributed_tasks.append(env_tasks)
return distributed_tasks
def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
"""Run tasks for a single environment."""
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
)
)
f.write("\n")
env.close()
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
logger.info("Args: %s", args)
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
# First, set up all environments
logger.info("Setting up all environments...")
envs = []
agents = []
for env_idx in range(args.num_envs):
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
agent = PromptAgent(
model=args.model,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
)
agents.append(agent)
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height),
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
)
envs.append(env)
logger.info("All environments are ready. Starting parallel task execution...")
# Create a shared list for scores across processes
with Manager() as manager:
shared_scores = manager.list()
# Create and start processes for each environment
processes = []
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
p = Process(
target=run_env_tasks,
args=(env_idx, env, agent, env_tasks, args, shared_scores)
)
processes.append(p)
p.start()
# Wait for all processes to complete
for p in processes:
p.join()
# Convert shared list to regular list
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)

View File

@@ -68,4 +68,4 @@ def get_result(action_space, use_model, observation_type, result_dir):
if __name__ == '__main__': if __name__ == '__main__':
get_result("pyautogui", "gpt-4-vision-preview", "screenshot", "./results") get_result("pyautogui", "gpt-4o", "a11y_tree", "./results")