From d0b37f0831e8a55b3718d490069c06735955f9fa Mon Sep 17 00:00:00 2001 From: FredWuCZ Date: Sat, 28 Sep 2024 12:49:29 +0800 Subject: [PATCH] Update --- desktop_env/controllers/setup.py | 8 +++-- desktop_env/desktop_env.py | 9 ++--- desktop_env/evaluators/getters/chrome.py | 42 ++++++++++++++--------- desktop_env/evaluators/getters/general.py | 2 +- desktop_env/providers/docker/manager.py | 34 +++++++++++++----- desktop_env/providers/docker/provider.py | 7 ++-- main.py | 2 +- requirements.txt | 4 --- 8 files changed, 67 insertions(+), 41 deletions(-) diff --git a/desktop_env/controllers/setup.py b/desktop_env/controllers/setup.py index dfac4b3..f03e9cd 100644 --- a/desktop_env/controllers/setup.py +++ b/desktop_env/controllers/setup.py @@ -28,8 +28,10 @@ FILE_PATH = os.path.dirname(os.path.abspath(__file__)) class SetupController: - def __init__(self, vm_ip: str, server_port: int, cache_dir: str): + def __init__(self, vm_ip: str, server_port: int, chromium_port: int, cache_dir: str): self.vm_ip: str = vm_ip + self.server_port: int = server_port + self.chromium_port: int = chromium_port self.http_server: str = f"http://{vm_ip}:{server_port}" self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup" self.cache_dir: str = cache_dir @@ -348,7 +350,7 @@ class SetupController: # Chrome setup def _chrome_open_tabs_setup(self, urls_to_open: List[str]): host = self.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = self.chromium_port # fixme: this port is hard-coded, need to be changed from config file remote_debugging_url = f"http://{host}:{port}" logger.info("Connect to Chrome @: %s", remote_debugging_url) @@ -399,7 +401,7 @@ class SetupController: time.sleep(5) # Wait for Chrome to finish launching host = self.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = self.server_port # fixme: this port is hard-coded, need to be changed from config file remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index a03ebb2..9b24d05 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -78,8 +78,9 @@ class DesktopEnv(gym.Env): self.require_terminal = require_terminal # Initialize emulator and controller - logger.info("Initializing...") - self._start_emulator() + if provider_name != "docker": # Check if this is applicable to other VM providers + logger.info("Initializing...") + self._start_emulator() # mode: human or machine self.instruction = None @@ -103,8 +104,8 @@ class DesktopEnv(gym.Env): self.server_port = int(vm_ip_ports[1]) self.chromium_port = int(vm_ip_ports[2]) self.vnc_port = int(vm_ip_ports[3]) - self.controller = PythonController(vm_ip=self.vm_ip) - self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, cache_dir=self.cache_dir_base) + self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port) + self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, cache_dir=self.cache_dir_base) def _revert_to_snapshot(self): # Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm diff --git a/desktop_env/evaluators/getters/chrome.py b/desktop_env/evaluators/getters/chrome.py index 94c0dda..ba384de 100644 --- a/desktop_env/evaluators/getters/chrome.py +++ b/desktop_env/evaluators/getters/chrome.py @@ -54,7 +54,8 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any: """ try: host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: # connect to remote Chrome instance @@ -68,7 +69,7 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any: "--remote-debugging-port=1337" ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -454,7 +455,8 @@ def get_extensions_installed_from_shop(env, config: Dict[str, str]): def get_page_info(env, config: Dict[str, str]): host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port url = config["url"] remote_debugging_url = f"http://{host}:{port}" @@ -478,7 +480,7 @@ def get_page_info(env, config: Dict[str, str]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -505,7 +507,8 @@ def get_page_info(env, config: Dict[str, str]): def get_open_tabs_info(env, config: Dict[str, str]): host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -528,7 +531,7 @@ def get_open_tabs_info(env, config: Dict[str, str]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) try: browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -643,7 +646,7 @@ def get_active_tab_info(env, config: Dict[str, str]): logger.error("Failed to get the url of active tab") return None host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -683,7 +686,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str: _path = os.path.join(env.cache_dir, config["dest"]) host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" @@ -706,7 +710,7 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str: ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -721,7 +725,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str: # fixme: needs to be changed (maybe through post-processing) since it's not working def get_chrome_saved_address(env, config: Dict[str, str]): host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -744,7 +749,7 @@ def get_chrome_saved_address(env, config: Dict[str, str]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) @@ -800,7 +805,8 @@ def get_number_of_search_results(env, config: Dict[str, str]): # todo: move into the config file url, result_selector = "https://google.com/search?q=query", '.search-result' host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -822,7 +828,7 @@ def get_number_of_search_results(env, config: Dict[str, str]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) page = browser.new_page() @@ -1145,7 +1151,8 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]): logger.error("active_tab_url is not a string") return None host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -1168,7 +1175,7 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) target_page = None @@ -1237,7 +1244,8 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]): especially used for www.recreation.gov examples """ host = env.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file + server_port = env.server_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -1259,7 +1267,7 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]): ], "shell": False}) headers = {"Content-Type": "application/json"} - requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload) + requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload) time.sleep(5) browser = p.chromium.connect_over_cdp(remote_debugging_url) page = browser.new_page() diff --git a/desktop_env/evaluators/getters/general.py b/desktop_env/evaluators/getters/general.py index d5d965c..2f5ed32 100644 --- a/desktop_env/evaluators/getters/general.py +++ b/desktop_env/evaluators/getters/general.py @@ -7,7 +7,7 @@ logger = logging.getLogger("desktopenv.getters.general") def get_vm_command_line(env, config: Dict[str, str]): vm_ip = env.vm_ip - port = 5000 + port = env.server_port command = config["command"] shell = config.get("shell", False) diff --git a/desktop_env/providers/docker/manager.py b/desktop_env/providers/docker/manager.py index 36dfbf1..e4b09a1 100644 --- a/desktop_env/providers/docker/manager.py +++ b/desktop_env/providers/docker/manager.py @@ -25,7 +25,9 @@ logger.setLevel(logging.INFO) MAX_RETRY_TIMES = 10 RETRY_INTERVAL = 5 -UBUNTU_X86_URL = "docker-osworld-x86" + +UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2" +VMS_DIR = "./docker_vm_data" # Determine the platform and CPU architecture to decide the correct VM image to download # if platform.system() == 'Darwin': # macOS @@ -37,7 +39,6 @@ UBUNTU_X86_URL = "docker-osworld-x86" # URL = UBUNTU_X86_URL # else: # raise Exception("Unsupported platform or architecture") - URL = UBUNTU_X86_URL DOWNLOADED_FILE_NAME = URL.split('/')[-1] @@ -46,10 +47,7 @@ if platform.system() == 'Windows': docker_path = r"C:\Program Files\Docker\Docker" os.environ["PATH"] += os.pathsep + docker_path -UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2" -VMS_DIR = "./vmware_vm_data" - -def __download_vm(vms_dir: str): +def _download_vm(vms_dir: str): # Download the virtual machine image logger.info("Downloading the virtual machine image...") downloaded_size = 0 @@ -58,6 +56,8 @@ def __download_vm(vms_dir: str): DOWNLOADED_FILE_NAME = URL.split('/')[-1] downloaded_file_name = DOWNLOADED_FILE_NAME + os.makedirs(vms_dir, exist_ok=True) + while True: downloaded_file_path = os.path.join(vms_dir, downloaded_file_name) headers = {} @@ -99,7 +99,25 @@ class DockerVMManager(VMManager): def __init__(self, registry_path=""): pass - def get_vm_path(self, region): + def add_vm(self, vm_path): + pass + + def check_and_clean(self): + pass + + def delete_vm(self, vm_path): + pass + + def initialize_registry(self): + pass + + def list_free_vms(self): + return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME) + + def occupy_vm(self, vm_path): + pass + + def get_vm_path(self, os_type, region): if not os.path.exists(os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)): - __download_vm(VMS_DIR) + _download_vm(VMS_DIR) return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME) \ No newline at end of file diff --git a/desktop_env/providers/docker/provider.py b/desktop_env/providers/docker/provider.py index 6920a9e..9d09d69 100644 --- a/desktop_env/providers/docker/provider.py +++ b/desktop_env/providers/docker/provider.py @@ -20,7 +20,7 @@ class DockerProvider(Provider): self.server_port = self._get_available_port(5000) # self.remote_debugging_port = self._get_available_port(1337) self.chromium_port = self._get_available_port(9222) - self.environment = {"DISK_SIZE": "64G", "RAM_SIZE": "4G", "CPU_CORES": "2"} # Modify if needed + self.environment = {"DISK_SIZE": "64G", "RAM_SIZE": "8G", "CPU_CORES": "8"} # Modify if needed @staticmethod def _get_available_port(port: int): @@ -30,8 +30,9 @@ class DockerProvider(Provider): port += 1 def start_emulator(self, path_to_vm: str, headless: bool, os_type: str): + logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}") # self.container = self.client.containers.run('qemux/qemu-docker', environment={"DISK_SIZE": "64G", "RAM_SIZE": "6G", "CPU_CORES": "8"}, volumes={"/Users/happysix/Programs/HKUNLP/Qemu/Ubuntu.qcow2": {"bind": "/Ubuntu.qcow2", "mode": "ro"}, "/Users/happysix/Programs/HKUNLP/Qemu/snapshot.qcow2": {"bind": "/boot.qcow2", "mode": "rw"}}, cap_add=["NET_ADMIN"], ports={8006: self.vnc_port, 5000: self.server_port}, detach=True) - self.container = self.client.containers.run(path_to_vm, environment=self.environment, cap_add=["NET_ADMIN"], volumes={"/Users/happysix/Programs/HKUNLP/Qemu/Ubuntu.qcow2": {"bind": "/Ubuntu.qcow2", "mode": "ro"}}, ports={8006: self.vnc_port, 5000: self.server_port}, detach=True) + self.container = self.client.containers.run("osworld-docker", environment=self.environment, cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={os.path.abspath(path_to_vm): {"bind": "/Ubuntu.qcow2", "mode": "ro"}}, ports={8006: self.vnc_port, 5000: self.server_port}, detach=True) def get_ip_address(self, path_to_vm: str) -> str: return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}" @@ -47,4 +48,4 @@ class DockerProvider(Provider): self.container.stop(WAIT_TIME) self.container.remove() -# docker run -it --rm -e "DISK_SIZE=64G" -e "RAM_SIZE=8G" -e "CPU_CORES=8" --volume C:\Users\admin\Documents\Ubuntu.qcow2:/boot.qcow2 --cap-add NET_ADMIN --device /dev/kvm -p 8006:8006 -p 5000:5000 qemux/qemu-docker \ No newline at end of file +# docker run -it --rm -e "DISK_SIZE=64G" -e "RAM_SIZE=8G" -e "CPU_CORES=8" --volume C:\Users\admin\Documents\OSWorld\docker_vm_data\Ubuntu.qcow2:/Ubuntu.qcow2:ro --cap-add NET_ADMIN --device /dev/kvm -p 8006:8006 -p 5000:5000 osworld-docker \ No newline at end of file diff --git a/main.py b/main.py index 4c6817f..6f03227 100644 --- a/main.py +++ b/main.py @@ -78,7 +78,7 @@ def human_agent(): result = env.evaluate() logger.info("Result: %.2f", result) - # env.close() + env.close() logger.info("Environment closed.") diff --git a/requirements.txt b/requirements.txt index 439f85b..2cc96ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,7 +53,3 @@ wrapt_timeout_decorator gdown tiktoken groq -boto3 -azure-identity -azure-mgmt-compute -azure-mgmt-network