- Add screen_size parameter to get_vm_path() for all providers (with default 1920x1080) - Add os_type parameter to start_emulator() for Azure and VirtualBox providers - Add region parameter to stop_emulator() for VMware, Docker, and VirtualBox providers - Use *args, **kwargs for better extensibility and parameter consistency - Add documentation comments explaining ignored parameters for interface consistency - Prevents TypeError exceptions when AWS-specific parameters are passed to other providers This ensures all providers can handle the same parameter sets while maintaining backward compatibility and avoiding interface fragmentation.
126 lines
4.1 KiB
Python
126 lines
4.1 KiB
Python
import os
|
|
import platform
|
|
import zipfile
|
|
|
|
from time import sleep
|
|
import requests
|
|
from tqdm import tqdm
|
|
|
|
import logging
|
|
|
|
from desktop_env.providers.base import VMManager
|
|
|
|
logger = logging.getLogger("desktopenv.providers.docker.DockerVMManager")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
MAX_RETRY_TIMES = 10
|
|
RETRY_INTERVAL = 5
|
|
|
|
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2.zip"
|
|
WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-10-x64.qcow2.zip"
|
|
VMS_DIR = "./docker_vm_data"
|
|
|
|
URL = UBUNTU_X86_URL
|
|
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
|
|
|
|
if platform.system() == 'Windows':
|
|
docker_path = r"C:\Program Files\Docker\Docker"
|
|
os.environ["PATH"] += os.pathsep + docker_path
|
|
|
|
|
|
def _download_vm(vms_dir: str):
|
|
global URL, DOWNLOADED_FILE_NAME
|
|
# Download the virtual machine image
|
|
logger.info("Downloading the virtual machine image...")
|
|
downloaded_size = 0
|
|
|
|
downloaded_file_name = DOWNLOADED_FILE_NAME
|
|
|
|
os.makedirs(vms_dir, exist_ok=True)
|
|
|
|
while True:
|
|
downloaded_file_path = os.path.join(vms_dir, downloaded_file_name)
|
|
headers = {}
|
|
if os.path.exists(downloaded_file_path):
|
|
downloaded_size = os.path.getsize(downloaded_file_path)
|
|
headers["Range"] = f"bytes={downloaded_size}-"
|
|
|
|
with requests.get(URL, headers=headers, stream=True) as response:
|
|
if response.status_code == 416:
|
|
# This means the range was not satisfiable, possibly the file was fully downloaded
|
|
logger.info("Fully downloaded or the file size changed.")
|
|
break
|
|
|
|
response.raise_for_status()
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
|
|
with open(downloaded_file_path, "ab") as file, tqdm(
|
|
desc="Progress",
|
|
total=total_size,
|
|
unit='iB',
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
initial=downloaded_size,
|
|
ascii=True
|
|
) as progress_bar:
|
|
try:
|
|
for data in response.iter_content(chunk_size=1024):
|
|
size = file.write(data)
|
|
progress_bar.update(size)
|
|
except (requests.exceptions.RequestException, IOError) as e:
|
|
logger.error(f"Download error: {e}")
|
|
sleep(RETRY_INTERVAL)
|
|
logger.error("Retrying...")
|
|
else:
|
|
logger.info("Download succeeds.")
|
|
break # Download completed successfully
|
|
|
|
if downloaded_file_name.endswith(".zip"):
|
|
# Unzip the downloaded file
|
|
logger.info("Unzipping the downloaded file...☕️")
|
|
with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref:
|
|
zip_ref.extractall(vms_dir)
|
|
logger.info("Files have been successfully extracted to the directory: " + str(vms_dir))
|
|
|
|
|
|
class DockerVMManager(VMManager):
|
|
def __init__(self, registry_path=""):
|
|
pass
|
|
|
|
def add_vm(self, vm_path):
|
|
pass
|
|
|
|
def check_and_clean(self):
|
|
pass
|
|
|
|
def delete_vm(self, vm_path):
|
|
pass
|
|
|
|
def initialize_registry(self):
|
|
pass
|
|
|
|
def list_free_vms(self):
|
|
return os.path.join(VMS_DIR, DOWNLOADED_FILE_NAME)
|
|
|
|
def occupy_vm(self, vm_path):
|
|
pass
|
|
|
|
def get_vm_path(self, os_type, region, screen_size=(1920, 1080), **kwargs):
|
|
# Note: screen_size parameter is ignored for Docker provider
|
|
# but kept for interface consistency with other providers
|
|
global URL, DOWNLOADED_FILE_NAME
|
|
if os_type == "Ubuntu":
|
|
URL = UBUNTU_X86_URL
|
|
elif os_type == "Windows":
|
|
URL = WINDOWS_X86_URL
|
|
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
|
|
|
|
if DOWNLOADED_FILE_NAME.endswith(".zip"):
|
|
vm_name = DOWNLOADED_FILE_NAME[:-4]
|
|
else:
|
|
vm_name = DOWNLOADED_FILE_NAME
|
|
|
|
if not os.path.exists(os.path.join(VMS_DIR, vm_name)):
|
|
_download_vm(VMS_DIR)
|
|
return os.path.join(VMS_DIR, vm_name)
|