Support Docker VM manager and provider (#75)
* Add docker provider framework * Update VM download link * Add stop container * Update docker manager & provider * Update * Update * Update provider
This commit is contained in:
@@ -12,9 +12,10 @@ logger = logging.getLogger("desktopenv.pycontroller")
|
||||
|
||||
class PythonController:
|
||||
def __init__(self, vm_ip: str,
|
||||
server_port: int,
|
||||
pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
|
||||
self.vm_ip = vm_ip
|
||||
self.http_server = f"http://{vm_ip}:5000"
|
||||
self.http_server = f"http://{vm_ip}:{server_port}"
|
||||
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
|
||||
self.retry_times = 3
|
||||
self.retry_interval = 5
|
||||
|
||||
@@ -28,10 +28,12 @@ FILE_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class SetupController:
|
||||
def __init__(self, vm_ip: str, cache_dir: str):
|
||||
def __init__(self, vm_ip: str, server_port: int, chromium_port: int, cache_dir: str):
|
||||
self.vm_ip: str = vm_ip
|
||||
self.http_server: str = f"http://{vm_ip}:5000"
|
||||
self.http_server_setup_root: str = f"http://{vm_ip}:5000/setup"
|
||||
self.server_port: int = server_port
|
||||
self.chromium_port: int = chromium_port
|
||||
self.http_server: str = f"http://{vm_ip}:{server_port}"
|
||||
self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup"
|
||||
self.cache_dir: str = cache_dir
|
||||
|
||||
def reset_cache_dir(self, cache_dir: str):
|
||||
@@ -348,7 +350,7 @@ class SetupController:
|
||||
# Chrome setup
|
||||
def _chrome_open_tabs_setup(self, urls_to_open: List[str]):
|
||||
host = self.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = self.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
logger.info("Connect to Chrome @: %s", remote_debugging_url)
|
||||
@@ -399,7 +401,7 @@ class SetupController:
|
||||
time.sleep(5) # Wait for Chrome to finish launching
|
||||
|
||||
host = self.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = self.server_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
with sync_playwright() as p:
|
||||
|
||||
@@ -26,7 +26,7 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str = "vmware",
|
||||
provider_name: str = "docker",
|
||||
region: str = None,
|
||||
path_to_vm: str = None,
|
||||
snapshot_name: str = "init_state",
|
||||
@@ -54,6 +54,11 @@ class DesktopEnv(gym.Env):
|
||||
"""
|
||||
# Initialize VM manager and vitualization provider
|
||||
self.region = region
|
||||
|
||||
# Default
|
||||
self.server_port = 5000
|
||||
self.chromium_port = 9222
|
||||
self.vnc_port = 8006
|
||||
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
||||
|
||||
self.os_type = os_type
|
||||
@@ -73,8 +78,9 @@ class DesktopEnv(gym.Env):
|
||||
self.require_terminal = require_terminal
|
||||
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
if provider_name != "docker": # Check if this is applicable to other VM providers
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
@@ -92,9 +98,14 @@ class DesktopEnv(gym.Env):
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
||||
|
||||
# Get the ip from the virtual machine, and setup the controller
|
||||
self.vm_ip = self.provider.get_ip_address(self.path_to_vm)
|
||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
|
||||
self.vm_ip = vm_ip_ports[0]
|
||||
if len(vm_ip_ports) > 1:
|
||||
self.server_port = int(vm_ip_ports[1])
|
||||
self.chromium_port = int(vm_ip_ports[2])
|
||||
self.vnc_port = int(vm_ip_ports[3])
|
||||
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, cache_dir=self.cache_dir_base)
|
||||
|
||||
def _revert_to_snapshot(self):
|
||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||
|
||||
@@ -54,7 +54,8 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any:
|
||||
"""
|
||||
try:
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
server_port = env.server_port
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
with sync_playwright() as p:
|
||||
# connect to remote Chrome instance
|
||||
@@ -68,7 +69,7 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any:
|
||||
"--remote-debugging-port=1337"
|
||||
], "shell": False})
|
||||
headers = {"Content-Type": "application/json"}
|
||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
||||
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||
time.sleep(5)
|
||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||
|
||||
@@ -454,7 +455,8 @@ def get_extensions_installed_from_shop(env, config: Dict[str, str]):
|
||||
|
||||
def get_page_info(env, config: Dict[str, str]):
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
server_port = env.server_port
|
||||
url = config["url"]
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
@@ -478,7 +480,7 @@ def get_page_info(env, config: Dict[str, str]):
|
||||
], "shell": False})
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
||||
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||
time.sleep(5)
|
||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||
|
||||
@@ -505,7 +507,8 @@ def get_page_info(env, config: Dict[str, str]):
|
||||
|
||||
def get_open_tabs_info(env, config: Dict[str, str]):
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
server_port = env.server_port
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
with sync_playwright() as p:
|
||||
@@ -528,7 +531,7 @@ def get_open_tabs_info(env, config: Dict[str, str]):
|
||||
], "shell": False})
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
||||
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||
time.sleep(5)
|
||||
try:
|
||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||
@@ -643,7 +646,7 @@ def get_active_tab_info(env, config: Dict[str, str]):
|
||||
logger.error("Failed to get the url of active tab")
|
||||
return None
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
with sync_playwright() as p:
|
||||
@@ -683,7 +686,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
|
||||
_path = os.path.join(env.cache_dir, config["dest"])
|
||||
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
server_port = env.server_port
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
|
||||
@@ -706,7 +710,7 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
|
||||
], "shell": False})
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
||||
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||
time.sleep(5)
|
||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||
|
||||
@@ -721,7 +725,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
|
||||
# fixme: needs to be changed (maybe through post-processing) since it's not working
|
||||
def get_chrome_saved_address(env, config: Dict[str, str]):
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
server_port = env.server_port
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
with sync_playwright() as p:
|
||||
@@ -744,7 +749,7 @@ def get_chrome_saved_address(env, config: Dict[str, str]):
|
||||
], "shell": False})
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
||||
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||
time.sleep(5)
|
||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||
|
||||
@@ -800,7 +805,8 @@ def get_number_of_search_results(env, config: Dict[str, str]):
|
||||
# todo: move into the config file
|
||||
url, result_selector = "https://google.com/search?q=query", '.search-result'
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
server_port = env.server_port
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
with sync_playwright() as p:
|
||||
@@ -822,7 +828,7 @@ def get_number_of_search_results(env, config: Dict[str, str]):
|
||||
], "shell": False})
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
||||
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||
time.sleep(5)
|
||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||
page = browser.new_page()
|
||||
@@ -1145,7 +1151,8 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]):
|
||||
logger.error("active_tab_url is not a string")
|
||||
return None
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
server_port = env.server_port
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
with sync_playwright() as p:
|
||||
@@ -1168,7 +1175,7 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]):
|
||||
], "shell": False})
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
||||
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||
time.sleep(5)
|
||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||
target_page = None
|
||||
@@ -1237,7 +1244,8 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]):
|
||||
especially used for www.recreation.gov examples
|
||||
"""
|
||||
host = env.vm_ip
|
||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
||||
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||
server_port = env.server_port
|
||||
|
||||
remote_debugging_url = f"http://{host}:{port}"
|
||||
with sync_playwright() as p:
|
||||
@@ -1259,7 +1267,7 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]):
|
||||
], "shell": False})
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
||||
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||
time.sleep(5)
|
||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||
page = browser.new_page()
|
||||
|
||||
@@ -7,7 +7,7 @@ logger = logging.getLogger("desktopenv.getters.general")
|
||||
|
||||
def get_vm_command_line(env, config: Dict[str, str]):
|
||||
vm_ip = env.vm_ip
|
||||
port = 5000
|
||||
port = env.server_port
|
||||
command = config["command"]
|
||||
shell = config.get("shell", False)
|
||||
|
||||
@@ -23,7 +23,7 @@ def get_vm_command_line(env, config: Dict[str, str]):
|
||||
|
||||
def get_vm_command_error(env, config: Dict[str, str]):
|
||||
vm_ip = env.vm_ip
|
||||
port = 5000
|
||||
port = env.server_port
|
||||
command = config["command"]
|
||||
shell = config.get("shell", False)
|
||||
|
||||
|
||||
@@ -1,12 +1,4 @@
|
||||
from desktop_env.providers.base import VMManager, Provider
|
||||
from desktop_env.providers.vmware.manager import VMwareVMManager
|
||||
from desktop_env.providers.vmware.provider import VMwareProvider
|
||||
from desktop_env.providers.aws.manager import AWSVMManager
|
||||
from desktop_env.providers.aws.provider import AWSProvider
|
||||
from desktop_env.providers.azure.manager import AzureVMManager
|
||||
from desktop_env.providers.azure.provider import AzureProvider
|
||||
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
|
||||
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
|
||||
|
||||
def create_vm_manager_and_provider(provider_name: str, region: str):
|
||||
"""
|
||||
@@ -14,12 +6,24 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
|
||||
"""
|
||||
provider_name = provider_name.lower().strip()
|
||||
if provider_name == "vmware":
|
||||
from desktop_env.providers.vmware.manager import VMwareVMManager
|
||||
from desktop_env.providers.vmware.provider import VMwareProvider
|
||||
return VMwareVMManager(), VMwareProvider(region)
|
||||
elif provider_name == "virtualbox":
|
||||
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
|
||||
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
|
||||
return VirtualBoxVMManager(), VirtualBoxProvider(region)
|
||||
elif provider_name in ["aws", "amazon web services"]:
|
||||
from desktop_env.providers.aws.manager import AWSVMManager
|
||||
from desktop_env.providers.aws.provider import AWSProvider
|
||||
return AWSVMManager(), AWSProvider(region)
|
||||
elif provider_name == "azure":
|
||||
from desktop_env.providers.azure.manager import AzureVMManager
|
||||
from desktop_env.providers.azure.provider import AzureProvider
|
||||
return AzureVMManager(), AzureProvider(region)
|
||||
elif provider_name == "docker":
|
||||
from desktop_env.providers.docker.manager import DockerVMManager
|
||||
from desktop_env.providers.docker.provider import DockerProvider
|
||||
return DockerVMManager(), DockerProvider(region)
|
||||
else:
|
||||
raise NotImplementedError(f"{provider_name} not implemented!")
|
||||
|
||||
123
desktop_env/providers/docker/manager.py
Normal file
123
desktop_env/providers/docker/manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import re
|
||||
|
||||
import threading
|
||||
from filelock import FileLock
|
||||
import uuid
|
||||
import zipfile
|
||||
|
||||
from time import sleep
|
||||
import shutil
|
||||
import psutil
|
||||
import subprocess
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import docker
|
||||
|
||||
import logging
|
||||
|
||||
from desktop_env.providers.base import VMManager
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
MAX_RETRY_TIMES = 10
|
||||
RETRY_INTERVAL = 5
|
||||
|
||||
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2"
|
||||
VMS_DIR = "./docker_vm_data"
|
||||
|
||||
# Determine the platform and CPU architecture to decide the correct VM image to download
|
||||
# if platform.system() == 'Darwin': # macOS
|
||||
# # if os.uname().machine == 'arm64': # Apple Silicon
|
||||
# URL = UBUNTU_ARM_URL
|
||||
# # else:
|
||||
# # url = UBUNTU_X86_URL
|
||||
# elif platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
# URL = UBUNTU_X86_URL
|
||||
# else:
|
||||
# raise Exception("Unsupported platform or architecture")
|
||||
URL = UBUNTU_X86_URL
|
||||
|
||||
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
docker_path = r"C:\Program Files\Docker\Docker"
|
||||
os.environ["PATH"] += os.pathsep + docker_path
|
||||
|
||||
def _download_vm(vms_dir: str):
|
||||
# Download the virtual machine image
|
||||
logger.info("Downloading the virtual machine image...")
|
||||
downloaded_size = 0
|
||||
|
||||
URL = UBUNTU_X86_URL
|
||||
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
|
||||
downloaded_file_name = DOWNLOADED_FILE_NAME
|
||||
|
||||
os.makedirs(vms_dir, exist_ok=True)
|
||||
|
||||
while True:
|
||||
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
|
||||
headers = {}
|
||||
if os.path.exists(downloaded_file_path):
|
||||
downloaded_size = os.path.getsize(downloaded_file_path)
|
||||
headers["Range"] = f"bytes={downloaded_size}-"
|
||||
|
||||
with requests.get(URL, headers=headers, stream=True) as response:
|
||||
if response.status_code == 416:
|
||||
# This means the range was not satisfiable, possibly the file was fully downloaded
|
||||
logger.info("Fully downloaded or the file size changed.")
|
||||
break
|
||||
|
||||
response.raise_for_status()
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
with open(downloaded_file_path, "ab") as file, tqdm(
|
||||
desc="Progress",
|
||||
total=total_size,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
initial=downloaded_size,
|
||||
ascii=True
|
||||
) as progress_bar:
|
||||
try:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
progress_bar.update(size)
|
||||
except (requests.exceptions.RequestException, IOError) as e:
|
||||
logger.error(f"Download error: {e}")
|
||||
sleep(RETRY_INTERVAL)
|
||||
logger.error("Retrying...")
|
||||
else:
|
||||
logger.info("Download succeeds.")
|
||||
break # Download completed successfully
|
||||
|
||||
class DockerVMManager(VMManager):
|
||||
def __init__(self, registry_path=""):
|
||||
pass
|
||||
|
||||
def add_vm(self, vm_path):
|
||||
pass
|
||||
|
||||
def check_and_clean(self):
|
||||
pass
|
||||
|
||||
def delete_vm(self, vm_path):
|
||||
pass
|
||||
|
||||
def initialize_registry(self):
|
||||
pass
|
||||
|
||||
def list_free_vms(self):
|
||||
return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)
|
||||
|
||||
def occupy_vm(self, vm_path):
|
||||
pass
|
||||
|
||||
def get_vm_path(self, os_type, region):
|
||||
if not os.path.exists(os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)):
|
||||
_download_vm(VMS_DIR)
|
||||
return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)
|
||||
67
desktop_env/providers/docker/provider.py
Normal file
67
desktop_env/providers/docker/provider.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import time
|
||||
import docker
|
||||
import psutil
|
||||
import requests
|
||||
|
||||
from desktop_env.providers.base import Provider
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
WAIT_TIME = 3
|
||||
RETRY_INTERVAL = 1
|
||||
|
||||
class DockerProvider(Provider):
|
||||
def __init__(self, region: str):
|
||||
self.client = docker.from_env()
|
||||
self.vnc_port = self._get_available_port(8006)
|
||||
self.server_port = self._get_available_port(5000)
|
||||
# self.remote_debugging_port = self._get_available_port(1337)
|
||||
self.chromium_port = self._get_available_port(9222)
|
||||
self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed
|
||||
|
||||
@staticmethod
|
||||
def _get_available_port(port: int):
|
||||
while port < 65354:
|
||||
if port not in [conn.laddr.port for conn in psutil.net_connections()]:
|
||||
return port
|
||||
port += 1
|
||||
|
||||
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
|
||||
logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}")
|
||||
self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment, cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={os.path.abspath(path_to_vm): {"bind": "/Ubuntu.qcow2", "mode": "ro"}}, ports={8006: self.vnc_port, 5000: self.server_port, 9222: self.chromium_port}, detach=True)
|
||||
def download_screenshot(ip, port):
|
||||
url = f"http://{ip}:{port}/screenshot"
|
||||
try:
|
||||
# max trey times 1, max timeout 1
|
||||
response = requests.get(url, timeout=(10, 10))
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except Exception as e:
|
||||
time.sleep(RETRY_INTERVAL)
|
||||
return False
|
||||
|
||||
# Try downloading the screenshot until successful
|
||||
while not download_screenshot("localhost", self.server_port):
|
||||
logger.info("Check whether the virtual machine is ready...")
|
||||
|
||||
def get_ip_address(self, path_to_vm: str) -> str:
|
||||
return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}"
|
||||
|
||||
def save_state(self, path_to_vm: str, snapshot_name: str):
|
||||
raise NotImplementedError("Not available for Docker.")
|
||||
|
||||
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
|
||||
pass
|
||||
|
||||
def stop_emulator(self, path_to_vm: str):
|
||||
logger.info("Stopping VM...")
|
||||
self.container.stop()
|
||||
self.container.remove()
|
||||
time.sleep(WAIT_TIME)
|
||||
|
||||
# docker run -it --rm -e "DISK_SIZE=64G" -e "RAM_SIZE=8G" -e "CPU_CORES=8" --volume /home/$USER/osworld/docker_vm_data/Ubuntu.qcow2:/Ubuntu.qcow2:ro --cap-add NET_ADMIN --device /dev/kvm -p 8008:8006 -p 5002:5000 happysixd/osworld-docker
|
||||
2
main.py
2
main.py
@@ -78,7 +78,7 @@ def human_agent():
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
|
||||
# env.close()
|
||||
env.close()
|
||||
logger.info("Environment closed.")
|
||||
|
||||
|
||||
|
||||
@@ -53,7 +53,3 @@ wrapt_timeout_decorator
|
||||
gdown
|
||||
tiktoken
|
||||
groq
|
||||
boto3
|
||||
azure-identity
|
||||
azure-mgmt-compute
|
||||
azure-mgmt-network
|
||||
|
||||
Reference in New Issue
Block a user