Support Docker VM manager and provider (#75)

* Add docker provider framework

* Update VM download link

* Add stop container

* Update docker manager & provider

* Update

* Update

* Update provider
This commit is contained in:
HappySix
2024-09-28 21:10:40 +08:00
committed by GitHub
parent 3b94cb4b74
commit 6419d707bc
10 changed files with 256 additions and 44 deletions

View File

@@ -12,9 +12,10 @@ logger = logging.getLogger("desktopenv.pycontroller")
class PythonController:
def __init__(self, vm_ip: str,
server_port: int,
pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
self.vm_ip = vm_ip
self.http_server = f"http://{vm_ip}:5000"
self.http_server = f"http://{vm_ip}:{server_port}"
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
self.retry_times = 3
self.retry_interval = 5

View File

@@ -28,10 +28,12 @@ FILE_PATH = os.path.dirname(os.path.abspath(__file__))
class SetupController:
def __init__(self, vm_ip: str, cache_dir: str):
def __init__(self, vm_ip: str, server_port: int, chromium_port: int, cache_dir: str):
self.vm_ip: str = vm_ip
self.http_server: str = f"http://{vm_ip}:5000"
self.http_server_setup_root: str = f"http://{vm_ip}:5000/setup"
self.server_port: int = server_port
self.chromium_port: int = chromium_port
self.http_server: str = f"http://{vm_ip}:{server_port}"
self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup"
self.cache_dir: str = cache_dir
def reset_cache_dir(self, cache_dir: str):
@@ -348,7 +350,7 @@ class SetupController:
# Chrome setup
def _chrome_open_tabs_setup(self, urls_to_open: List[str]):
host = self.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = self.chromium_port # fixme: this port is hard-coded, need to be changed from config file
remote_debugging_url = f"http://{host}:{port}"
logger.info("Connect to Chrome @: %s", remote_debugging_url)
@@ -399,7 +401,7 @@ class SetupController:
time.sleep(5) # Wait for Chrome to finish launching
host = self.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = self.server_port # fixme: this port is hard-coded, need to be changed from config file
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:

View File

@@ -26,7 +26,7 @@ class DesktopEnv(gym.Env):
def __init__(
self,
provider_name: str = "vmware",
provider_name: str = "docker",
region: str = None,
path_to_vm: str = None,
snapshot_name: str = "init_state",
@@ -54,6 +54,11 @@ class DesktopEnv(gym.Env):
"""
# Initialize VM manager and vitualization provider
self.region = region
# Default
self.server_port = 5000
self.chromium_port = 9222
self.vnc_port = 8006
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
self.os_type = os_type
@@ -73,8 +78,9 @@ class DesktopEnv(gym.Env):
self.require_terminal = require_terminal
# Initialize emulator and controller
logger.info("Initializing...")
self._start_emulator()
if provider_name != "docker": # Check if this is applicable to other VM providers
logger.info("Initializing...")
self._start_emulator()
# mode: human or machine
self.instruction = None
@@ -92,9 +98,14 @@ class DesktopEnv(gym.Env):
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
# Get the ip from the virtual machine, and setup the controller
self.vm_ip = self.provider.get_ip_address(self.path_to_vm)
self.controller = PythonController(vm_ip=self.vm_ip)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
self.vm_ip = vm_ip_ports[0]
if len(vm_ip_ports) > 1:
self.server_port = int(vm_ip_ports[1])
self.chromium_port = int(vm_ip_ports[2])
self.vnc_port = int(vm_ip_ports[3])
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, cache_dir=self.cache_dir_base)
def _revert_to_snapshot(self):
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm

View File

@@ -54,7 +54,8 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any:
"""
try:
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
server_port = env.server_port
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:
# connect to remote Chrome instance
@@ -68,7 +69,7 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any:
"--remote-debugging-port=1337"
], "shell": False})
headers = {"Content-Type": "application/json"}
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
time.sleep(5)
browser = p.chromium.connect_over_cdp(remote_debugging_url)
@@ -454,7 +455,8 @@ def get_extensions_installed_from_shop(env, config: Dict[str, str]):
def get_page_info(env, config: Dict[str, str]):
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
server_port = env.server_port
url = config["url"]
remote_debugging_url = f"http://{host}:{port}"
@@ -478,7 +480,7 @@ def get_page_info(env, config: Dict[str, str]):
], "shell": False})
headers = {"Content-Type": "application/json"}
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
time.sleep(5)
browser = p.chromium.connect_over_cdp(remote_debugging_url)
@@ -505,7 +507,8 @@ def get_page_info(env, config: Dict[str, str]):
def get_open_tabs_info(env, config: Dict[str, str]):
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
server_port = env.server_port
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:
@@ -528,7 +531,7 @@ def get_open_tabs_info(env, config: Dict[str, str]):
], "shell": False})
headers = {"Content-Type": "application/json"}
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
time.sleep(5)
try:
browser = p.chromium.connect_over_cdp(remote_debugging_url)
@@ -643,7 +646,7 @@ def get_active_tab_info(env, config: Dict[str, str]):
logger.error("Failed to get the url of active tab")
return None
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:
@@ -683,7 +686,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
_path = os.path.join(env.cache_dir, config["dest"])
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
server_port = env.server_port
remote_debugging_url = f"http://{host}:{port}"
@@ -706,7 +710,7 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
], "shell": False})
headers = {"Content-Type": "application/json"}
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
time.sleep(5)
browser = p.chromium.connect_over_cdp(remote_debugging_url)
@@ -721,7 +725,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
# fixme: needs to be changed (maybe through post-processing) since it's not working
def get_chrome_saved_address(env, config: Dict[str, str]):
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
server_port = env.server_port
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:
@@ -744,7 +749,7 @@ def get_chrome_saved_address(env, config: Dict[str, str]):
], "shell": False})
headers = {"Content-Type": "application/json"}
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
time.sleep(5)
browser = p.chromium.connect_over_cdp(remote_debugging_url)
@@ -800,7 +805,8 @@ def get_number_of_search_results(env, config: Dict[str, str]):
# todo: move into the config file
url, result_selector = "https://google.com/search?q=query", '.search-result'
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
server_port = env.server_port
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:
@@ -822,7 +828,7 @@ def get_number_of_search_results(env, config: Dict[str, str]):
], "shell": False})
headers = {"Content-Type": "application/json"}
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
time.sleep(5)
browser = p.chromium.connect_over_cdp(remote_debugging_url)
page = browser.new_page()
@@ -1145,7 +1151,8 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]):
logger.error("active_tab_url is not a string")
return None
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
server_port = env.server_port
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:
@@ -1168,7 +1175,7 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]):
], "shell": False})
headers = {"Content-Type": "application/json"}
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
time.sleep(5)
browser = p.chromium.connect_over_cdp(remote_debugging_url)
target_page = None
@@ -1237,7 +1244,8 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]):
especially used for www.recreation.gov examples
"""
host = env.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
server_port = env.server_port
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:
@@ -1259,7 +1267,7 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]):
], "shell": False})
headers = {"Content-Type": "application/json"}
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
time.sleep(5)
browser = p.chromium.connect_over_cdp(remote_debugging_url)
page = browser.new_page()

View File

@@ -7,7 +7,7 @@ logger = logging.getLogger("desktopenv.getters.general")
def get_vm_command_line(env, config: Dict[str, str]):
vm_ip = env.vm_ip
port = 5000
port = env.server_port
command = config["command"]
shell = config.get("shell", False)
@@ -23,7 +23,7 @@ def get_vm_command_line(env, config: Dict[str, str]):
def get_vm_command_error(env, config: Dict[str, str]):
vm_ip = env.vm_ip
port = 5000
port = env.server_port
command = config["command"]
shell = config.get("shell", False)

View File

@@ -1,12 +1,4 @@
from desktop_env.providers.base import VMManager, Provider
from desktop_env.providers.vmware.manager import VMwareVMManager
from desktop_env.providers.vmware.provider import VMwareProvider
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
def create_vm_manager_and_provider(provider_name: str, region: str):
"""
@@ -14,12 +6,24 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
"""
provider_name = provider_name.lower().strip()
if provider_name == "vmware":
from desktop_env.providers.vmware.manager import VMwareVMManager
from desktop_env.providers.vmware.provider import VMwareProvider
return VMwareVMManager(), VMwareProvider(region)
elif provider_name == "virtualbox":
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]:
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider
return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure":
from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider
return AzureVMManager(), AzureProvider(region)
elif provider_name == "docker":
from desktop_env.providers.docker.manager import DockerVMManager
from desktop_env.providers.docker.provider import DockerProvider
return DockerVMManager(), DockerProvider(region)
else:
raise NotImplementedError(f"{provider_name} not implemented!")

View File

@@ -0,0 +1,123 @@
import os
import platform
import random
import re
import threading
from filelock import FileLock
import uuid
import zipfile
from time import sleep
import shutil
import psutil
import subprocess
import requests
from tqdm import tqdm
import docker
import logging
from desktop_env.providers.base import VMManager
logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager")
logger.setLevel(logging.INFO)
MAX_RETRY_TIMES = 10
RETRY_INTERVAL = 5
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2"
VMS_DIR = "./docker_vm_data"
# Determine the platform and CPU architecture to decide the correct VM image to download
# if platform.system() == 'Darwin': # macOS
# # if os.uname().machine == 'arm64': # Apple Silicon
# URL = UBUNTU_ARM_URL
# # else:
# # url = UBUNTU_X86_URL
# elif platform.machine().lower() in ['amd64', 'x86_64']:
# URL = UBUNTU_X86_URL
# else:
# raise Exception("Unsupported platform or architecture")
URL = UBUNTU_X86_URL
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
if platform.system() == 'Windows':
docker_path = r"C:\Program Files\Docker\Docker"
os.environ["PATH"] += os.pathsep + docker_path
def _download_vm(vms_dir: str):
# Download the virtual machine image
logger.info("Downloading the virtual machine image...")
downloaded_size = 0
URL = UBUNTU_X86_URL
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
downloaded_file_name = DOWNLOADED_FILE_NAME
os.makedirs(vms_dir, exist_ok=True)
while True:
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
headers = {}
if os.path.exists(downloaded_file_path):
downloaded_size = os.path.getsize(downloaded_file_path)
headers["Range"] = f"bytes={downloaded_size}-"
with requests.get(URL, headers=headers, stream=True) as response:
if response.status_code == 416:
# This means the range was not satisfiable, possibly the file was fully downloaded
logger.info("Fully downloaded or the file size changed.")
break
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(downloaded_file_path, "ab") as file, tqdm(
desc="Progress",
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
initial=downloaded_size,
ascii=True
) as progress_bar:
try:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
except (requests.exceptions.RequestException, IOError) as e:
logger.error(f"Download error: {e}")
sleep(RETRY_INTERVAL)
logger.error("Retrying...")
else:
logger.info("Download succeeds.")
break # Download completed successfully
class DockerVMManager(VMManager):
def __init__(self, registry_path=""):
pass
def add_vm(self, vm_path):
pass
def check_and_clean(self):
pass
def delete_vm(self, vm_path):
pass
def initialize_registry(self):
pass
def list_free_vms(self):
return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)
def occupy_vm(self, vm_path):
pass
def get_vm_path(self, os_type, region):
if not os.path.exists(os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)):
_download_vm(VMS_DIR)
return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)

View File

@@ -0,0 +1,67 @@
import logging
import os
import platform
import subprocess
import time
import docker
import psutil
import requests
from desktop_env.providers.base import Provider
logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider")
logger.setLevel(logging.INFO)
WAIT_TIME = 3
RETRY_INTERVAL = 1
class DockerProvider(Provider):
def __init__(self, region: str):
self.client = docker.from_env()
self.vnc_port = self._get_available_port(8006)
self.server_port = self._get_available_port(5000)
# self.remote_debugging_port = self._get_available_port(1337)
self.chromium_port = self._get_available_port(9222)
self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed
@staticmethod
def _get_available_port(port: int):
while port < 65354:
if port not in [conn.laddr.port for conn in psutil.net_connections()]:
return port
port += 1
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}")
self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment, cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={os.path.abspath(path_to_vm): {"bind": "/Ubuntu.qcow2", "mode": "ro"}}, ports={8006: self.vnc_port, 5000: self.server_port, 9222: self.chromium_port}, detach=True)
def download_screenshot(ip, port):
url = f"http://{ip}:{port}/screenshot"
try:
# max trey times 1, max timeout 1
response = requests.get(url, timeout=(10, 10))
if response.status_code == 200:
return True
except Exception as e:
time.sleep(RETRY_INTERVAL)
return False
# Try downloading the screenshot until successful
while not download_screenshot("localhost", self.server_port):
logger.info("Check whether the virtual machine is ready...")
def get_ip_address(self, path_to_vm: str) -> str:
return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}"
def save_state(self, path_to_vm: str, snapshot_name: str):
raise NotImplementedError("Not available for Docker.")
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
pass
def stop_emulator(self, path_to_vm: str):
logger.info("Stopping VM...")
self.container.stop()
self.container.remove()
time.sleep(WAIT_TIME)
# docker run -it --rm -e "DISK_SIZE=64G" -e "RAM_SIZE=8G" -e "CPU_CORES=8" --volume /home/$USER/osworld/docker_vm_data/Ubuntu.qcow2:/Ubuntu.qcow2:ro --cap-add NET_ADMIN --device /dev/kvm -p 8008:8006 -p 5002:5000 happysixd/osworld-docker

View File

@@ -78,7 +78,7 @@ def human_agent():
result = env.evaluate()
logger.info("Result: %.2f", result)
# env.close()
env.close()
logger.info("Environment closed.")

View File

@@ -53,7 +53,3 @@ wrapt_timeout_decorator
gdown
tiktoken
groq
boto3
azure-identity
azure-mgmt-compute
azure-mgmt-network