Support Docker VM manager and provider (#75)
* Add docker provider framework * Update VM download link * Add stop container * Update docker manager & provider * Update * Update * Update provider
This commit is contained in:
@@ -12,9 +12,10 @@ logger = logging.getLogger("desktopenv.pycontroller")
|
|||||||
|
|
||||||
class PythonController:
|
class PythonController:
|
||||||
def __init__(self, vm_ip: str,
|
def __init__(self, vm_ip: str,
|
||||||
|
server_port: int,
|
||||||
pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
|
pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
|
||||||
self.vm_ip = vm_ip
|
self.vm_ip = vm_ip
|
||||||
self.http_server = f"http://{vm_ip}:5000"
|
self.http_server = f"http://{vm_ip}:{server_port}"
|
||||||
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
|
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
|
||||||
self.retry_times = 3
|
self.retry_times = 3
|
||||||
self.retry_interval = 5
|
self.retry_interval = 5
|
||||||
|
|||||||
@@ -28,10 +28,12 @@ FILE_PATH = os.path.dirname(os.path.abspath(__file__))
|
|||||||
|
|
||||||
|
|
||||||
class SetupController:
|
class SetupController:
|
||||||
def __init__(self, vm_ip: str, cache_dir: str):
|
def __init__(self, vm_ip: str, server_port: int, chromium_port: int, cache_dir: str):
|
||||||
self.vm_ip: str = vm_ip
|
self.vm_ip: str = vm_ip
|
||||||
self.http_server: str = f"http://{vm_ip}:5000"
|
self.server_port: int = server_port
|
||||||
self.http_server_setup_root: str = f"http://{vm_ip}:5000/setup"
|
self.chromium_port: int = chromium_port
|
||||||
|
self.http_server: str = f"http://{vm_ip}:{server_port}"
|
||||||
|
self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup"
|
||||||
self.cache_dir: str = cache_dir
|
self.cache_dir: str = cache_dir
|
||||||
|
|
||||||
def reset_cache_dir(self, cache_dir: str):
|
def reset_cache_dir(self, cache_dir: str):
|
||||||
@@ -348,7 +350,7 @@ class SetupController:
|
|||||||
# Chrome setup
|
# Chrome setup
|
||||||
def _chrome_open_tabs_setup(self, urls_to_open: List[str]):
|
def _chrome_open_tabs_setup(self, urls_to_open: List[str]):
|
||||||
host = self.vm_ip
|
host = self.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = self.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
logger.info("Connect to Chrome @: %s", remote_debugging_url)
|
logger.info("Connect to Chrome @: %s", remote_debugging_url)
|
||||||
@@ -399,7 +401,7 @@ class SetupController:
|
|||||||
time.sleep(5) # Wait for Chrome to finish launching
|
time.sleep(5) # Wait for Chrome to finish launching
|
||||||
|
|
||||||
host = self.vm_ip
|
host = self.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = self.server_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider_name: str = "vmware",
|
provider_name: str = "docker",
|
||||||
region: str = None,
|
region: str = None,
|
||||||
path_to_vm: str = None,
|
path_to_vm: str = None,
|
||||||
snapshot_name: str = "init_state",
|
snapshot_name: str = "init_state",
|
||||||
@@ -54,6 +54,11 @@ class DesktopEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
# Initialize VM manager and vitualization provider
|
# Initialize VM manager and vitualization provider
|
||||||
self.region = region
|
self.region = region
|
||||||
|
|
||||||
|
# Default
|
||||||
|
self.server_port = 5000
|
||||||
|
self.chromium_port = 9222
|
||||||
|
self.vnc_port = 8006
|
||||||
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
||||||
|
|
||||||
self.os_type = os_type
|
self.os_type = os_type
|
||||||
@@ -73,8 +78,9 @@ class DesktopEnv(gym.Env):
|
|||||||
self.require_terminal = require_terminal
|
self.require_terminal = require_terminal
|
||||||
|
|
||||||
# Initialize emulator and controller
|
# Initialize emulator and controller
|
||||||
logger.info("Initializing...")
|
if provider_name != "docker": # Check if this is applicable to other VM providers
|
||||||
self._start_emulator()
|
logger.info("Initializing...")
|
||||||
|
self._start_emulator()
|
||||||
|
|
||||||
# mode: human or machine
|
# mode: human or machine
|
||||||
self.instruction = None
|
self.instruction = None
|
||||||
@@ -92,9 +98,14 @@ class DesktopEnv(gym.Env):
|
|||||||
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
||||||
|
|
||||||
# Get the ip from the virtual machine, and setup the controller
|
# Get the ip from the virtual machine, and setup the controller
|
||||||
self.vm_ip = self.provider.get_ip_address(self.path_to_vm)
|
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
|
||||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
self.vm_ip = vm_ip_ports[0]
|
||||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
if len(vm_ip_ports) > 1:
|
||||||
|
self.server_port = int(vm_ip_ports[1])
|
||||||
|
self.chromium_port = int(vm_ip_ports[2])
|
||||||
|
self.vnc_port = int(vm_ip_ports[3])
|
||||||
|
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||||
|
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, cache_dir=self.cache_dir_base)
|
||||||
|
|
||||||
def _revert_to_snapshot(self):
|
def _revert_to_snapshot(self):
|
||||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||||
|
|||||||
@@ -54,7 +54,8 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
server_port = env.server_port
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
# connect to remote Chrome instance
|
# connect to remote Chrome instance
|
||||||
@@ -68,7 +69,7 @@ def get_info_from_website(env, config: Dict[Any, Any]) -> Any:
|
|||||||
"--remote-debugging-port=1337"
|
"--remote-debugging-port=1337"
|
||||||
], "shell": False})
|
], "shell": False})
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||||
|
|
||||||
@@ -454,7 +455,8 @@ def get_extensions_installed_from_shop(env, config: Dict[str, str]):
|
|||||||
|
|
||||||
def get_page_info(env, config: Dict[str, str]):
|
def get_page_info(env, config: Dict[str, str]):
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
server_port = env.server_port
|
||||||
url = config["url"]
|
url = config["url"]
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
@@ -478,7 +480,7 @@ def get_page_info(env, config: Dict[str, str]):
|
|||||||
], "shell": False})
|
], "shell": False})
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||||
|
|
||||||
@@ -505,7 +507,8 @@ def get_page_info(env, config: Dict[str, str]):
|
|||||||
|
|
||||||
def get_open_tabs_info(env, config: Dict[str, str]):
|
def get_open_tabs_info(env, config: Dict[str, str]):
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
server_port = env.server_port
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
@@ -528,7 +531,7 @@ def get_open_tabs_info(env, config: Dict[str, str]):
|
|||||||
], "shell": False})
|
], "shell": False})
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
try:
|
try:
|
||||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||||
@@ -643,7 +646,7 @@ def get_active_tab_info(env, config: Dict[str, str]):
|
|||||||
logger.error("Failed to get the url of active tab")
|
logger.error("Failed to get the url of active tab")
|
||||||
return None
|
return None
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
@@ -683,7 +686,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
|
|||||||
_path = os.path.join(env.cache_dir, config["dest"])
|
_path = os.path.join(env.cache_dir, config["dest"])
|
||||||
|
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
server_port = env.server_port
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
|
|
||||||
@@ -706,7 +710,7 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
|
|||||||
], "shell": False})
|
], "shell": False})
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||||
|
|
||||||
@@ -721,7 +725,8 @@ def get_pdf_from_url(env, config: Dict[str, str]) -> str:
|
|||||||
# fixme: needs to be changed (maybe through post-processing) since it's not working
|
# fixme: needs to be changed (maybe through post-processing) since it's not working
|
||||||
def get_chrome_saved_address(env, config: Dict[str, str]):
|
def get_chrome_saved_address(env, config: Dict[str, str]):
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
server_port = env.server_port
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
@@ -744,7 +749,7 @@ def get_chrome_saved_address(env, config: Dict[str, str]):
|
|||||||
], "shell": False})
|
], "shell": False})
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||||
|
|
||||||
@@ -800,7 +805,8 @@ def get_number_of_search_results(env, config: Dict[str, str]):
|
|||||||
# todo: move into the config file
|
# todo: move into the config file
|
||||||
url, result_selector = "https://google.com/search?q=query", '.search-result'
|
url, result_selector = "https://google.com/search?q=query", '.search-result'
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
server_port = env.server_port
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
@@ -822,7 +828,7 @@ def get_number_of_search_results(env, config: Dict[str, str]):
|
|||||||
], "shell": False})
|
], "shell": False})
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||||
page = browser.new_page()
|
page = browser.new_page()
|
||||||
@@ -1145,7 +1151,8 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]):
|
|||||||
logger.error("active_tab_url is not a string")
|
logger.error("active_tab_url is not a string")
|
||||||
return None
|
return None
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
server_port = env.server_port
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
@@ -1168,7 +1175,7 @@ def get_active_tab_html_parse(env, config: Dict[str, Any]):
|
|||||||
], "shell": False})
|
], "shell": False})
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||||
target_page = None
|
target_page = None
|
||||||
@@ -1237,7 +1244,8 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]):
|
|||||||
especially used for www.recreation.gov examples
|
especially used for www.recreation.gov examples
|
||||||
"""
|
"""
|
||||||
host = env.vm_ip
|
host = env.vm_ip
|
||||||
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
|
port = env.chromium_port # fixme: this port is hard-coded, need to be changed from config file
|
||||||
|
server_port = env.server_port
|
||||||
|
|
||||||
remote_debugging_url = f"http://{host}:{port}"
|
remote_debugging_url = f"http://{host}:{port}"
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
@@ -1259,7 +1267,7 @@ def get_gotoRecreationPage_and_get_html_content(env, config: Dict[str, Any]):
|
|||||||
], "shell": False})
|
], "shell": False})
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
requests.post("http://" + host + ":5000/setup" + "/launch", headers=headers, data=payload)
|
requests.post("http://" + host + ":" + server_port + "/setup" + "/launch", headers=headers, data=payload)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
browser = p.chromium.connect_over_cdp(remote_debugging_url)
|
||||||
page = browser.new_page()
|
page = browser.new_page()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ logger = logging.getLogger("desktopenv.getters.general")
|
|||||||
|
|
||||||
def get_vm_command_line(env, config: Dict[str, str]):
|
def get_vm_command_line(env, config: Dict[str, str]):
|
||||||
vm_ip = env.vm_ip
|
vm_ip = env.vm_ip
|
||||||
port = 5000
|
port = env.server_port
|
||||||
command = config["command"]
|
command = config["command"]
|
||||||
shell = config.get("shell", False)
|
shell = config.get("shell", False)
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ def get_vm_command_line(env, config: Dict[str, str]):
|
|||||||
|
|
||||||
def get_vm_command_error(env, config: Dict[str, str]):
|
def get_vm_command_error(env, config: Dict[str, str]):
|
||||||
vm_ip = env.vm_ip
|
vm_ip = env.vm_ip
|
||||||
port = 5000
|
port = env.server_port
|
||||||
command = config["command"]
|
command = config["command"]
|
||||||
shell = config.get("shell", False)
|
shell = config.get("shell", False)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,4 @@
|
|||||||
from desktop_env.providers.base import VMManager, Provider
|
from desktop_env.providers.base import VMManager, Provider
|
||||||
from desktop_env.providers.vmware.manager import VMwareVMManager
|
|
||||||
from desktop_env.providers.vmware.provider import VMwareProvider
|
|
||||||
from desktop_env.providers.aws.manager import AWSVMManager
|
|
||||||
from desktop_env.providers.aws.provider import AWSProvider
|
|
||||||
from desktop_env.providers.azure.manager import AzureVMManager
|
|
||||||
from desktop_env.providers.azure.provider import AzureProvider
|
|
||||||
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
|
|
||||||
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
|
|
||||||
|
|
||||||
def create_vm_manager_and_provider(provider_name: str, region: str):
|
def create_vm_manager_and_provider(provider_name: str, region: str):
|
||||||
"""
|
"""
|
||||||
@@ -14,12 +6,24 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
|
|||||||
"""
|
"""
|
||||||
provider_name = provider_name.lower().strip()
|
provider_name = provider_name.lower().strip()
|
||||||
if provider_name == "vmware":
|
if provider_name == "vmware":
|
||||||
|
from desktop_env.providers.vmware.manager import VMwareVMManager
|
||||||
|
from desktop_env.providers.vmware.provider import VMwareProvider
|
||||||
return VMwareVMManager(), VMwareProvider(region)
|
return VMwareVMManager(), VMwareProvider(region)
|
||||||
elif provider_name == "virtualbox":
|
elif provider_name == "virtualbox":
|
||||||
|
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
|
||||||
|
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider
|
||||||
return VirtualBoxVMManager(), VirtualBoxProvider(region)
|
return VirtualBoxVMManager(), VirtualBoxProvider(region)
|
||||||
elif provider_name in ["aws", "amazon web services"]:
|
elif provider_name in ["aws", "amazon web services"]:
|
||||||
|
from desktop_env.providers.aws.manager import AWSVMManager
|
||||||
|
from desktop_env.providers.aws.provider import AWSProvider
|
||||||
return AWSVMManager(), AWSProvider(region)
|
return AWSVMManager(), AWSProvider(region)
|
||||||
elif provider_name == "azure":
|
elif provider_name == "azure":
|
||||||
|
from desktop_env.providers.azure.manager import AzureVMManager
|
||||||
|
from desktop_env.providers.azure.provider import AzureProvider
|
||||||
return AzureVMManager(), AzureProvider(region)
|
return AzureVMManager(), AzureProvider(region)
|
||||||
|
elif provider_name == "docker":
|
||||||
|
from desktop_env.providers.docker.manager import DockerVMManager
|
||||||
|
from desktop_env.providers.docker.provider import DockerProvider
|
||||||
|
return DockerVMManager(), DockerProvider(region)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{provider_name} not implemented!")
|
raise NotImplementedError(f"{provider_name} not implemented!")
|
||||||
|
|||||||
123
desktop_env/providers/docker/manager.py
Normal file
123
desktop_env/providers/docker/manager.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from filelock import FileLock
|
||||||
|
import uuid
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
from time import sleep
|
||||||
|
import shutil
|
||||||
|
import psutil
|
||||||
|
import subprocess
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
import docker
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from desktop_env.providers.base import VMManager
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
MAX_RETRY_TIMES = 10
|
||||||
|
RETRY_INTERVAL = 5
|
||||||
|
|
||||||
|
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2"
|
||||||
|
VMS_DIR = "./docker_vm_data"
|
||||||
|
|
||||||
|
# Determine the platform and CPU architecture to decide the correct VM image to download
|
||||||
|
# if platform.system() == 'Darwin': # macOS
|
||||||
|
# # if os.uname().machine == 'arm64': # Apple Silicon
|
||||||
|
# URL = UBUNTU_ARM_URL
|
||||||
|
# # else:
|
||||||
|
# # url = UBUNTU_X86_URL
|
||||||
|
# elif platform.machine().lower() in ['amd64', 'x86_64']:
|
||||||
|
# URL = UBUNTU_X86_URL
|
||||||
|
# else:
|
||||||
|
# raise Exception("Unsupported platform or architecture")
|
||||||
|
URL = UBUNTU_X86_URL
|
||||||
|
|
||||||
|
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
|
||||||
|
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
docker_path = r"C:\Program Files\Docker\Docker"
|
||||||
|
os.environ["PATH"] += os.pathsep + docker_path
|
||||||
|
|
||||||
|
def _download_vm(vms_dir: str):
|
||||||
|
# Download the virtual machine image
|
||||||
|
logger.info("Downloading the virtual machine image...")
|
||||||
|
downloaded_size = 0
|
||||||
|
|
||||||
|
URL = UBUNTU_X86_URL
|
||||||
|
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
|
||||||
|
downloaded_file_name = DOWNLOADED_FILE_NAME
|
||||||
|
|
||||||
|
os.makedirs(vms_dir, exist_ok=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
|
||||||
|
headers = {}
|
||||||
|
if os.path.exists(downloaded_file_path):
|
||||||
|
downloaded_size = os.path.getsize(downloaded_file_path)
|
||||||
|
headers["Range"] = f"bytes={downloaded_size}-"
|
||||||
|
|
||||||
|
with requests.get(URL, headers=headers, stream=True) as response:
|
||||||
|
if response.status_code == 416:
|
||||||
|
# This means the range was not satisfiable, possibly the file was fully downloaded
|
||||||
|
logger.info("Fully downloaded or the file size changed.")
|
||||||
|
break
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
|
||||||
|
with open(downloaded_file_path, "ab") as file, tqdm(
|
||||||
|
desc="Progress",
|
||||||
|
total=total_size,
|
||||||
|
unit='iB',
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
initial=downloaded_size,
|
||||||
|
ascii=True
|
||||||
|
) as progress_bar:
|
||||||
|
try:
|
||||||
|
for data in response.iter_content(chunk_size=1024):
|
||||||
|
size = file.write(data)
|
||||||
|
progress_bar.update(size)
|
||||||
|
except (requests.exceptions.RequestException, IOError) as e:
|
||||||
|
logger.error(f"Download error: {e}")
|
||||||
|
sleep(RETRY_INTERVAL)
|
||||||
|
logger.error("Retrying...")
|
||||||
|
else:
|
||||||
|
logger.info("Download succeeds.")
|
||||||
|
break # Download completed successfully
|
||||||
|
|
||||||
|
class DockerVMManager(VMManager):
|
||||||
|
def __init__(self, registry_path=""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_vm(self, vm_path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_and_clean(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete_vm(self, vm_path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def initialize_registry(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def list_free_vms(self):
|
||||||
|
return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)
|
||||||
|
|
||||||
|
def occupy_vm(self, vm_path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_vm_path(self, os_type, region):
|
||||||
|
if not os.path.exists(os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)):
|
||||||
|
_download_vm(VMS_DIR)
|
||||||
|
return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)
|
||||||
67
desktop_env/providers/docker/provider.py
Normal file
67
desktop_env/providers/docker/provider.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import docker
|
||||||
|
import psutil
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from desktop_env.providers.base import Provider
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
WAIT_TIME = 3
|
||||||
|
RETRY_INTERVAL = 1
|
||||||
|
|
||||||
|
class DockerProvider(Provider):
|
||||||
|
def __init__(self, region: str):
|
||||||
|
self.client = docker.from_env()
|
||||||
|
self.vnc_port = self._get_available_port(8006)
|
||||||
|
self.server_port = self._get_available_port(5000)
|
||||||
|
# self.remote_debugging_port = self._get_available_port(1337)
|
||||||
|
self.chromium_port = self._get_available_port(9222)
|
||||||
|
self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_available_port(port: int):
|
||||||
|
while port < 65354:
|
||||||
|
if port not in [conn.laddr.port for conn in psutil.net_connections()]:
|
||||||
|
return port
|
||||||
|
port += 1
|
||||||
|
|
||||||
|
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
|
||||||
|
logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}")
|
||||||
|
self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment, cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={os.path.abspath(path_to_vm): {"bind": "/Ubuntu.qcow2", "mode": "ro"}}, ports={8006: self.vnc_port, 5000: self.server_port, 9222: self.chromium_port}, detach=True)
|
||||||
|
def download_screenshot(ip, port):
|
||||||
|
url = f"http://{ip}:{port}/screenshot"
|
||||||
|
try:
|
||||||
|
# max trey times 1, max timeout 1
|
||||||
|
response = requests.get(url, timeout=(10, 10))
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
time.sleep(RETRY_INTERVAL)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Try downloading the screenshot until successful
|
||||||
|
while not download_screenshot("localhost", self.server_port):
|
||||||
|
logger.info("Check whether the virtual machine is ready...")
|
||||||
|
|
||||||
|
def get_ip_address(self, path_to_vm: str) -> str:
|
||||||
|
return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}"
|
||||||
|
|
||||||
|
def save_state(self, path_to_vm: str, snapshot_name: str):
|
||||||
|
raise NotImplementedError("Not available for Docker.")
|
||||||
|
|
||||||
|
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop_emulator(self, path_to_vm: str):
|
||||||
|
logger.info("Stopping VM...")
|
||||||
|
self.container.stop()
|
||||||
|
self.container.remove()
|
||||||
|
time.sleep(WAIT_TIME)
|
||||||
|
|
||||||
|
# docker run -it --rm -e "DISK_SIZE=64G" -e "RAM_SIZE=8G" -e "CPU_CORES=8" --volume /home/$USER/osworld/docker_vm_data/Ubuntu.qcow2:/Ubuntu.qcow2:ro --cap-add NET_ADMIN --device /dev/kvm -p 8008:8006 -p 5002:5000 happysixd/osworld-docker
|
||||||
2
main.py
2
main.py
@@ -78,7 +78,7 @@ def human_agent():
|
|||||||
result = env.evaluate()
|
result = env.evaluate()
|
||||||
logger.info("Result: %.2f", result)
|
logger.info("Result: %.2f", result)
|
||||||
|
|
||||||
# env.close()
|
env.close()
|
||||||
logger.info("Environment closed.")
|
logger.info("Environment closed.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,3 @@ wrapt_timeout_decorator
|
|||||||
gdown
|
gdown
|
||||||
tiktoken
|
tiktoken
|
||||||
groq
|
groq
|
||||||
boto3
|
|
||||||
azure-identity
|
|
||||||
azure-mgmt-compute
|
|
||||||
azure-mgmt-network
|
|
||||||
|
|||||||
Reference in New Issue
Block a user