fix(docker): add file lock for port allocation to prevent race conditions

This commit is contained in:
Timothyxxx
2024-11-02 14:12:57 +08:00
parent 324371e78b
commit 3933e0d303
2 changed files with 45 additions and 45 deletions

View File

@@ -1,26 +1,16 @@
import os
import platform
import random
import re
import threading
from filelock import FileLock
import uuid
import zipfile
from time import sleep
import shutil
import psutil
import subprocess
import requests
from tqdm import tqdm
import docker
import logging
from desktop_env.providers.base import VMManager
logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager")
logger = logging.getLogger("desktopenv.providers.docker.DockerVMManager")
logger.setLevel(logging.INFO)
MAX_RETRY_TIMES = 10
@@ -30,17 +20,6 @@ UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve
WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-10-x64.qcow2.zip"
VMS_DIR = "./docker_vm_data"
# Determine the platform and CPU architecture to decide the correct VM image to download
# if platform.system() == 'Darwin': # macOS
# # if os.uname().machine == 'arm64': # Apple Silicon
# URL = UBUNTU_ARM_URL
# # else:
# # url = UBUNTU_X86_URL
# elif platform.machine().lower() in ['amd64', 'x86_64']:
# URL = UBUNTU_X86_URL
# else:
# raise Exception("Unsupported platform or architecture")
URL = UBUNTU_X86_URL
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
@@ -48,6 +27,7 @@ if platform.system() == 'Windows':
docker_path = r"C:\Program Files\Docker\Docker"
os.environ["PATH"] += os.pathsep + docker_path
def _download_vm(vms_dir: str):
global URL, DOWNLOADED_FILE_NAME
# Download the virtual machine image
@@ -102,6 +82,7 @@ def _download_vm(vms_dir: str):
zip_ref.extractall(vms_dir)
logger.info("Files have been successfully extracted to the directory: " + str(vms_dir))
class DockerVMManager(VMManager):
def __init__(self, registry_path=""):
pass
@@ -139,4 +120,4 @@ class DockerVMManager(VMManager):
if not os.path.exists(os.path.join(VMS_DIR, vm_name)):
_download_vm(VMS_DIR)
return os.path.join(VMS_DIR, vm_name)
return os.path.join(VMS_DIR, vm_name)

View File

@@ -1,42 +1,61 @@
import logging
import os
import platform
import subprocess
import time
import docker
import psutil
import requests
from filelock import FileLock
from pathlib import Path
from desktop_env.providers.base import Provider
logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider")
logger = logging.getLogger("desktopenv.providers.docker.DockerProvider")
logger.setLevel(logging.INFO)
WAIT_TIME = 3
RETRY_INTERVAL = 1
LOCK_TIMEOUT = 10
class PortAllocationError(Exception):
pass
class DockerProvider(Provider):
def __init__(self, region: str):
self.client = docker.from_env()
self.server_port = None
self.vnc_port = None
self.chromium_port = None
self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed
@staticmethod
def _get_available_port(port: int):
while port < 65354:
if port not in [conn.laddr.port for conn in psutil.net_connections()]:
return port
port += 1
self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed
temp_dir = Path(os.getenv('TEMP' if platform.system() == 'Windows' else '/tmp'))
self.lock_file = temp_dir / "docker_port_allocation.lck"
self.lock_file.parent.mkdir(parents=True, exist_ok=True)
def _get_available_port(self, port: int, lock_file: Path = None):
if lock_file is None:
lock_file = self.lock_file
lock = FileLock(str(lock_file), timeout=LOCK_TIMEOUT)
with lock:
while port < 65354:
if port not in [conn.laddr.port for conn in psutil.net_connections()]:
return port
port += 1
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
self.vnc_port = self._get_available_port(8006)
self.server_port = self._get_available_port(5000)
# self.remote_debugging_port = self._get_available_port(1337)
self.chromium_port = self._get_available_port(9222)
logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}")
self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment, cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={os.path.abspath(path_to_vm): {"bind": "/System.qcow2", "mode": "ro"}}, ports={8006: self.vnc_port, 5000: self.server_port, 9222: self.chromium_port}, detach=True)
self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment,
cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={
os.path.abspath(path_to_vm): {"bind": "/System.qcow2", "mode": "ro"}},
ports={8006: self.vnc_port, 5000: self.server_port,
9222: self.chromium_port}, detach=True)
def download_screenshot(ip, port):
url = f"http://{ip}:{port}/screenshot"
try:
@@ -47,22 +66,22 @@ class DockerProvider(Provider):
except Exception as e:
time.sleep(RETRY_INTERVAL)
return False
# Try downloading the screenshot until successful
while not download_screenshot("localhost", self.server_port):
logger.info("Check whether the virtual machine is ready...")
def get_ip_address(self, path_to_vm: str) -> str:
return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}"
def save_state(self, path_to_vm: str, snapshot_name: str):
raise NotImplementedError("Not available for Docker.")
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
pass
def stop_emulator(self, path_to_vm: str):
logger.info("Stopping VM...")
self.container.stop()
self.container.remove()
time.sleep(WAIT_TIME)
time.sleep(WAIT_TIME)